use super::{Guard, ThreadType}; use crate::{ config::Config, database, types::{alpaca::websocket, Asset}, }; use futures_util::{stream::SplitSink, SinkExt}; use log::info; use serde_json::to_string; use std::sync::Arc; use tokio::{ join, net::TcpStream, spawn, sync::{mpsc, oneshot, Mutex, RwLock}, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; #[derive(Clone)] pub enum Action { Add, Remove, } pub struct Message { pub action: Action, pub assets: Vec, pub response: oneshot::Sender<()>, } impl Message { pub fn new(action: Action, assets: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { action, assets, response: sender, }, receiver, ) } } pub async fn run( app_config: Arc, thread_type: ThreadType, guard: Arc>, mut asset_status_receiver: mpsc::Receiver, websocket_sender: Arc< Mutex>, tungstenite::Message>>, >, ) { loop { let message = asset_status_receiver.recv().await.unwrap(); spawn(handle_asset_status_message( app_config.clone(), thread_type, guard.clone(), websocket_sender.clone(), message, )); } } #[allow(clippy::significant_drop_tightening)] async fn handle_asset_status_message( app_config: Arc, thread_type: ThreadType, guard: Arc>, websocket_sender: Arc< Mutex>, tungstenite::Message>>, >, message: Message, ) { let symbols = message .assets .clone() .into_iter() .map(|asset| match thread_type { ThreadType::Bars(_) => asset.symbol, ThreadType::News => asset.abbreviation, }) .collect::>(); match message.action { Action::Add => { let mut guard = guard.write().await; guard.symbols.extend(symbols.clone()); guard .pending_subscriptions .extend(symbols.clone().into_iter().zip(message.assets.clone())); info!("{:?} - Added {:?}.", thread_type, symbols); let database_future = async { if matches!(thread_type, ThreadType::Bars(_)) { database::assets::upsert_batch(&app_config.clickhouse_client, message.assets) .await; } }; let websocket_future = async move { websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Subscribe( websocket_market_message_factory(thread_type, symbols), )) .unwrap(), )) .await .unwrap(); }; join!(database_future, websocket_future); } Action::Remove => { let mut guard = guard.write().await; guard.symbols.retain(|symbol| !symbols.contains(symbol)); guard .pending_unsubscriptions .extend(symbols.clone().into_iter().zip(message.assets.clone())); info!("{:?} - Removed {:?}.", thread_type, symbols); let sybols_clone = symbols.clone(); let database_future = database::assets::delete_where_symbols( &app_config.clickhouse_client, &sybols_clone, ); let websocket_future = async move { websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Unsubscribe( websocket_market_message_factory(thread_type, symbols), )) .unwrap(), )) .await .unwrap(); }; join!(database_future, websocket_future); } } message.response.send(()).unwrap(); } fn websocket_market_message_factory( thread_type: ThreadType, symbols: Vec, ) -> websocket::outgoing::subscribe::Message { match thread_type { ThreadType::Bars(_) => websocket::outgoing::subscribe::Message::Market( websocket::outgoing::subscribe::MarketMessage::new(symbols), ), ThreadType::News => websocket::outgoing::subscribe::Message::News( websocket::outgoing::subscribe::NewsMessage::new(symbols), ), } }