diff --git a/Cargo.lock b/Cargo.lock index 79b8655..b06d862 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,12 +195,6 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" -[[package]] -name = "bimap" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" - [[package]] name = "bitflags" version = "1.3.2" @@ -903,9 +897,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "hmac" @@ -1073,9 +1067,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.59" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1169,6 +1163,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -1186,9 +1189,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.67" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -1338,9 +1341,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] @@ -1401,9 +1404,9 @@ checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "num-traits", ] @@ -1656,13 +1659,13 @@ dependencies = [ "async-trait", "axum", "backoff", - "bimap", "clickhouse", "dotenv", "futures-util", "governor", "html-escape", "http 1.0.0", + "itertools 0.12.1", "log", "log4rs", "regex", @@ -1884,7 +1887,7 @@ checksum = "19599f60a688b5160247ee9c37a6af8b0c742ee8b160c5b44acc0f0eb265a59f" dependencies = [ "csv", "hashbrown 0.14.3", - "itertools", + "itertools 0.11.0", "lazy_static", "protobuf", "rayon", @@ -2242,13 +2245,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", "rustix", "windows-sys 0.52.0", ] @@ -2285,9 +2287,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.32" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "itoa", @@ -2528,9 +2530,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unsafe-any-ors" @@ -2602,9 +2604,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2612,9 +2614,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", @@ -2627,9 +2629,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -2639,9 +2641,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2649,9 +2651,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", @@ -2662,15 +2664,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "web-sys" -version = "0.3.67" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 5021cb7..5836868 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,5 +52,5 @@ backoff = { version = "0.4.0", features = [ regex = "1.10.3" html-escape = "0.2.13" rust-bert = "0.22.0" -bimap = "0.6.3" async-trait = "0.1.77" +itertools = "0.12.1" diff --git a/src/main.rs b/src/main.rs index 25667cc..5e9ab66 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,27 +23,27 @@ async fn main() { cleanup(&app_config.clickhouse_client).await; - let (asset_status_sender, asset_status_receiver) = - mpsc::channel::(100); + let (data_sender, data_receiver) = mpsc::channel::(100); let (clock_sender, clock_receiver) = mpsc::channel::(1); spawn(threads::data::run( app_config.clone(), - asset_status_receiver, + data_receiver, clock_receiver, )); spawn(threads::clock::run(app_config.clone(), clock_sender)); - let assets = database::assets::select(&app_config.clickhouse_client).await; - - let (asset_status_message, asset_status_receiver) = - threads::data::asset_status::Message::new(threads::data::asset_status::Action::Add, assets); - asset_status_sender - .send(asset_status_message) + let assets = database::assets::select(&app_config.clickhouse_client) .await - .unwrap(); - asset_status_receiver.await.unwrap(); + .into_iter() + .map(|asset| (asset.symbol, asset.class)) + .collect::>(); - routes::run(app_config, asset_status_sender).await; + let (data_message, data_receiver) = + threads::data::Message::new(threads::data::Action::Add, assets); + data_sender.send(data_message).await.unwrap(); + data_receiver.await.unwrap(); + + routes::run(app_config, data_sender).await; } diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 681311a..4d361aa 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -1,14 +1,9 @@ use crate::{ - config::{Config, ALPACA_ASSET_API_URL}, + config::Config, database, threads, - types::{ - alpaca::api::incoming::{self, asset::Status}, - Asset, - }, + types::{alpaca::api::incoming, Asset}, }; use axum::{extract::Path, Extension, Json}; -use backoff::{future::retry, ExponentialBackoff}; -use core::panic; use http::StatusCode; use serde::Deserialize; use std::sync::Arc; @@ -38,9 +33,9 @@ pub struct AddAssetRequest { pub async fn add( Extension(app_config): Extension>, - Extension(asset_status_sender): Extension>, + Extension(data_sender): Extension>, Json(request): Json, -) -> Result<(StatusCode, Json), StatusCode> { +) -> Result { if database::assets::select_where_symbol(&app_config.clickhouse_client, &request.symbol) .await .is_some() @@ -48,66 +43,38 @@ pub async fn add( return Err(StatusCode::CONFLICT); } - let asset = retry(ExponentialBackoff::default(), || async { - app_config.alpaca_rate_limit.until_ready().await; - app_config - .alpaca_client - .get(&format!("{}/{}", ALPACA_ASSET_API_URL, request.symbol)) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::NOT_FOUND) => backoff::Error::Permanent(e), - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) - .await - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::NOT_FOUND) => StatusCode::NOT_FOUND, - _ => panic!("Unexpected error: {}.", e), - })?; - - if asset.status != Status::Active || !asset.tradable || !asset.fractionable { + let asset = incoming::asset::get_by_symbol(&app_config, &request.symbol).await?; + if !asset.tradable || !asset.fractionable { return Err(StatusCode::FORBIDDEN); } - let asset = Asset::from(asset); - let (asset_status_message, asset_status_response) = threads::data::asset_status::Message::new( - threads::data::asset_status::Action::Add, - vec![asset.clone()], + let (data_message, data_response) = threads::data::Message::new( + threads::data::Action::Add, + vec![(asset.symbol, asset.class)], ); - asset_status_sender - .send(asset_status_message) - .await - .unwrap(); - asset_status_response.await.unwrap(); + data_sender.send(data_message).await.unwrap(); + data_response.await.unwrap(); - Ok((StatusCode::CREATED, Json(asset))) + Ok(StatusCode::CREATED) } pub async fn delete( Extension(app_config): Extension>, - Extension(asset_status_sender): Extension>, + Extension(data_sender): Extension>, Path(symbol): Path, ) -> Result { let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol) .await .ok_or(StatusCode::NOT_FOUND)?; - let (asset_status_message, asset_status_response) = threads::data::asset_status::Message::new( - threads::data::asset_status::Action::Remove, - vec![asset], + let (asset_status_message, asset_status_response) = threads::data::Message::new( + threads::data::Action::Remove, + vec![(asset.symbol, asset.class)], ); - asset_status_sender - .send(asset_status_message) - .await - .unwrap(); + data_sender.send(asset_status_message).await.unwrap(); asset_status_response.await.unwrap(); Ok(StatusCode::NO_CONTENT) diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 3c371ff..247d52c 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -10,10 +10,7 @@ use log::info; use std::{net::SocketAddr, sync::Arc}; use tokio::{net::TcpListener, sync::mpsc}; -pub async fn run( - app_config: Arc, - asset_status_sender: mpsc::Sender, -) { +pub async fn run(app_config: Arc, data_sender: mpsc::Sender) { let app = Router::new() .route("/health", get(health::get)) .route("/assets", get(assets::get)) @@ -21,7 +18,7 @@ pub async fn run( .route("/assets", post(assets::add)) .route("/assets/:symbol", delete(assets::delete)) .layer(Extension(app_config)) - .layer(Extension(asset_status_sender)); + .layer(Extension(data_sender)); let addr = SocketAddr::from(([0, 0, 0, 0], 7878)); let listener = TcpListener::bind(addr).await.unwrap(); diff --git a/src/threads/clock.rs b/src/threads/clock.rs index 631b1df..851e8a7 100644 --- a/src/threads/clock.rs +++ b/src/threads/clock.rs @@ -1,9 +1,4 @@ -use crate::{ - config::{Config, ALPACA_CLOCK_API_URL}, - types::alpaca, - utils::duration_until, -}; -use backoff::{future::retry, ExponentialBackoff}; +use crate::{config::Config, types::alpaca, utils::duration_until}; use log::info; use std::sync::Arc; use time::OffsetDateTime; @@ -37,20 +32,7 @@ impl From for Message { pub async fn run(app_config: Arc, sender: mpsc::Sender) { loop { - let clock = retry(ExponentialBackoff::default(), || async { - app_config.alpaca_rate_limit.until_ready().await; - app_config - .alpaca_client - .get(ALPACA_CLOCK_API_URL) - .send() - .await? - .error_for_status()? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) - .await - .unwrap(); + let clock = alpaca::api::incoming::clock::get(&app_config).await; let sleep_until = duration_until(if clock.is_open { info!("Market is open, will close at {}.", clock.next_close); diff --git a/src/threads/data/asset_status.rs b/src/threads/data/asset_status.rs deleted file mode 100644 index d58f27d..0000000 --- a/src/threads/data/asset_status.rs +++ /dev/null @@ -1,218 +0,0 @@ -use super::{Guard, ThreadType}; -use crate::{ - config::Config, - database, - types::{alpaca::websocket, Asset}, -}; -use async_trait::async_trait; -use futures_util::{stream::SplitSink, SinkExt}; -use log::info; -use serde_json::to_string; -use std::sync::Arc; -use tokio::{ - join, - net::TcpStream, - spawn, - sync::{mpsc, oneshot, Mutex, RwLock}, -}; -use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; - -#[derive(Clone)] -pub enum Action { - Add, - Remove, -} - -pub struct Message { - pub action: Action, - pub assets: Vec, - pub response: oneshot::Sender<()>, -} - -impl Message { - pub fn new(action: Action, assets: Vec) -> (Self, oneshot::Receiver<()>) { - let (sender, receiver) = oneshot::channel::<()>(); - ( - Self { - action, - assets, - response: sender, - }, - receiver, - ) - } -} - -#[async_trait] -pub trait Handler: Send + Sync { - async fn add_assets(&self, assets: Vec, symbols: Vec); - async fn remove_assets(&self, assets: Vec, symbols: Vec); -} - -pub async fn run( - handler: Arc>, - guard: Arc>, - mut receiver: mpsc::Receiver, -) { - loop { - let message = receiver.recv().await.unwrap(); - - spawn(handle_asset_status_message( - handler.clone(), - guard.clone(), - message, - )); - } -} - -#[allow(clippy::significant_drop_tightening)] -async fn handle_asset_status_message( - handler: Arc>, - guard: Arc>, - message: Message, -) { - let symbols = message - .assets - .clone() - .into_iter() - .map(|asset| asset.symbol) - .collect::>(); - - match message.action { - Action::Add => { - let mut guard = guard.write().await; - - guard.assets.extend( - message - .assets - .iter() - .map(|asset| (asset.clone(), asset.symbol.clone())), - ); - guard.pending_subscriptions.extend(message.assets.clone()); - - handler.add_assets(message.assets, symbols).await; - } - Action::Remove => { - let mut guard = guard.write().await; - - guard - .assets - .retain(|asset, _| !message.assets.contains(asset)); - guard.pending_unsubscriptions.extend(message.assets.clone()); - - handler.remove_assets(message.assets, symbols).await; - } - } - - message.response.send(()).unwrap(); -} - -pub fn create_asset_status_handler( - thread_type: ThreadType, - app_config: Arc, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, -) -> Box { - match thread_type { - ThreadType::Bars(_) => Box::new(BarsHandler { - app_config, - websocket_sender, - }), - ThreadType::News => Box::new(NewsHandler { websocket_sender }), - } -} - -struct BarsHandler { - app_config: Arc, - websocket_sender: - Arc>, tungstenite::Message>>>, -} - -#[async_trait] -impl Handler for BarsHandler { - async fn add_assets(&self, assets: Vec, symbols: Vec) { - let database_future = - database::assets::upsert_batch(&self.app_config.clickhouse_client, assets); - - let symbols_clone = symbols.clone(); - let websocket_future = async move { - self.websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Subscribe( - websocket::outgoing::subscribe::Message::new_market(symbols_clone), - )) - .unwrap(), - )) - .await - .unwrap(); - }; - - join!(database_future, websocket_future); - info!("Added {:?}.", symbols); - } - - async fn remove_assets(&self, _: Vec, symbols: Vec) { - let symbols_clone = symbols.clone(); - let database_future = database::assets::delete_where_symbols( - &self.app_config.clickhouse_client, - &symbols_clone, - ); - - let symbols_clone = symbols.clone(); - let websocket_future = async move { - self.websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Unsubscribe( - websocket::outgoing::subscribe::Message::new_market(symbols_clone), - )) - .unwrap(), - )) - .await - .unwrap(); - }; - - join!(database_future, websocket_future); - info!("Removed {:?}.", symbols); - } -} - -struct NewsHandler { - websocket_sender: - Arc>, tungstenite::Message>>>, -} - -#[async_trait] -impl Handler for NewsHandler { - async fn add_assets(&self, _: Vec, symbols: Vec) { - self.websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Subscribe( - websocket::outgoing::subscribe::Message::new_news(symbols), - )) - .unwrap(), - )) - .await - .unwrap(); - } - - async fn remove_assets(&self, _: Vec, symbols: Vec) { - self.websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Unsubscribe( - websocket::outgoing::subscribe::Message::new_news(symbols), - )) - .unwrap(), - )) - .await - .unwrap(); - } -} diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index dabc168..cc7af7b 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -1,26 +1,29 @@ -use super::{Guard, ThreadType}; +use super::ThreadType; use crate::{ - config::{Config, ALPACA_CRYPTO_DATA_URL, ALPACA_NEWS_DATA_URL, ALPACA_STOCK_DATA_URL}, + config::{Config, ALPACA_CRYPTO_DATA_URL, ALPACA_STOCK_DATA_URL}, database, types::{ alpaca::{ + self, api::{self, outgoing::Sort}, Source, }, news::Prediction, - Asset, Bar, Class, News, Subset, + Bar, Class, News, + }, + utils::{ + duration_until, last_minute, remove_slash_from_pair, FIFTEEN_MINUTES, ONE_MINUTE, + ONE_SECOND, }, - utils::{duration_until, last_minute, remove_slash_from_pair, FIFTEEN_MINUTES, ONE_MINUTE}, }; use async_trait::async_trait; -use backoff::{future::retry, ExponentialBackoff}; use futures_util::future::join_all; -use log::{error, info, warn}; +use log::{info, warn}; use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; use tokio::{ join, spawn, - sync::{mpsc, oneshot, Mutex, RwLock}, + sync::{mpsc, oneshot, Mutex}, task::{block_in_place, JoinHandle}, time::sleep, }; @@ -30,19 +33,28 @@ pub enum Action { Purge, } +impl From for Action { + fn from(action: super::Action) -> Self { + match action { + super::Action::Add => Self::Backfill, + super::Action::Remove => Self::Purge, + } + } +} + pub struct Message { pub action: Action, - pub assets: Subset, + pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { - pub fn new(action: Action, assets: Subset) -> (Self, oneshot::Receiver<()>) { + pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { action, - assets, + symbols, response: sender, }, receiver, @@ -60,58 +72,31 @@ pub trait Handler: Send + Sync { fn log_string(&self) -> &'static str; } -pub async fn run( - handler: Arc>, - guard: Arc>, - mut receiver: mpsc::Receiver, -) { +pub async fn run(handler: Arc>, mut receiver: mpsc::Receiver) { let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); loop { let message = receiver.recv().await.unwrap(); - spawn(handle_backfill_message( handler.clone(), - guard.clone(), backfill_jobs.clone(), message, )); } } -#[allow(clippy::significant_drop_tightening)] -#[allow(clippy::too_many_lines)] async fn handle_backfill_message( handler: Arc>, - guard: Arc>, backfill_jobs: Arc>>>, message: Message, ) { - let guard = guard.read().await; let mut backfill_jobs = backfill_jobs.lock().await; - let symbols = match message.assets { - Subset::All => guard - .assets - .clone() - .into_iter() - .map(|(_, symbol)| symbol) - .collect(), - Subset::Some(assets) => assets - .into_iter() - .map(|asset| asset.symbol) - .filter(|symbol| match message.action { - Action::Backfill => guard.assets.contains_right(symbol), - Action::Purge => !guard.assets.contains_right(symbol), - }) - .collect::>(), - }; - match message.action { Action::Backfill => { let log_string = handler.log_string(); - for symbol in symbols { + for symbol in message.symbols { if let Some(job) = backfill_jobs.get(&symbol) { if !job.is_finished() { warn!( @@ -131,7 +116,7 @@ async fn handle_backfill_message( .await .as_ref() .map_or(OffsetDateTime::UNIX_EPOCH, |backfill| { - backfill.time + ONE_MINUTE + backfill.time + ONE_SECOND }); let fetch_to = last_minute(); @@ -148,7 +133,7 @@ async fn handle_backfill_message( } } Action::Purge => { - for symbol in &symbols { + for symbol in &message.symbols { if let Some(job) = backfill_jobs.remove(symbol) { if !job.is_finished() { job.abort(); @@ -158,8 +143,8 @@ async fn handle_backfill_message( } join!( - handler.delete_backfills(&symbols), - handler.delete_data(&symbols) + handler.delete_backfills(&message.symbols), + handler.delete_data(&message.symbols) ); } } @@ -167,25 +152,6 @@ async fn handle_backfill_message( message.response.send(()).unwrap(); } -pub fn create_backfill_handler( - thread_type: ThreadType, - app_config: Arc, -) -> Box { - match thread_type { - ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler { - app_config, - data_url: ALPACA_STOCK_DATA_URL, - api_query_constructor: us_equity_query_constructor, - }), - ThreadType::Bars(Class::Crypto) => Box::new(BarHandler { - app_config, - data_url: ALPACA_CRYPTO_DATA_URL, - api_query_constructor: crypto_query_constructor, - }), - ThreadType::News => Box::new(NewsHandler { app_config }), - } -} - struct BarHandler { app_config: Arc, data_url: &'static str, @@ -277,35 +243,19 @@ impl Handler for BarHandler { let mut next_page_token = None; loop { - let message = retry(ExponentialBackoff::default(), || async { - self.app_config.alpaca_rate_limit.until_ready().await; - self.app_config - .alpaca_client - .get(self.data_url) - .query(&(self.api_query_constructor)( - &self.app_config, - symbol.clone(), - fetch_from, - fetch_to, - next_page_token.clone(), - )) - .send() - .await? - .error_for_status()? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) + let message = alpaca::api::incoming::bar::get_historical( + &self.app_config, + self.data_url, + &(self.api_query_constructor)( + &self.app_config, + symbol.clone(), + fetch_from, + fetch_to, + next_page_token.clone(), + ), + ) .await; - let message = match message { - Ok(message) => message, - Err(e) => { - error!("Failed to backfill bars for {}: {}.", symbol, e); - return; - } - }; - message.bars.into_iter().for_each(|(symbol, bar_vec)| { for bar in bar_vec { bars.push(Bar::from((bar, symbol.clone()))); @@ -381,38 +331,21 @@ impl Handler for NewsHandler { let mut next_page_token = None; loop { - let message = retry(ExponentialBackoff::default(), || async { - self.app_config.alpaca_rate_limit.until_ready().await; - self.app_config - .alpaca_client - .get(ALPACA_NEWS_DATA_URL) - .query(&api::outgoing::news::News { - symbols: vec![remove_slash_from_pair(&symbol)], - start: Some(fetch_from), - end: Some(fetch_to), - limit: Some(50), - include_content: Some(true), - exclude_contentless: Some(false), - page_token: next_page_token.clone(), - sort: Some(Sort::Asc), - }) - .send() - .await? - .error_for_status()? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) + let message = alpaca::api::incoming::news::get_historical( + &self.app_config, + &api::outgoing::news::News { + symbols: vec![remove_slash_from_pair(&symbol)], + start: Some(fetch_from), + end: Some(fetch_to), + limit: Some(50), + include_content: Some(true), + exclude_contentless: Some(false), + page_token: next_page_token.clone(), + sort: Some(Sort::Asc), + }, + ) .await; - let message = match message { - Ok(message) => message, - Err(e) => { - error!("Failed to backfill news for {}: {}.", symbol, e); - return; - } - }; - message.news.into_iter().for_each(|news_item| { news.push(News::from(news_item)); }); @@ -480,3 +413,19 @@ impl Handler for NewsHandler { "news" } } + +pub fn create_handler(thread_type: ThreadType, app_config: Arc) -> Box { + match thread_type { + ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler { + app_config, + data_url: ALPACA_STOCK_DATA_URL, + api_query_constructor: us_equity_query_constructor, + }), + ThreadType::Bars(Class::Crypto) => Box::new(BarHandler { + app_config, + data_url: ALPACA_CRYPTO_DATA_URL, + api_query_constructor: crypto_query_constructor, + }), + ThreadType::News => Box::new(NewsHandler { app_config }), + } +} diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index abb3d2b..ba7cfd8 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -1,24 +1,50 @@ -pub mod asset_status; pub mod backfill; pub mod websocket; -use self::asset_status::create_asset_status_handler; -use super::{clock, guard::Guard}; +use super::clock; use crate::{ config::{ Config, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_NEWS_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL, }, - types::{Class, Subset}, - utils::authenticate, + database, + types::{alpaca, Asset, Class}, + utils::{authenticate, cleanup}, }; -use futures_util::StreamExt; +use futures_util::{future::join_all, StreamExt}; +use itertools::{Either, Itertools}; use std::sync::Arc; use tokio::{ join, select, spawn, - sync::{mpsc, Mutex, RwLock}, + sync::{mpsc, oneshot}, }; use tokio_tungstenite::connect_async; +#[derive(Clone)] +pub enum Action { + Add, + Remove, +} + +pub struct Message { + pub action: Action, + pub assets: Vec<(String, Class)>, + pub response: oneshot::Sender<()>, +} + +impl Message { + pub fn new(action: Action, assets: Vec<(String, Class)>) -> (Self, oneshot::Receiver<()>) { + let (sender, receiver) = oneshot::channel(); + ( + Self { + action, + assets, + response: sender, + }, + receiver, + ) + } +} + #[derive(Clone, Copy, Debug)] pub enum ThreadType { Bars(Class), @@ -27,36 +53,39 @@ pub enum ThreadType { pub async fn run( app_config: Arc, - mut asset_receiver: mpsc::Receiver, + mut receiver: mpsc::Receiver, mut clock_receiver: mpsc::Receiver, ) { - let (bars_us_equity_asset_status_sender, bars_us_equity_backfill_sender) = + let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = init_thread(app_config.clone(), ThreadType::Bars(Class::UsEquity)).await; - let (bars_crypto_asset_status_sender, bars_crypto_backfill_sender) = + let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = init_thread(app_config.clone(), ThreadType::Bars(Class::Crypto)).await; - let (news_asset_status_sender, news_backfill_sender) = + let (news_websocket_sender, news_backfill_sender) = init_thread(app_config.clone(), ThreadType::News).await; loop { select! { - Some(asset_message) = asset_receiver.recv() => { - spawn(handle_asset_message( - bars_us_equity_asset_status_sender.clone(), - bars_crypto_asset_status_sender.clone(), - news_asset_status_sender.clone(), - asset_message, + Some(message) = receiver.recv() => { + spawn(handle_message( + app_config.clone(), + bars_us_equity_websocket_sender.clone(), + bars_us_equity_backfill_sender.clone(), + bars_crypto_websocket_sender.clone(), + bars_crypto_backfill_sender.clone(), + news_websocket_sender.clone(), + news_backfill_sender.clone(), + message, )); } Some(_) = clock_receiver.recv() => { spawn(handle_clock_message( + app_config.clone(), bars_us_equity_backfill_sender.clone(), bars_crypto_backfill_sender.clone(), news_backfill_sender.clone(), )); } - else => { - panic!("Communication channel unexpectedly closed.") - } + else => panic!("Communication channel unexpectedly closed.") } } } @@ -65,11 +94,9 @@ async fn init_thread( app_config: Arc, thread_type: ThreadType, ) -> ( - mpsc::Sender, + mpsc::Sender, mpsc::Sender, ) { - let guard = Arc::new(RwLock::new(Guard::new())); - let websocket_url = match thread_type { ThreadType::Bars(Class::UsEquity) => format!( "{}/{}", @@ -80,130 +107,190 @@ async fn init_thread( }; let (websocket, _) = connect_async(websocket_url).await.unwrap(); - let (mut websocket_sender, mut websocket_receiver) = websocket.split(); - authenticate(&app_config, &mut websocket_sender, &mut websocket_receiver).await; - let websocket_sender = Arc::new(Mutex::new(websocket_sender)); - - let (asset_status_sender, asset_status_receiver) = mpsc::channel(100); - spawn(asset_status::run( - Arc::new(create_asset_status_handler( - thread_type, - app_config.clone(), - websocket_sender.clone(), - )), - guard.clone(), - asset_status_receiver, - )); + let (mut websocket_sink, mut websocket_stream) = websocket.split(); + authenticate(&app_config, &mut websocket_sink, &mut websocket_stream).await; let (backfill_sender, backfill_receiver) = mpsc::channel(100); spawn(backfill::run( - Arc::new(backfill::create_backfill_handler( - thread_type, - app_config.clone(), - )), - guard.clone(), + Arc::new(backfill::create_handler(thread_type, app_config.clone())), backfill_receiver, )); + let (websocket_sender, websocket_receiver) = mpsc::channel(100); spawn(websocket::run( - app_config.clone(), - guard.clone(), - websocket_sender, + Arc::new(websocket::create_handler(thread_type, app_config.clone())), websocket_receiver, - backfill_sender.clone(), + websocket_stream, + websocket_sink, )); - (asset_status_sender, backfill_sender) + (websocket_sender, backfill_sender) } -async fn handle_asset_message( - bars_us_equity_asset_status_sender: mpsc::Sender, - bars_crypto_asset_status_sender: mpsc::Sender, - news_asset_status_sender: mpsc::Sender, - asset_status_message: asset_status::Message, +macro_rules! create_send_await { + ($sender:expr, $action:expr, $($contents:expr),*) => { + let (message, receiver) = $action($($contents),*); + $sender.send(message).await.unwrap(); + receiver.await.unwrap(); + }; +} + +#[allow(clippy::too_many_arguments)] +async fn handle_message( + app_config: Arc, + bars_us_equity_websocket_sender: mpsc::Sender, + bars_us_equity_backfill_sender: mpsc::Sender, + bars_crypto_websocket_sender: mpsc::Sender, + bars_crypto_backfill_sender: mpsc::Sender, + news_websocket_sender: mpsc::Sender, + news_backfill_sender: mpsc::Sender, + message: Message, ) { - let (us_equity_assets, crypto_assets): (Vec<_>, Vec<_>) = asset_status_message + let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message .assets .clone() .into_iter() - .partition(|asset| asset.class == Class::UsEquity); + .partition_map(|asset| match asset.1 { + Class::UsEquity => Either::Left(asset.0), + Class::Crypto => Either::Right(asset.0), + }); + + let symbols = message + .assets + .into_iter() + .map(|(symbol, _)| symbol) + .collect::>(); let bars_us_equity_future = async { - if !us_equity_assets.is_empty() { - let (bars_us_equity_asset_status_message, bars_us_equity_asset_status_receiver) = - asset_status::Message::new(asset_status_message.action.clone(), us_equity_assets); - bars_us_equity_asset_status_sender - .send(bars_us_equity_asset_status_message) - .await - .unwrap(); - bars_us_equity_asset_status_receiver.await.unwrap(); + if us_equity_symbols.is_empty() { + return; } + + create_send_await!( + bars_us_equity_websocket_sender, + websocket::Message::new, + message.action.clone().into(), + us_equity_symbols.clone() + ); + + create_send_await!( + bars_us_equity_backfill_sender, + backfill::Message::new, + message.action.clone().into(), + us_equity_symbols + ); }; let bars_crypto_future = async { - if !crypto_assets.is_empty() { - let (crypto_asset_status_message, crypto_asset_status_receiver) = - asset_status::Message::new(asset_status_message.action.clone(), crypto_assets); - bars_crypto_asset_status_sender - .send(crypto_asset_status_message) - .await - .unwrap(); - crypto_asset_status_receiver.await.unwrap(); + if crypto_symbols.is_empty() { + return; } + + create_send_await!( + bars_crypto_websocket_sender, + websocket::Message::new, + message.action.clone().into(), + crypto_symbols.clone() + ); + + create_send_await!( + bars_crypto_backfill_sender, + backfill::Message::new, + message.action.clone().into(), + crypto_symbols + ); }; let news_future = async { - if !asset_status_message.assets.is_empty() { - let (news_asset_status_message, news_asset_status_receiver) = - asset_status::Message::new( - asset_status_message.action.clone(), - asset_status_message.assets, - ); - news_asset_status_sender - .send(news_asset_status_message) - .await - .unwrap(); - news_asset_status_receiver.await.unwrap(); - } + create_send_await!( + news_websocket_sender, + websocket::Message::new, + message.action.clone().into(), + symbols.clone() + ); + + create_send_await!( + news_backfill_sender, + backfill::Message::new, + message.action.clone().into(), + symbols.clone() + ); }; join!(bars_us_equity_future, bars_crypto_future, news_future); - asset_status_message.response.send(()).unwrap(); + + match message.action { + Action::Add => { + let assets = + join_all(symbols.into_iter().map(|symbol| { + let app_config = app_config.clone(); + async move { + alpaca::api::incoming::asset::get_by_symbol(&app_config, &symbol).await + } + })) + .await + .into_iter() + .map(|result| Asset::from(result.unwrap())) + .collect::>(); + + database::assets::upsert_batch(&app_config.clickhouse_client, assets).await; + } + Action::Remove => { + database::assets::delete_where_symbols(&app_config.clickhouse_client, &symbols).await; + } + } + + message.response.send(()).unwrap(); } async fn handle_clock_message( + app_config: Arc, bars_us_equity_backfill_sender: mpsc::Sender, bars_crypto_backfill_sender: mpsc::Sender, news_backfill_sender: mpsc::Sender, ) { + cleanup(&app_config.clickhouse_client).await; + + let assets = database::assets::select(&app_config.clickhouse_client).await; + + let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets + .clone() + .into_iter() + .partition_map(|asset| match asset.class { + Class::UsEquity => Either::Left(asset.symbol), + Class::Crypto => Either::Right(asset.symbol), + }); + + let symbols = assets + .into_iter() + .map(|asset| asset.symbol) + .collect::>(); + let bars_us_equity_future = async { - let (bars_us_equity_backfill_message, bars_us_equity_backfill_receiver) = - backfill::Message::new(backfill::Action::Backfill, Subset::All); - bars_us_equity_backfill_sender - .send(bars_us_equity_backfill_message) - .await - .unwrap(); - bars_us_equity_backfill_receiver.await.unwrap(); + create_send_await!( + bars_us_equity_backfill_sender, + backfill::Message::new, + backfill::Action::Backfill, + us_equity_symbols.clone() + ); }; let bars_crypto_future = async { - let (bars_crypto_backfill_message, bars_crypto_backfill_receiver) = - backfill::Message::new(backfill::Action::Backfill, Subset::All); - bars_crypto_backfill_sender - .send(bars_crypto_backfill_message) - .await - .unwrap(); - bars_crypto_backfill_receiver.await.unwrap(); + create_send_await!( + bars_crypto_backfill_sender, + backfill::Message::new, + backfill::Action::Backfill, + crypto_symbols.clone() + ); }; let news_future = async { - let (news_backfill_message, news_backfill_receiver) = - backfill::Message::new(backfill::Action::Backfill, Subset::All); - news_backfill_sender - .send(news_backfill_message) - .await - .unwrap(); - news_backfill_receiver.await.unwrap(); + create_send_await!( + news_backfill_sender, + backfill::Message::new, + backfill::Action::Backfill, + symbols + ); }; join!(bars_us_equity_future, bars_crypto_future, news_future); diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index 84af18e..fffec0f 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -1,51 +1,192 @@ -use super::{backfill, Guard}; +use super::ThreadType; use crate::{ config::Config, database, - types::{alpaca::websocket, news::Prediction, Bar, News, Subset}, + types::{alpaca::websocket, news::Prediction, Bar, News}, utils::add_slash_to_pair, }; +use async_trait::async_trait; use futures_util::{ + future::join_all, stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; -use log::{debug, error, info, warn}; -use serde_json::from_str; -use std::{collections::HashSet, sync::Arc}; +use log::{debug, error, info}; +use serde_json::{from_str, to_string}; +use std::{collections::HashMap, sync::Arc}; use tokio::{ - join, net::TcpStream, - spawn, - sync::{mpsc, Mutex, RwLock}, + select, spawn, + sync::{mpsc, oneshot, Mutex, RwLock}, task::block_in_place, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; -pub async fn run( - app_config: Arc, - guard: Arc>, - sender: Arc>, tungstenite::Message>>>, - mut receiver: SplitStream>>, - backfill_sender: mpsc::Sender, -) { - loop { - let message = receiver.next().await.unwrap().unwrap(); +pub enum Action { + Subscribe, + Unsubscribe, +} - spawn(handle_websocket_message( - app_config.clone(), - guard.clone(), - sender.clone(), - backfill_sender.clone(), - message, - )); +impl From for Action { + fn from(action: super::Action) -> Self { + match action { + super::Action::Add => Self::Subscribe, + super::Action::Remove => Self::Unsubscribe, + } } } +pub struct Message { + pub action: Action, + pub symbols: Vec, + pub response: oneshot::Sender<()>, +} + +impl Message { + pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { + let (sender, receiver) = oneshot::channel(); + ( + Self { + action, + symbols, + response: sender, + }, + receiver, + ) + } +} + +pub struct Pending { + pub subscriptions: HashMap>, + pub unsubscriptions: HashMap>, +} + +#[async_trait] +pub trait Handler: Send + Sync { + fn create_subscription_message( + &self, + symbols: Vec, + ) -> websocket::outgoing::subscribe::Message; + async fn handle_parsed_websocket_message( + &self, + pending: Arc>, + message: websocket::incoming::Message, + ); +} + +pub async fn run( + handler: Arc>, + mut receiver: mpsc::Receiver, + mut websocket_stream: SplitStream>>, + websocket_sink: SplitSink>, tungstenite::Message>, +) { + let pending = Arc::new(RwLock::new(Pending { + subscriptions: HashMap::new(), + unsubscriptions: HashMap::new(), + })); + let websocket_sink = Arc::new(Mutex::new(websocket_sink)); + + loop { + select! { + Some(message) = receiver.recv() => { + spawn(handle_message( + handler.clone(), + pending.clone(), + websocket_sink.clone(), + message, + )); + } + Some(Ok(message)) = websocket_stream.next() => { + spawn(handle_websocket_message( + handler.clone(), + pending.clone(), + websocket_sink.clone(), + message, + )); + } + else => panic!("Communication channel unexpectedly closed.") + } + } +} + +async fn handle_message( + handler: Arc>, + pending: Arc>, + websocket_sender: Arc< + Mutex>, tungstenite::Message>>, + >, + message: Message, +) { + match message.action { + Action::Subscribe => { + let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message + .symbols + .iter() + .map(|symbol| { + let (sender, receiver) = oneshot::channel(); + ((symbol.clone(), sender), receiver) + }) + .unzip(); + + pending + .write() + .await + .subscriptions + .extend(pending_subscriptions); + + websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Subscribe( + handler.create_subscription_message(message.symbols), + )) + .unwrap(), + )) + .await + .unwrap(); + + join_all(receivers).await; + } + Action::Unsubscribe => { + let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message + .symbols + .iter() + .map(|symbol| { + let (sender, receiver) = oneshot::channel(); + ((symbol.clone(), sender), receiver) + }) + .unzip(); + + pending + .write() + .await + .unsubscriptions + .extend(pending_unsubscriptions); + + websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Unsubscribe( + handler.create_subscription_message(message.symbols.clone()), + )) + .unwrap(), + )) + .await + .unwrap(); + + join_all(receivers).await; + } + } + + message.response.send(()).unwrap(); +} + async fn handle_websocket_message( - app_config: Arc, - guard: Arc>, + handler: Arc>, + pending: Arc>, sender: Arc>, tungstenite::Message>>>, - backfill_sender: mpsc::Sender, message: tungstenite::Message, ) { match message { @@ -54,12 +195,14 @@ async fn handle_websocket_message( if let Ok(message) = message { for message in message { - spawn(handle_parsed_websocket_message( - app_config.clone(), - guard.clone(), - backfill_sender.clone(), - message, - )); + let handler = handler.clone(); + let pending = pending.clone(); + + spawn(async move { + handler + .handle_parsed_websocket_message(pending, message) + .await; + }); } } else { error!("Failed to deserialize websocket message: {:?}", message); @@ -77,143 +220,190 @@ async fn handle_websocket_message( } } -#[allow(clippy::significant_drop_tightening)] -#[allow(clippy::too_many_lines)] -async fn handle_parsed_websocket_message( +struct BarsHandler { app_config: Arc, - guard: Arc>, - backfill_sender: mpsc::Sender, - message: websocket::incoming::Message, -) { - match message { - websocket::incoming::Message::Subscription(message) => { - let (symbols, log_string) = match message { - websocket::incoming::subscription::Message::Market { bars, .. } => (bars, "bars"), - websocket::incoming::subscription::Message::News { news } => ( - news.into_iter() - .map(|symbol| add_slash_to_pair(&symbol)) - .collect(), - "news", - ), - }; +} - let mut guard = guard.write().await; +#[async_trait] +impl Handler for BarsHandler { + fn create_subscription_message( + &self, + symbols: Vec, + ) -> websocket::outgoing::subscribe::Message { + websocket::outgoing::subscribe::Message::new_market(symbols) + } - let newly_subscribed = guard - .pending_subscriptions - .extract_if(|asset| symbols.contains(&asset.symbol)) - .collect::>(); + async fn handle_parsed_websocket_message( + &self, + pending: Arc>, + message: websocket::incoming::Message, + ) { + match message { + websocket::incoming::Message::Subscription(message) => { + let websocket::incoming::subscription::Message::Market { bars: symbols, .. } = + message + else { + unreachable!() + }; - let newly_unsubscribed = guard - .pending_unsubscriptions - .extract_if(|asset| !symbols.contains(&asset.symbol)) - .collect::>(); + let mut pending = pending.write().await; - drop(guard); + let newly_subscribed = pending + .subscriptions + .extract_if(|symbol, _| symbols.contains(symbol)) + .collect::>(); + + let newly_unsubscribed = pending + .unsubscriptions + .extract_if(|symbol, _| !symbols.contains(symbol)) + .collect::>(); + + drop(pending); - let newly_subscribed_future = async { if !newly_subscribed.is_empty() { info!( - "Subscribed to {} for {:?}.", - log_string, - newly_subscribed - .iter() - .map(|asset| asset.symbol.clone()) - .collect::>() + "Subscribed to bars for {:?}.", + newly_subscribed.keys().collect::>() ); - let (backfill_message, backfill_receiver) = backfill::Message::new( - backfill::Action::Backfill, - Subset::Some(newly_subscribed.into_iter().collect::>()), - ); - - backfill_sender.send(backfill_message).await.unwrap(); - backfill_receiver.await.unwrap(); + for sender in newly_subscribed.into_values() { + sender.send(()).unwrap(); + } } - }; - let newly_unsubscribed_future = async { if !newly_unsubscribed.is_empty() { info!( - "Unsubscribed from {} for {:?}.", - log_string, - newly_unsubscribed - .iter() - .map(|asset| asset.symbol.clone()) - .collect::>() + "Unsubscribed from bars for {:?}.", + newly_unsubscribed.keys().collect::>() ); - let (purge_message, purge_receiver) = backfill::Message::new( - backfill::Action::Purge, - Subset::Some(newly_unsubscribed.into_iter().collect::>()), - ); - - backfill_sender.send(purge_message).await.unwrap(); - purge_receiver.await.unwrap(); + for sender in newly_unsubscribed.into_values() { + sender.send(()).unwrap(); + } } - }; - - join!(newly_subscribed_future, newly_unsubscribed_future); - } - websocket::incoming::Message::Bar(message) - | websocket::incoming::Message::UpdatedBar(message) => { - let bar = Bar::from(message); - - let guard = guard.read().await; - if !guard.assets.contains_right(&bar.symbol) { - warn!( - "Race condition: received bar for unsubscribed symbol: {:?}.", - bar.symbol - ); - return; } - - debug!("Received bar for {}: {}.", bar.symbol, bar.time); - database::bars::upsert(&app_config.clickhouse_client, &bar).await; - } - websocket::incoming::Message::News(message) => { - let news = News::from(message); - - let guard = guard.read().await; - if !news - .symbols - .iter() - .any(|symbol| guard.assets.contains_right(symbol)) - { - warn!( - "Race condition: received news for unsubscribed symbols: {:?}.", - news.symbols - ); - return; + websocket::incoming::Message::Bar(message) + | websocket::incoming::Message::UpdatedBar(message) => { + let bar = Bar::from(message); + debug!("Received bar for {}: {}.", bar.symbol, bar.time); + database::bars::upsert(&self.app_config.clickhouse_client, &bar).await; } - - debug!( - "Received news for {:?}: {}.", - news.symbols, news.time_created - ); - - let input = format!("{}\n\n{}", news.headline, news.content); - - let sequence_classifier = app_config.sequence_classifier.lock().await; - let prediction = block_in_place(|| { - sequence_classifier - .predict(vec![input.as_str()]) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>()[0] - }); - drop(sequence_classifier); - - let news = News { - sentiment: prediction.sentiment, - confidence: prediction.confidence, - ..news - }; - database::news::upsert(&app_config.clickhouse_client, &news).await; - } - websocket::incoming::Message::Success(_) => {} - websocket::incoming::Message::Error(message) => { - error!("Received error message: {}.", message.message); + websocket::incoming::Message::Success(_) => {} + websocket::incoming::Message::Error(message) => { + error!("Received error message: {}.", message.message); + } + websocket::incoming::Message::News(_) => unreachable!(), } } } + +struct NewsHandler { + app_config: Arc, +} + +#[async_trait] +impl Handler for NewsHandler { + fn create_subscription_message( + &self, + symbols: Vec, + ) -> websocket::outgoing::subscribe::Message { + websocket::outgoing::subscribe::Message::new_news(symbols) + } + + async fn handle_parsed_websocket_message( + &self, + pending: Arc>, + message: websocket::incoming::Message, + ) { + match message { + websocket::incoming::Message::Subscription(message) => { + let websocket::incoming::subscription::Message::News { news: symbols } = message + else { + unreachable!() + }; + + let symbols = symbols + .into_iter() + .map(|symbol| add_slash_to_pair(&symbol)) + .collect::>(); + + let mut pending = pending.write().await; + + let newly_subscribed = pending + .subscriptions + .extract_if(|symbol, _| symbols.contains(symbol)) + .collect::>(); + + let newly_unsubscribed = pending + .unsubscriptions + .extract_if(|symbol, _| !symbols.contains(symbol)) + .collect::>(); + + drop(pending); + + if !newly_subscribed.is_empty() { + info!( + "Subscribed to news for {:?}.", + newly_subscribed.keys().collect::>() + ); + + for sender in newly_subscribed.into_values() { + sender.send(()).unwrap(); + } + } + + if !newly_unsubscribed.is_empty() { + info!( + "Unsubscribed from news for {:?}.", + newly_unsubscribed.keys().collect::>() + ); + + for sender in newly_unsubscribed.into_values() { + sender.send(()).unwrap(); + } + } + } + websocket::incoming::Message::News(message) => { + let news = News::from(message); + + debug!( + "Received news for {:?}: {}.", + news.symbols, news.time_created + ); + + let input = format!("{}\n\n{}", news.headline, news.content); + + let sequence_classifier = self.app_config.sequence_classifier.lock().await; + let prediction = block_in_place(|| { + sequence_classifier + .predict(vec![input.as_str()]) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>()[0] + }); + drop(sequence_classifier); + + let news = News { + sentiment: prediction.sentiment, + confidence: prediction.confidence, + ..news + }; + database::news::upsert(&self.app_config.clickhouse_client, &news).await; + } + websocket::incoming::Message::Success(_) => {} + websocket::incoming::Message::Error(message) => { + error!("Received error message: {}.", message.message); + } + websocket::incoming::Message::Bar(_) | websocket::incoming::Message::UpdatedBar(_) => { + unreachable!() + } + } + } +} + +pub fn create_handler(thread_type: ThreadType, app_config: Arc) -> Box { + match thread_type { + ThreadType::Bars(_) => Box::new(BarsHandler { app_config }), + ThreadType::News => Box::new(NewsHandler { app_config }), + } +} diff --git a/src/threads/guard.rs b/src/threads/guard.rs deleted file mode 100644 index 6e67641..0000000 --- a/src/threads/guard.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::types::Asset; -use bimap::BiMap; -use std::collections::HashSet; - -pub struct Guard { - pub assets: BiMap, - pub pending_subscriptions: HashSet, - pub pending_unsubscriptions: HashSet, -} - -impl Guard { - pub fn new() -> Self { - Self { - assets: BiMap::new(), - pending_subscriptions: HashSet::new(), - pending_unsubscriptions: HashSet::new(), - } - } -} diff --git a/src/threads/mod.rs b/src/threads/mod.rs index 227caf0..5f09b94 100644 --- a/src/threads/mod.rs +++ b/src/threads/mod.rs @@ -1,3 +1,2 @@ pub mod clock; pub mod data; -pub mod guard; diff --git a/src/types/algebraic/mod.rs b/src/types/algebraic/mod.rs deleted file mode 100644 index 192af7d..0000000 --- a/src/types/algebraic/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod subset; - -pub use subset::Subset; diff --git a/src/types/algebraic/subset.rs b/src/types/algebraic/subset.rs deleted file mode 100644 index 5a902f3..0000000 --- a/src/types/algebraic/subset.rs +++ /dev/null @@ -1,5 +0,0 @@ -#[derive(Clone, Debug)] -pub enum Subset { - Some(Vec), - All, -} diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index d164a28..0a00a35 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -1,5 +1,11 @@ -use crate::types::{self, alpaca::api::impl_from_enum}; +use crate::{ + config::{Config, ALPACA_ASSET_API_URL}, + types::{self, alpaca::api::impl_from_enum}, +}; +use backoff::{future::retry, ExponentialBackoff}; +use http::StatusCode; use serde::Deserialize; +use std::sync::Arc; #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "snake_case")] @@ -80,3 +86,27 @@ impl From for types::Asset { } } } + +pub async fn get_by_symbol(app_config: &Arc, symbol: &str) -> Result { + retry(ExponentialBackoff::default(), || async { + app_config.alpaca_rate_limit.until_ready().await; + app_config + .alpaca_client + .get(&format!("{ALPACA_ASSET_API_URL}/{symbol}")) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::NOT_FOUND) => backoff::Error::Permanent(e), + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }) + .await + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::NOT_FOUND) => StatusCode::NOT_FOUND, + _ => panic!("Unexpected error: {e}."), + }) +} diff --git a/src/types/alpaca/api/incoming/bar.rs b/src/types/alpaca/api/incoming/bar.rs index 37feb60..d3f49a5 100644 --- a/src/types/alpaca/api/incoming/bar.rs +++ b/src/types/alpaca/api/incoming/bar.rs @@ -1,6 +1,10 @@ -use crate::types; +use crate::{ + config::Config, + types::{self, alpaca::api::outgoing}, +}; +use backoff::{future::retry, ExponentialBackoff}; use serde::Deserialize; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; #[derive(Clone, Debug, PartialEq, Deserialize)] @@ -45,3 +49,25 @@ pub struct Message { pub bars: HashMap>, pub next_page_token: Option, } + +pub async fn get_historical( + app_config: &Arc, + data_url: &str, + query: &outgoing::bar::Bar, +) -> Message { + retry(ExponentialBackoff::default(), || async { + app_config.alpaca_rate_limit.until_ready().await; + app_config + .alpaca_client + .get(data_url) + .query(query) + .send() + .await? + .error_for_status()? + .json::() + .await + .map_err(backoff::Error::Permanent) + }) + .await + .unwrap() +} diff --git a/src/types/alpaca/api/incoming/clock.rs b/src/types/alpaca/api/incoming/clock.rs index e808b99..fb9486f 100644 --- a/src/types/alpaca/api/incoming/clock.rs +++ b/src/types/alpaca/api/incoming/clock.rs @@ -1,4 +1,7 @@ +use crate::config::{Config, ALPACA_CLOCK_API_URL}; +use backoff::{future::retry, ExponentialBackoff}; use serde::Deserialize; +use std::sync::Arc; use time::OffsetDateTime; #[derive(Clone, Debug, PartialEq, Eq, Deserialize)] @@ -11,3 +14,19 @@ pub struct Clock { #[serde(with = "time::serde::rfc3339")] pub next_close: OffsetDateTime, } + +pub async fn get(app_config: &Arc) -> Clock { + retry(ExponentialBackoff::default(), || async { + app_config.alpaca_rate_limit.until_ready().await; + app_config + .alpaca_client + .get(ALPACA_CLOCK_API_URL) + .send() + .await? + .json::() + .await + .map_err(backoff::Error::Permanent) + }) + .await + .unwrap() +} diff --git a/src/types/alpaca/api/incoming/news.rs b/src/types/alpaca/api/incoming/news.rs index 4cd32e4..f0a9c9a 100644 --- a/src/types/alpaca/api/incoming/news.rs +++ b/src/types/alpaca/api/incoming/news.rs @@ -1,8 +1,11 @@ use crate::{ - types, + config::{Config, ALPACA_NEWS_DATA_URL}, + types::{self, alpaca::api::outgoing}, utils::{add_slash_to_pair, normalize_news_content}, }; +use backoff::{future::retry, ExponentialBackoff}; use serde::Deserialize; +use std::sync::Arc; use time::OffsetDateTime; #[derive(Clone, Debug, PartialEq, Eq, Deserialize)] @@ -66,3 +69,21 @@ pub struct Message { pub news: Vec, pub next_page_token: Option, } + +pub async fn get_historical(app_config: &Arc, query: &outgoing::news::News) -> Message { + retry(ExponentialBackoff::default(), || async { + app_config.alpaca_rate_limit.until_ready().await; + app_config + .alpaca_client + .get(ALPACA_NEWS_DATA_URL) + .query(query) + .send() + .await? + .error_for_status()? + .json::() + .await + .map_err(backoff::Error::Permanent) + }) + .await + .unwrap() +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b19bfab..5c4c806 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,11 +1,9 @@ -pub mod algebraic; pub mod alpaca; pub mod asset; pub mod backfill; pub mod bar; pub mod news; -pub use algebraic::Subset; pub use asset::{Asset, Class, Exchange}; pub use backfill::Backfill; pub use bar::Bar; diff --git a/src/utils/cleanup.rs b/src/utils/cleanup.rs index 05bb69f..5e4dd53 100644 --- a/src/utils/cleanup.rs +++ b/src/utils/cleanup.rs @@ -3,9 +3,9 @@ use clickhouse::Client; use tokio::join; pub async fn cleanup(clickhouse_client: &Client) { - let bars_future = database::bars::cleanup(clickhouse_client); - let news_future = database::news::cleanup(clickhouse_client); - let backfills_future = database::backfills::cleanup(clickhouse_client); - - join!(bars_future, news_future, backfills_future); + join!( + database::bars::cleanup(clickhouse_client), + database::news::cleanup(clickhouse_client), + database::backfills::cleanup(clickhouse_client) + ); } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 4ed449d..e9c9726 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,5 +5,5 @@ pub mod websocket; pub use cleanup::cleanup; pub use news::{add_slash_to_pair, normalize_news_content, remove_slash_from_pair}; -pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}; +pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND}; pub use websocket::authenticate; diff --git a/src/utils/time.rs b/src/utils/time.rs index fa0c2ae..9be0e7e 100644 --- a/src/utils/time.rs +++ b/src/utils/time.rs @@ -1,6 +1,7 @@ use std::time::Duration; use time::OffsetDateTime; +pub const ONE_SECOND: Duration = Duration::from_secs(1); pub const ONE_MINUTE: Duration = Duration::from_secs(60); pub const FIFTEEN_MINUTES: Duration = Duration::from_secs(60 * 15); diff --git a/src/utils/websocket.rs b/src/utils/websocket.rs index 6983bff..29f3eaa 100644 --- a/src/utils/websocket.rs +++ b/src/utils/websocket.rs @@ -11,10 +11,10 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; pub async fn authenticate( app_config: &Arc, - sender: &mut SplitSink>, Message>, - receiver: &mut SplitStream>>, + sink: &mut SplitSink>, Message>, + stream: &mut SplitStream>>, ) { - match receiver.next().await.unwrap().unwrap() { + match stream.next().await.unwrap().unwrap() { Message::Text(data) if from_str::>(&data) .unwrap() @@ -25,20 +25,19 @@ pub async fn authenticate( _ => panic!("Failed to connect to Alpaca websocket."), } - sender - .send(Message::Text( - to_string(&websocket::outgoing::Message::Auth( - websocket::outgoing::auth::Message { - key: app_config.alpaca_api_key.clone(), - secret: app_config.alpaca_api_secret.clone(), - }, - )) - .unwrap(), + sink.send(Message::Text( + to_string(&websocket::outgoing::Message::Auth( + websocket::outgoing::auth::Message { + key: app_config.alpaca_api_key.clone(), + secret: app_config.alpaca_api_secret.clone(), + }, )) - .await - .unwrap(); + .unwrap(), + )) + .await + .unwrap(); - match receiver.next().await.unwrap().unwrap() { + match stream.next().await.unwrap().unwrap() { Message::Text(data) if from_str::>(&data) .unwrap()