use super::ThreadType; use crate::{ config::Config, database, types::{alpaca::websocket, news::Prediction, Bar, Class, News}, }; use async_trait::async_trait; use futures_util::{ future::join_all, stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use log::{debug, error, info}; use serde_json::{from_str, to_string}; use std::{collections::HashMap, sync::Arc}; use tokio::{ net::TcpStream, select, spawn, sync::{mpsc, oneshot, Mutex, RwLock}, task::block_in_place, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; pub enum Action { Subscribe, Unsubscribe, } impl From for Option { fn from(action: super::Action) -> Self { match action { super::Action::Add | super::Action::Enable => Some(Action::Subscribe), super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe), } } } pub struct Message { pub action: Option, pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { pub fn new(action: Option, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel(); ( Self { action, symbols, response: sender, }, receiver, ) } } pub struct Pending { pub subscriptions: HashMap>, pub unsubscriptions: HashMap>, } #[async_trait] pub trait Handler: Send + Sync { fn create_subscription_message( &self, symbols: Vec, ) -> websocket::data::outgoing::subscribe::Message; async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, ); } pub async fn run( handler: Arc>, mut receiver: mpsc::Receiver, mut websocket_stream: SplitStream>>, websocket_sink: SplitSink>, tungstenite::Message>, ) { let pending = Arc::new(RwLock::new(Pending { subscriptions: HashMap::new(), unsubscriptions: HashMap::new(), })); let websocket_sink = Arc::new(Mutex::new(websocket_sink)); loop { select! { Some(message) = receiver.recv() => { spawn(handle_message( handler.clone(), pending.clone(), websocket_sink.clone(), message, )); } Some(Ok(message)) = websocket_stream.next() => { match message { tungstenite::Message::Text(message) => { let parsed_message = from_str::>(&message); if parsed_message.is_err() { error!("Failed to deserialize websocket message: {:?}", message); continue; } for message in parsed_message.unwrap() { let handler = handler.clone(); let pending = pending.clone(); spawn(async move { handler.handle_websocket_message(pending, message).await; }); } } tungstenite::Message::Ping(_) => {} _ => error!("Unexpected websocket message: {:?}", message), } } else => panic!("Communication channel unexpectedly closed.") } } } async fn handle_message( handler: Arc>, pending: Arc>, sink: Arc>, tungstenite::Message>>>, message: Message, ) { if message.symbols.is_empty() { message.response.send(()).unwrap(); return; } match message.action { Some(Action::Subscribe) => { let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message .symbols .iter() .map(|symbol| { let (sender, receiver) = oneshot::channel(); ((symbol.clone(), sender), receiver) }) .unzip(); pending .write() .await .subscriptions .extend(pending_subscriptions); sink.lock() .await .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Subscribe( handler.create_subscription_message(message.symbols), )) .unwrap(), )) .await .unwrap(); join_all(receivers).await; } Some(Action::Unsubscribe) => { let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message .symbols .iter() .map(|symbol| { let (sender, receiver) = oneshot::channel(); ((symbol.clone(), sender), receiver) }) .unzip(); pending .write() .await .unsubscriptions .extend(pending_unsubscriptions); sink.lock() .await .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Unsubscribe( handler.create_subscription_message(message.symbols.clone()), )) .unwrap(), )) .await .unwrap(); join_all(receivers).await; } None => {} } message.response.send(()).unwrap(); } struct BarsHandler { config: Arc, subscription_message_constructor: fn(Vec) -> websocket::data::outgoing::subscribe::Message, } #[async_trait] impl Handler for BarsHandler { fn create_subscription_message( &self, symbols: Vec, ) -> websocket::data::outgoing::subscribe::Message { (self.subscription_message_constructor)(symbols) } async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, ) { match message { websocket::data::incoming::Message::Subscription(message) => { let websocket::data::incoming::subscription::Message::Market { bars: symbols, .. } = message else { unreachable!() }; let mut pending = pending.write().await; let newly_subscribed = pending .subscriptions .extract_if(|symbol, _| symbols.contains(symbol)) .collect::>(); let newly_unsubscribed = pending .unsubscriptions .extract_if(|symbol, _| !symbols.contains(symbol)) .collect::>(); drop(pending); if !newly_subscribed.is_empty() { info!( "Subscribed to bars for {:?}.", newly_subscribed.keys().collect::>() ); for sender in newly_subscribed.into_values() { sender.send(()).unwrap(); } } if !newly_unsubscribed.is_empty() { info!( "Unsubscribed from bars for {:?}.", newly_unsubscribed.keys().collect::>() ); for sender in newly_unsubscribed.into_values() { sender.send(()).unwrap(); } } } websocket::data::incoming::Message::Bar(message) | websocket::data::incoming::Message::UpdatedBar(message) => { let bar = Bar::from(message); debug!("Received bar for {}: {}.", bar.symbol, bar.time); database::bars::upsert( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &bar, ) .await .unwrap(); } websocket::data::incoming::Message::Status(message) => { debug!( "Received status message for {}: {:?}.", message.symbol, message.status ); match message.status { websocket::data::incoming::status::Status::TradingHalt(_) | websocket::data::incoming::status::Status::VolatilityTradingPause(_) => { database::assets::update_status_where_symbol( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &message.symbol, false, ) .await .unwrap(); } websocket::data::incoming::status::Status::Resume(_) | websocket::data::incoming::status::Status::TradingResumption(_) => { database::assets::update_status_where_symbol( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &message.symbol, true, ) .await .unwrap(); } _ => {} } } websocket::data::incoming::Message::Error(message) => { error!("Received error message: {}.", message.message); } _ => unreachable!(), } } } struct NewsHandler { config: Arc, } #[async_trait] impl Handler for NewsHandler { fn create_subscription_message( &self, symbols: Vec, ) -> websocket::data::outgoing::subscribe::Message { websocket::data::outgoing::subscribe::Message::new_news(symbols) } async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, ) { match message { websocket::data::incoming::Message::Subscription(message) => { let websocket::data::incoming::subscription::Message::News { news: symbols } = message else { unreachable!() }; let mut pending = pending.write().await; let newly_subscribed = pending .subscriptions .extract_if(|symbol, _| symbols.contains(symbol)) .collect::>(); let newly_unsubscribed = pending .unsubscriptions .extract_if(|symbol, _| !symbols.contains(symbol)) .collect::>(); drop(pending); if !newly_subscribed.is_empty() { info!( "Subscribed to news for {:?}.", newly_subscribed.keys().collect::>() ); for sender in newly_subscribed.into_values() { sender.send(()).unwrap(); } } if !newly_unsubscribed.is_empty() { info!( "Unsubscribed from news for {:?}.", newly_unsubscribed.keys().collect::>() ); for sender in newly_unsubscribed.into_values() { sender.send(()).unwrap(); } } } websocket::data::incoming::Message::News(message) => { let news = News::from(message); debug!( "Received news for {:?}: {}.", news.symbols, news.time_created ); let input = format!("{}\n\n{}", news.headline, news.content); let sequence_classifier = self.config.sequence_classifier.lock().await; let prediction = block_in_place(|| { sequence_classifier .predict(vec![input.as_str()]) .into_iter() .map(|label| Prediction::try_from(label).unwrap()) .collect::>()[0] }); drop(sequence_classifier); let news = News { sentiment: prediction.sentiment, confidence: prediction.confidence, ..news }; database::news::upsert( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &news, ) .await .unwrap(); } websocket::data::incoming::Message::Error(message) => { error!("Received error message: {}.", message.message); } _ => unreachable!(), } } } pub fn create_handler(thread_type: ThreadType, config: Arc) -> Box { match thread_type { ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler { config, subscription_message_constructor: websocket::data::outgoing::subscribe::Message::new_market_us_equity, }), ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler { config, subscription_message_constructor: websocket::data::outgoing::subscribe::Message::new_market_crypto, }), ThreadType::News => Box::new(NewsHandler { config }), } }