use super::{Guard, ThreadType}; use crate::{ config::Config, database, types::{alpaca::websocket, Asset}, }; use async_trait::async_trait; use futures_util::{stream::SplitSink, SinkExt}; use log::info; use serde_json::to_string; use std::sync::Arc; use tokio::{ join, net::TcpStream, spawn, sync::{mpsc, oneshot, Mutex, RwLock}, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; #[derive(Clone)] pub enum Action { Add, Remove, } pub struct Message { pub action: Action, pub assets: Vec, pub response: oneshot::Sender<()>, } impl Message { pub fn new(action: Action, assets: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { action, assets, response: sender, }, receiver, ) } } #[async_trait] pub trait Handler: Send + Sync { async fn add_assets(&self, assets: Vec, symbols: Vec); async fn remove_assets(&self, assets: Vec, symbols: Vec); } pub async fn run( handler: Arc>, guard: Arc>, mut receiver: mpsc::Receiver, ) { loop { let message = receiver.recv().await.unwrap(); spawn(handle_asset_status_message( handler.clone(), guard.clone(), message, )); } } #[allow(clippy::significant_drop_tightening)] async fn handle_asset_status_message( handler: Arc>, guard: Arc>, message: Message, ) { let symbols = message .assets .clone() .into_iter() .map(|asset| asset.symbol) .collect::>(); match message.action { Action::Add => { let mut guard = guard.write().await; guard.assets.extend( message .assets .iter() .map(|asset| (asset.clone(), asset.symbol.clone())), ); guard.pending_subscriptions.extend(message.assets.clone()); handler.add_assets(message.assets, symbols).await; } Action::Remove => { let mut guard = guard.write().await; guard .assets .retain(|asset, _| !message.assets.contains(asset)); guard.pending_unsubscriptions.extend(message.assets.clone()); handler.remove_assets(message.assets, symbols).await; } } message.response.send(()).unwrap(); } pub fn create_asset_status_handler( thread_type: ThreadType, app_config: Arc, websocket_sender: Arc< Mutex>, tungstenite::Message>>, >, ) -> Box { match thread_type { ThreadType::Bars(_) => Box::new(BarsHandler { app_config, websocket_sender, }), ThreadType::News => Box::new(NewsHandler { websocket_sender }), } } struct BarsHandler { app_config: Arc, websocket_sender: Arc>, tungstenite::Message>>>, } #[async_trait] impl Handler for BarsHandler { async fn add_assets(&self, assets: Vec, symbols: Vec) { let database_future = database::assets::upsert_batch(&self.app_config.clickhouse_client, assets); let symbols_clone = symbols.clone(); let websocket_future = async move { self.websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Subscribe( websocket::outgoing::subscribe::Message::new_market(symbols_clone), )) .unwrap(), )) .await .unwrap(); }; join!(database_future, websocket_future); info!("Added {:?}.", symbols); } async fn remove_assets(&self, _: Vec, symbols: Vec) { let symbols_clone = symbols.clone(); let database_future = database::assets::delete_where_symbols( &self.app_config.clickhouse_client, &symbols_clone, ); let symbols_clone = symbols.clone(); let websocket_future = async move { self.websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Unsubscribe( websocket::outgoing::subscribe::Message::new_market(symbols_clone), )) .unwrap(), )) .await .unwrap(); }; join!(database_future, websocket_future); info!("Removed {:?}.", symbols); } } struct NewsHandler { websocket_sender: Arc>, tungstenite::Message>>>, } #[async_trait] impl Handler for NewsHandler { async fn add_assets(&self, _: Vec, symbols: Vec) { self.websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Subscribe( websocket::outgoing::subscribe::Message::new_news(symbols), )) .unwrap(), )) .await .unwrap(); } async fn remove_assets(&self, _: Vec, symbols: Vec) { self.websocket_sender .lock() .await .send(tungstenite::Message::Text( to_string(&websocket::outgoing::Message::Unsubscribe( websocket::outgoing::subscribe::Message::new_news(symbols), )) .unwrap(), )) .await .unwrap(); } }