use super::{backfill, Guard, ThreadType}; use crate::{ config::Config, database, types::{alpaca::websocket, news::Prediction, Bar, News, Subset}, utils::add_slash_to_pair, }; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use log::{error, info, warn}; use serde_json::from_str; use std::{collections::HashSet, sync::Arc}; use tokio::{ join, net::TcpStream, spawn, sync::{mpsc, Mutex, RwLock}, task::block_in_place, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; pub async fn run( app_config: Arc, thread_type: ThreadType, guard: Arc>, websocket_sender: Arc< Mutex>, tungstenite::Message>>, >, mut websocket_receiver: SplitStream>>, backfill_sender: mpsc::Sender, ) { loop { let message = websocket_receiver.next().await.unwrap().unwrap(); spawn(handle_websocket_message( app_config.clone(), thread_type, guard.clone(), websocket_sender.clone(), backfill_sender.clone(), message, )); } } async fn handle_websocket_message( app_config: Arc, thread_type: ThreadType, guard: Arc>, websocket_sender: Arc< Mutex>, tungstenite::Message>>, >, backfill_sender: mpsc::Sender, message: tungstenite::Message, ) { match message { tungstenite::Message::Text(message) => { let message = from_str::>(&message); if let Ok(message) = message { for message in message { spawn(handle_parsed_websocket_message( app_config.clone(), thread_type, guard.clone(), backfill_sender.clone(), message, )); } } else { error!( "{:?} - Failed to deserialize websocket message: {:?}", thread_type, message ); } } tungstenite::Message::Ping(_) => { websocket_sender .lock() .await .send(tungstenite::Message::Pong(vec![])) .await .unwrap(); } _ => error!( "{:?} - Unexpected websocket message: {:?}", thread_type, message ), } } #[allow(clippy::significant_drop_tightening)] #[allow(clippy::too_many_lines)] async fn handle_parsed_websocket_message( app_config: Arc, thread_type: ThreadType, guard: Arc>, backfill_sender: mpsc::Sender, message: websocket::incoming::Message, ) { match message { websocket::incoming::Message::Subscription(message) => { let symbols = match message { websocket::incoming::subscription::Message::Market { bars, .. } => bars, websocket::incoming::subscription::Message::News { news } => news .into_iter() .map(|symbol| add_slash_to_pair(&symbol)) .collect(), }; let mut guard = guard.write().await; let newly_subscribed = guard .pending_subscriptions .extract_if(|asset| symbols.contains(&asset.symbol)) .collect::>(); let newly_unsubscribed = guard .pending_unsubscriptions .extract_if(|asset| !symbols.contains(&asset.symbol)) .collect::>(); drop(guard); let newly_subscribed_future = async { if !newly_subscribed.is_empty() { info!( "{:?} - Subscribed to {:?}.", thread_type, newly_subscribed .iter() .map(|asset| asset.symbol.clone()) .collect::>() ); let (backfill_message, backfill_receiver) = backfill::Message::new( backfill::Action::Backfill, Subset::Some(newly_subscribed.into_iter().collect::>()), ); backfill_sender.send(backfill_message).await.unwrap(); backfill_receiver.await.unwrap(); } }; let newly_unsubscribed_future = async { if !newly_unsubscribed.is_empty() { info!( "{:?} - Unsubscribed from {:?}.", thread_type, newly_unsubscribed .iter() .map(|asset| asset.symbol.clone()) .collect::>() ); let (purge_message, purge_receiver) = backfill::Message::new( backfill::Action::Purge, Subset::Some(newly_unsubscribed.into_iter().collect::>()), ); backfill_sender.send(purge_message).await.unwrap(); purge_receiver.await.unwrap(); } }; join!(newly_subscribed_future, newly_unsubscribed_future); } websocket::incoming::Message::Bar(message) | websocket::incoming::Message::UpdatedBar(message) => { let bar = Bar::from(message); let guard = guard.read().await; if !guard.assets.contains_right(&bar.symbol) { warn!( "{:?} - Race condition: received bar for unsubscribed symbol: {:?}.", thread_type, bar.symbol ); return; } info!( "{:?} - Received bar for {}: {}.", thread_type, bar.symbol, bar.time ); database::bars::upsert(&app_config.clickhouse_client, &bar).await; } websocket::incoming::Message::News(message) => { let news = News::from(message); let guard = guard.read().await; if !news .symbols .iter() .any(|symbol| guard.assets.contains_right(symbol)) { warn!( "{:?} - Race condition: received news for unsubscribed symbols: {:?}.", thread_type, news.symbols ); return; } info!( "{:?} - Received news for {:?}: {}.", thread_type, news.symbols, news.time_created ); let input = format!("{}\n\n{}", news.headline, news.content); let sequence_classifier = app_config.sequence_classifier.lock().await; let prediction = block_in_place(|| { sequence_classifier .predict(vec![input.as_str()]) .into_iter() .map(|label| Prediction::try_from(label).unwrap()) .collect::>()[0] }); drop(sequence_classifier); let news = News { sentiment: prediction.sentiment, confidence: prediction.confidence, ..news }; database::news::upsert(&app_config.clickhouse_client, &news).await; } websocket::incoming::Message::Success(_) => {} websocket::incoming::Message::Error(message) => { error!( "{:?} - Received error message: {}.", thread_type, message.message ); } } }