Files
qrust/src/threads/data/websocket.rs
2024-03-09 20:13:36 +00:00

438 lines
14 KiB
Rust

use super::ThreadType;
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, Class, News},
};
use async_trait::async_trait;
use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::{debug, error, info};
use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::block_in_place,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let pending = pending.clone();
spawn(async move {
handler.handle_websocket_message(pending, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
message: Message,
) {
if message.symbols.is_empty() {
message.response.send(()).unwrap();
return;
}
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.subscriptions
.extend(pending_subscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.unsubscriptions
.extend(pending_unsubscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarsHandler {
config: Arc<Config>,
subscription_message_constructor:
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl Handler for BarsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bar,
)
.await
.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_crypto,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}