use crate::{ config::{Config, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL}, data::authenticate_websocket, database, types::{ alpaca::{api, websocket, Source}, state, Asset, Backfill, Bar, BroadcastMessage, Class, }, utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}, }; use backoff::{future::retry, ExponentialBackoff}; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use log::{error, info, warn}; use serde_json::{from_str, to_string}; use std::{ collections::{HashMap, HashSet}, sync::Arc, }; use time::OffsetDateTime; use tokio::{ net::TcpStream, spawn, sync::{broadcast::Sender, Mutex, RwLock}, task::JoinHandle, time::sleep, }; use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream}; pub struct Guard { symbols: HashSet, backfill_jobs: HashMap>, pending_subscriptions: HashMap, pending_unsubscriptions: HashMap, } pub async fn run( app_config: Arc, class: Class, broadcast_bus_sender: Sender, ) { info!("Running live threads for {:?}.", class); let websocket_url = match class { Class::UsEquity => format!( "{}/{}", ALPACA_STOCK_WEBSOCKET_URL, app_config.alpaca_source ), Class::Crypto => ALPACA_CRYPTO_WEBSOCKET_URL.to_string(), }; let (stream, _) = connect_async(websocket_url).await.unwrap(); let (mut sink, mut stream) = stream.split(); authenticate_websocket(&app_config, &mut stream, &mut sink).await; let sink = Arc::new(Mutex::new(sink)); let guard = Arc::new(RwLock::new(Guard { symbols: HashSet::new(), backfill_jobs: HashMap::new(), pending_subscriptions: HashMap::new(), pending_unsubscriptions: HashMap::new(), })); spawn(broadcast_bus_handler( app_config.clone(), class, sink.clone(), broadcast_bus_sender.clone(), guard.clone(), )); spawn(websocket_handler( app_config.clone(), stream, sink, broadcast_bus_sender.clone(), guard.clone(), )); let assets = database::assets::select_where_class(&app_config.clickhouse_client, &class).await; broadcast_bus_sender .send(BroadcastMessage::Asset(( state::asset::BroadcastMessage::Add, assets, ))) .unwrap(); } pub async fn broadcast_bus_handler( app_config: Arc, class: Class, sink: Arc>, Message>>>, broadcast_bus_sender: Sender, guard: Arc>, ) { let mut broadcast_bus_receiver = broadcast_bus_sender.subscribe(); loop { let app_config = app_config.clone(); let sink = sink.clone(); let broadcast_bus_sender = broadcast_bus_sender.clone(); let guard = guard.clone(); let message = broadcast_bus_receiver.recv().await.unwrap(); spawn(broadcast_bus_handle_message( app_config, class, sink, broadcast_bus_sender, guard, message, )); } } #[allow(clippy::significant_drop_tightening)] #[allow(clippy::too_many_lines)] async fn broadcast_bus_handle_message( app_config: Arc, class: Class, sink: Arc>, Message>>>, broadcast_bus_sender: Sender, guard: Arc>, message: BroadcastMessage, ) { match message { BroadcastMessage::Asset((action, mut assets)) => { assets.retain(|asset| asset.class == class); if assets.is_empty() { return; } let assets = assets .into_iter() .map(|asset| (asset.symbol.clone(), asset)) .collect::>(); let symbols = assets.keys().cloned().collect::>(); match action { state::asset::BroadcastMessage::Add => { database::assets::upsert_batch( &app_config.clickhouse_client, assets.clone().into_values(), ) .await; let mut guard = guard.write().await; guard.symbols.extend(symbols.clone()); guard.pending_subscriptions.extend(assets); info!("Added {:?}.", symbols); sink.lock() .await .send(Message::Text( to_string(&websocket::data::outgoing::Message::Subscribe( websocket::data::outgoing::subscribe::Message::new(symbols), )) .unwrap(), )) .await .unwrap(); } state::asset::BroadcastMessage::Delete => { database::assets::delete_where_symbols(&app_config.clickhouse_client, &symbols) .await; let mut guard = guard.write().await; guard.symbols.retain(|symbol| !assets.contains_key(symbol)); guard.pending_unsubscriptions.extend(assets); info!("Deleted {:?}.", symbols); sink.lock() .await .send(Message::Text( to_string(&websocket::data::outgoing::Message::Unsubscribe( websocket::data::outgoing::subscribe::Message::new(symbols), )) .unwrap(), )) .await .unwrap(); } state::asset::BroadcastMessage::Backfill => { let guard_clone = guard.clone(); let mut guard = guard.write().await; info!("Creating backfill jobs for {:?}.", symbols); for (symbol, asset) in assets { if let Some(backfill_job) = guard.backfill_jobs.remove(&symbol) { backfill_job.abort(); backfill_job.await.unwrap_err(); } guard.backfill_jobs.insert(symbol.clone(), { let guard = guard_clone.clone(); let app_config = app_config.clone(); spawn(async move { backfill(app_config, class, asset.clone()).await; let mut guard = guard.write().await; guard.backfill_jobs.remove(&symbol); }) }); } } state::asset::BroadcastMessage::Purge => { let mut guard = guard.write().await; info!("Purging {:?}.", symbols); for (symbol, _) in assets { if let Some(backfill_job) = guard.backfill_jobs.remove(&symbol) { backfill_job.abort(); backfill_job.await.unwrap_err(); } } database::backfills::delete_where_symbols( &app_config.clickhouse_client, &symbols, ) .await; database::bars::delete_where_symbols(&app_config.clickhouse_client, &symbols) .await; } } } BroadcastMessage::Clock(_) => { broadcast_bus_sender .send(BroadcastMessage::Asset(( state::asset::BroadcastMessage::Backfill, database::assets::select(&app_config.clickhouse_client).await, ))) .unwrap(); } } } async fn websocket_handler( app_config: Arc, mut stream: SplitStream>>, sink: Arc>, Message>>>, broadcast_bus_sender: Sender, guard: Arc>, ) { loop { let app_config = app_config.clone(); let sink = sink.clone(); let broadcast_bus_sender = broadcast_bus_sender.clone(); let guard = guard.clone(); let message = stream.next().await.expect("Websocket stream closed."); spawn(async move { match message { Ok(Message::Text(data)) => { let parsed_data = from_str::>(&data); if let Ok(messages) = parsed_data { for message in messages { websocket_handle_message( app_config.clone(), broadcast_bus_sender.clone(), guard.clone(), message, ) .await; } } else { error!( "Unparsed websocket message: {:?}: {}.", data, parsed_data.unwrap_err() ); } } Ok(Message::Ping(_)) => { sink.lock().await.send(Message::Pong(vec![])).await.unwrap(); } _ => error!("Unknown websocket message: {:?}.", message), } }); } } #[allow(clippy::significant_drop_tightening)] async fn websocket_handle_message( app_config: Arc, broadcast_bus_sender: Sender, guard: Arc>, message: websocket::data::incoming::Message, ) { match message { websocket::data::incoming::Message::Subscription(message) => { let symbols = message.bars.into_iter().collect::>(); let mut guard = guard.write().await; let newly_subscribed_assets = guard .pending_subscriptions .extract_if(|symbol, _| symbols.contains(symbol)) .collect::>(); if !newly_subscribed_assets.is_empty() { info!( "Subscribed to {:?}.", newly_subscribed_assets.keys().collect::>() ); broadcast_bus_sender .send(BroadcastMessage::Asset(( state::asset::BroadcastMessage::Backfill, newly_subscribed_assets.into_values().collect::>(), ))) .unwrap(); } let newly_unsubscribed_assets = guard .pending_unsubscriptions .extract_if(|symbol, _| !symbols.contains(symbol)) .collect::>(); if !newly_unsubscribed_assets.is_empty() { info!( "Unsubscribed from {:?}.", newly_unsubscribed_assets.keys().collect::>() ); broadcast_bus_sender .send(BroadcastMessage::Asset(( state::asset::BroadcastMessage::Purge, newly_unsubscribed_assets.into_values().collect::>(), ))) .unwrap(); } } websocket::data::incoming::Message::Bars(bar_message) | websocket::data::incoming::Message::UpdatedBars(bar_message) => { let bar = Bar::from(bar_message); let guard = guard.read().await; let symbol_status = guard.symbols.get(&bar.symbol); if symbol_status.is_none() { warn!( "Race condition: received bar for unsubscribed symbol: {:?}.", bar.symbol ); return; } info!("Received bar for {}: {}.", bar.symbol, bar.time); database::bars::upsert(&app_config.clickhouse_client, &bar).await; } websocket::data::incoming::Message::Success(_) => {} } } pub async fn backfill(app_config: Arc, class: Class, asset: Asset) { let latest_backfill = database::backfills::select_latest_where_symbol( &app_config.clickhouse_client, &asset.symbol, ) .await; let fetch_from = if let Some(backfill) = latest_backfill { backfill.time + ONE_MINUTE } else { OffsetDateTime::UNIX_EPOCH }; let fetch_until = last_minute(); if fetch_from > fetch_until { return; } if app_config.alpaca_source == Source::Iex { let task_run_delay = duration_until(fetch_until + FIFTEEN_MINUTES + ONE_MINUTE); info!( "Queing backfill for {} in {:?}.", asset.symbol, task_run_delay ); sleep(task_run_delay).await; } info!("Running backfill for {}.", asset.symbol); let mut bars = Vec::new(); let mut next_page_token = None; loop { let message = retry(ExponentialBackoff::default(), || async { app_config.alpaca_rate_limit.until_ready().await; app_config .alpaca_client .get(class.get_data_url()) .query(&api::outgoing::bar::Bar::new( vec![asset.symbol.clone()], ONE_MINUTE, fetch_from, fetch_until, 10000, next_page_token.clone(), )) .send() .await? .error_for_status()? .json::() .await .map_err(backoff::Error::Permanent) }) .await; let message = match message { Ok(message) => message, Err(e) => { error!("Failed to backfill data for {}: {}.", asset.symbol, e); return; } }; message.bars.into_iter().for_each(|(symbol, bar_vec)| { bar_vec.unwrap_or_default().into_iter().for_each(|bar| { bars.push(Bar::from((bar, symbol.clone()))); }); }); if message.next_page_token.is_none() { break; } next_page_token = message.next_page_token; } database::bars::upsert_batch(&app_config.clickhouse_client, bars).await; database::backfills::upsert( &app_config.clickhouse_client, &Backfill::new(asset.symbol.clone(), fetch_until), ) .await; info!("Backfilled data for {}.", asset.symbol); }