diff --git a/src/data/market.rs b/src/data/market.rs index 9f3c476..5769000 100644 --- a/src/data/market.rs +++ b/src/data/market.rs @@ -6,10 +6,7 @@ use crate::{ data::authenticate_websocket, database, types::{ - alpaca::{ - api::{incoming, outgoing}, - websocket, Source, - }, + alpaca::{api, websocket, Source}, asset::{self, Asset}, Bar, BarValidity, BroadcastMessage, Class, }, @@ -63,16 +60,13 @@ pub async fn run( broadcast_sender.subscribe(), )); - database::assets::select_where_class(&app_config.clickhouse_client, class) - .await - .into_iter() - .for_each(|asset| { - broadcast_sender - .send(BroadcastMessage::Asset(asset::BroadcastMessage::Added( - asset, - ))) - .unwrap(); - }); + let assets = database::assets::select_where_class(&app_config.clickhouse_client, class).await; + broadcast_sender + .send(BroadcastMessage::Asset(( + asset::BroadcastMessage::Added, + assets, + ))) + .unwrap(); websocket_handler(app_config, class, stream, sink).await; @@ -86,39 +80,41 @@ async fn broadcast_handler( ) { loop { match broadcast_receiver.recv().await.unwrap() { - BroadcastMessage::Asset(asset::BroadcastMessage::Added(asset)) - if asset.class == class => - { + BroadcastMessage::Asset((action, assets)) => { + let symbols = assets + .into_iter() + .filter(|asset| asset.class == class) + .map(|asset| asset.symbol) + .collect::>(); + + if symbols.is_empty() { + continue; + } + sink.write() .await .send(Message::Text( - serde_json::to_string(&websocket::data::outgoing::Message::Subscribe( - websocket::data::outgoing::subscribe::Message::new( - asset.clone().symbol, - ), - )) + serde_json::to_string(&match action { + asset::BroadcastMessage::Added => { + websocket::data::outgoing::Message::Subscribe( + websocket::data::outgoing::subscribe::Message::new( + symbols.clone(), + ), + ) + } + asset::BroadcastMessage::Deleted => { + websocket::data::outgoing::Message::Unsubscribe( + websocket::data::outgoing::subscribe::Message::new( + symbols.clone(), + ), + ) + } + }) .unwrap(), )) .await .unwrap(); } - BroadcastMessage::Asset(asset::BroadcastMessage::Deleted(asset)) - if asset.class == class => - { - sink.write() - .await - .send(Message::Text( - serde_json::to_string(&websocket::data::outgoing::Message::Unsubscribe( - websocket::data::outgoing::subscribe::Message::new( - asset.clone().symbol, - ), - )) - .unwrap(), - )) - .await - .unwrap(); - } - BroadcastMessage::Asset(_) => {} } } } @@ -143,7 +139,12 @@ async fn websocket_handler( } for message in parsed_data.unwrap_or_default() { - websocket_handle_message(&app_config, class, &backfilled, message).await; + spawn(websocket_handle_message( + app_config.clone(), + class, + backfilled.clone(), + message, + )); } } Some(Ok(Message::Ping(_))) => sink @@ -159,68 +160,75 @@ async fn websocket_handler( } async fn websocket_handle_message( - app_config: &Arc, + app_config: Arc, class: Class, - backfilled: &Arc>>, + backfilled: Arc>>, message: websocket::data::incoming::Message, ) { match message { - websocket::data::incoming::Message::Subscription(subscription_message) => { - let old_assets = backfilled - .read() - .await - .keys() - .cloned() - .collect::>(); - let new_assets = subscription_message - .bars - .into_iter() - .collect::>(); + websocket::data::incoming::Message::Subscription(message) => { + let added_asset_symbols; + let deleted_asset_symbols; - let added_assets = new_assets.difference(&old_assets).collect::>(); - let deleted_assets = old_assets.difference(&new_assets).collect::>(); + { + let mut backfilled = backfilled.write().await; - for asset_symbol in &added_assets { + let old_asset_sybols = backfilled.keys().cloned().collect::>(); + let new_asset_symbols = message.bars.into_iter().collect::>(); + + added_asset_symbols = new_asset_symbols + .difference(&old_asset_sybols) + .cloned() + .collect::>(); + + for asset_symbol in &added_asset_symbols { + backfilled.insert(asset_symbol.clone(), false); + } + + deleted_asset_symbols = old_asset_sybols + .difference(&new_asset_symbols) + .cloned() + .collect::>(); + + for asset_symbol in &deleted_asset_symbols { + backfilled.remove(asset_symbol); + } + + drop(backfilled); + + info!( + "Subscription update for {:?}: {:?} added, {:?} deleted.", + class, added_asset_symbols, deleted_asset_symbols + ); + } + + for asset_symbol in added_asset_symbols { let asset = database::assets::select_where_symbol( &app_config.clickhouse_client, - asset_symbol, + &asset_symbol, ) .await .unwrap(); - backfilled.write().await.insert(asset.symbol.clone(), false); - - let bar_validity = BarValidity::none(asset.symbol.clone()); database::bars::insert_validity_if_not_exists( &app_config.clickhouse_client, - &bar_validity, + &BarValidity::none(asset.symbol.clone()), ) .await; - spawn(backfill( - app_config.clone(), - backfilled.clone(), - asset.clone(), - )); + spawn(backfill(app_config.clone(), backfilled.clone(), asset)); } - for asset_symbol in &deleted_assets { + for asset_symbol in deleted_asset_symbols { database::bars::delete_validity_where_symbol( &app_config.clickhouse_client, - asset_symbol, + &asset_symbol, ) .await; - database::bars::delete_where_symbol(&app_config.clickhouse_client, asset_symbol) + database::bars::delete_where_symbol(&app_config.clickhouse_client, &asset_symbol) .await; - - backfilled.write().await.remove(*asset_symbol); } - - info!( - "Subscription update for {:?}: {:?} added, {:?} deleted.", - class, added_assets, deleted_assets - ); } websocket::data::incoming::Message::Bars(bar_message) | websocket::data::incoming::Message::UpdatedBars(bar_message) => { @@ -228,7 +236,7 @@ async fn websocket_handle_message( info!("websocket::Incoming bar for {}: {}", bar.symbol, bar.time); database::bars::upsert(&app_config.clickhouse_client, &bar).await; - if backfilled.read().await[&bar.symbol] { + if *backfilled.read().await.get(&bar.symbol).unwrap() { database::bars::upsert_validity(&app_config.clickhouse_client, &bar.into()).await; } } @@ -255,7 +263,7 @@ pub async fn backfill( .send() .await .unwrap() - .json::() + .json::() .await .unwrap(); @@ -294,9 +302,9 @@ pub async fn backfill( Class::UsEquity => ALPACA_STOCK_DATA_URL, Class::Crypto => ALPACA_CRYPTO_DATA_URL, }) - .query(&outgoing::bar::Bar::new( + .query(&api::outgoing::bar::Bar::new( vec![asset.symbol.clone()], - String::from("1Min"), + ONE_MINUTE, fetch_from, fetch_until, 10000, @@ -305,7 +313,7 @@ pub async fn backfill( .send() .await .unwrap() - .json::() + .json::() .await .unwrap(); diff --git a/src/routes/assets.rs b/src/routes/assets.rs index f12f245..657cb10 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -69,8 +69,9 @@ pub async fn add( database::assets::upsert(&app_config.clickhouse_client, &asset).await; broadcast_sender - .send(BroadcastMessage::Asset(asset::BroadcastMessage::Added( - asset.clone(), + .send(BroadcastMessage::Asset(( + asset::BroadcastMessage::Added, + vec![asset.clone()], ))) .unwrap(); @@ -89,8 +90,9 @@ pub async fn delete( .unwrap(); broadcast_sender - .send(BroadcastMessage::Asset(asset::BroadcastMessage::Deleted( - asset, + .send(BroadcastMessage::Asset(( + asset::BroadcastMessage::Deleted, + vec![asset.clone()], ))) .unwrap(); diff --git a/src/types/alpaca/api/outgoing/bar.rs b/src/types/alpaca/api/outgoing/bar.rs index 3a94ccd..5617593 100644 --- a/src/types/alpaca/api/outgoing/bar.rs +++ b/src/types/alpaca/api/outgoing/bar.rs @@ -1,10 +1,54 @@ -use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use serde::{Serialize, Serializer}; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +fn serialize_symbols(symbols: &[String], serializer: S) -> Result +where + S: Serializer, +{ + let string = symbols.join(","); + serializer.serialize_str(&string) +} + +fn serialize_timeframe(timeframe: &Duration, serializer: S) -> Result +where + S: serde::Serializer, +{ + let mins = timeframe.as_secs() / 60; + if mins < 60 { + return serializer.serialize_str(&format!("{mins}Min")); + } + + let hours = mins / 60; + if hours < 24 { + return serializer.serialize_str(&format!("{hours}Hour")); + } + + let days = hours / 24; + if days == 1 { + return serializer.serialize_str("1Day"); + } + + let weeks = days / 7; + if weeks == 1 { + return serializer.serialize_str("1Week"); + } + + let months = days / 30; + if [1, 2, 3, 4, 6, 12].contains(&months) { + return serializer.serialize_str(&format!("{months}Month")); + }; + + Err(serde::ser::Error::custom("Invalid timeframe duration")) +} + +#[derive(Serialize)] pub struct Bar { + #[serde(serialize_with = "serialize_symbols")] pub symbols: Vec, - pub timeframe: String, + #[serde(serialize_with = "serialize_timeframe")] + pub timeframe: Duration, #[serde(with = "time::serde::rfc3339")] pub start: OffsetDateTime, #[serde(with = "time::serde::rfc3339")] @@ -15,9 +59,9 @@ pub struct Bar { } impl Bar { - pub fn new( + pub const fn new( symbols: Vec, - timeframe: String, + timeframe: Duration, start: OffsetDateTime, end: OffsetDateTime, limit: i64, diff --git a/src/types/alpaca/websocket/data/outgoing/subscribe.rs b/src/types/alpaca/websocket/data/outgoing/subscribe.rs index 5250dd9..afea0af 100644 --- a/src/types/alpaca/websocket/data/outgoing/subscribe.rs +++ b/src/types/alpaca/websocket/data/outgoing/subscribe.rs @@ -8,10 +8,10 @@ pub struct Message { } impl Message { - pub fn new(symbol: String) -> Self { + pub fn new(symbols: Vec) -> Self { Self { - bars: vec![symbol.clone()], - updated_bars: vec![symbol], + bars: symbols.clone(), + updated_bars: symbols, } } } diff --git a/src/types/asset.rs b/src/types/asset.rs index de80381..1f53ebc 100644 --- a/src/types/asset.rs +++ b/src/types/asset.rs @@ -34,6 +34,6 @@ pub struct Asset { #[derive(Clone, Debug, PartialEq, Eq)] pub enum BroadcastMessage { - Added(Asset), - Deleted(Asset), + Added, + Deleted, } diff --git a/src/types/mod.rs b/src/types/mod.rs index f1aba5f..4fd2b8e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -7,5 +7,5 @@ pub use bar::{Bar, BarValidity}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum BroadcastMessage { - Asset(asset::BroadcastMessage), + Asset((asset::BroadcastMessage, Vec)), }