From 85eef2bf0bbc348ef1fee46f6c37126678eb9646 Mon Sep 17 00:00:00 2001 From: Nikolaos Karaolidis Date: Mon, 5 Feb 2024 13:47:43 +0000 Subject: [PATCH] Refactor threads to use trait implementations Signed-off-by: Nikolaos Karaolidis --- Cargo.lock | 1 + Cargo.toml | 1 + src/database/backfills.rs | 49 +- src/threads/clock.rs | 4 +- src/threads/data/asset_status.rs | 202 ++++---- src/threads/data/backfill.rs | 609 ++++++++++++++----------- src/threads/data/mod.rs | 16 +- src/threads/data/websocket.rs | 77 ++-- src/types/alpaca/api/incoming/asset.rs | 4 +- 9 files changed, 524 insertions(+), 439 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1008425..79b8655 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1653,6 +1653,7 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" name = "qrust" version = "0.1.0" dependencies = [ + "async-trait", "axum", "backoff", "bimap", diff --git a/Cargo.toml b/Cargo.toml index e439373..5021cb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,3 +53,4 @@ regex = "1.10.3" html-escape = "0.2.13" rust-bert = "0.22.0" bimap = "0.6.3" +async-trait = "0.1.77" diff --git a/src/database/backfills.rs b/src/database/backfills.rs index 0639cb4..23d2c11 100644 --- a/src/database/backfills.rs +++ b/src/database/backfills.rs @@ -1,11 +1,26 @@ -use crate::{threads::data::ThreadType, types::Backfill}; +use crate::types::Backfill; use clickhouse::Client; use serde::Serialize; +use std::fmt::Display; use tokio::join; +pub enum Table { + Bars, + News, +} + +impl Display for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bars => write!(f, "backfills_bars"), + Self::News => write!(f, "backfills_news"), + } + } +} + pub async fn select_latest_where_symbol( clickhouse_client: &Client, - thread_type: &ThreadType, + table: &Table, symbol: &T, ) -> Option where @@ -13,11 +28,7 @@ where { clickhouse_client .query(&format!( - "SELECT ?fields FROM {} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1", - match thread_type { - ThreadType::Bars(_) => "backfills_bars", - ThreadType::News => "backfills_news", - } + "SELECT ?fields FROM {table} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1", )) .bind(symbol) .fetch_optional::() @@ -25,32 +36,18 @@ where .unwrap() } -pub async fn upsert(clickhouse_client: &Client, thread_type: &ThreadType, backfill: &Backfill) { - let mut insert = clickhouse_client - .insert(match thread_type { - ThreadType::Bars(_) => "backfills_bars", - ThreadType::News => "backfills_news", - }) - .unwrap(); +pub async fn upsert(clickhouse_client: &Client, table: &Table, backfill: &Backfill) { + let mut insert = clickhouse_client.insert(&table.to_string()).unwrap(); insert.write(backfill).await.unwrap(); insert.end().await.unwrap(); } -pub async fn delete_where_symbols( - clickhouse_client: &Client, - thread_type: &ThreadType, - symbols: &[T], -) where +pub async fn delete_where_symbols(clickhouse_client: &Client, table: &Table, symbols: &[T]) +where T: AsRef + Serialize + Send + Sync, { clickhouse_client - .query(&format!( - "DELETE FROM {} WHERE symbol IN ?", - match thread_type { - ThreadType::Bars(_) => "backfills_bars", - ThreadType::News => "backfills_news", - } - )) + .query(&format!("DELETE FROM {table} WHERE symbol IN ?")) .bind(symbols) .execute() .await diff --git a/src/threads/clock.rs b/src/threads/clock.rs index 34b7853..631b1df 100644 --- a/src/threads/clock.rs +++ b/src/threads/clock.rs @@ -35,7 +35,7 @@ impl From for Message { } } -pub async fn run(app_config: Arc, clock_sender: mpsc::Sender) { +pub async fn run(app_config: Arc, sender: mpsc::Sender) { loop { let clock = retry(ExponentialBackoff::default(), || async { app_config.alpaca_rate_limit.until_ready().await; @@ -61,6 +61,6 @@ pub async fn run(app_config: Arc, clock_sender: mpsc::Sender) { }); sleep(sleep_until).await; - clock_sender.send(clock.into()).await.unwrap(); + sender.send(clock.into()).await.unwrap(); } } diff --git a/src/threads/data/asset_status.rs b/src/threads/data/asset_status.rs index 959c743..d58f27d 100644 --- a/src/threads/data/asset_status.rs +++ b/src/threads/data/asset_status.rs @@ -4,6 +4,7 @@ use crate::{ database, types::{alpaca::websocket, Asset}, }; +use async_trait::async_trait; use futures_util::{stream::SplitSink, SinkExt}; use log::info; use serde_json::to_string; @@ -42,23 +43,23 @@ impl Message { } } +#[async_trait] +pub trait Handler: Send + Sync { + async fn add_assets(&self, assets: Vec, symbols: Vec); + async fn remove_assets(&self, assets: Vec, symbols: Vec); +} + pub async fn run( - app_config: Arc, - thread_type: ThreadType, + handler: Arc>, guard: Arc>, - mut asset_status_receiver: mpsc::Receiver, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, + mut receiver: mpsc::Receiver, ) { loop { - let message = asset_status_receiver.recv().await.unwrap(); + let message = receiver.recv().await.unwrap(); spawn(handle_asset_status_message( - app_config.clone(), - thread_type, + handler.clone(), guard.clone(), - websocket_sender.clone(), message, )); } @@ -66,12 +67,8 @@ pub async fn run( #[allow(clippy::significant_drop_tightening)] async fn handle_asset_status_message( - app_config: Arc, - thread_type: ThreadType, + handler: Arc>, guard: Arc>, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, message: Message, ) { let symbols = message @@ -93,37 +90,7 @@ async fn handle_asset_status_message( ); guard.pending_subscriptions.extend(message.assets.clone()); - info!("{:?} - Added {:?}.", thread_type, symbols); - - let database_future = async { - if matches!(thread_type, ThreadType::Bars(_)) { - database::assets::upsert_batch(&app_config.clickhouse_client, message.assets) - .await; - } - }; - - let websocket_future = async move { - websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Subscribe( - match thread_type { - ThreadType::Bars(_) => { - websocket::outgoing::subscribe::Message::new_market(symbols) - } - ThreadType::News => { - websocket::outgoing::subscribe::Message::new_news(symbols) - } - }, - )) - .unwrap(), - )) - .await - .unwrap(); - }; - - join!(database_future, websocket_future); + handler.add_assets(message.assets, symbols).await; } Action::Remove => { let mut guard = guard.write().await; @@ -131,40 +98,121 @@ async fn handle_asset_status_message( guard .assets .retain(|asset, _| !message.assets.contains(asset)); - guard.pending_unsubscriptions.extend(message.assets); + guard.pending_unsubscriptions.extend(message.assets.clone()); - info!("{:?} - Removed {:?}.", thread_type, symbols); - - let sybols_clone = symbols.clone(); - let database_future = database::assets::delete_where_symbols( - &app_config.clickhouse_client, - &sybols_clone, - ); - - let websocket_future = async move { - websocket_sender - .lock() - .await - .send(tungstenite::Message::Text( - to_string(&websocket::outgoing::Message::Unsubscribe( - match thread_type { - ThreadType::Bars(_) => { - websocket::outgoing::subscribe::Message::new_market(symbols) - } - ThreadType::News => { - websocket::outgoing::subscribe::Message::new_news(symbols) - } - }, - )) - .unwrap(), - )) - .await - .unwrap(); - }; - - join!(database_future, websocket_future); + handler.remove_assets(message.assets, symbols).await; } } message.response.send(()).unwrap(); } + +pub fn create_asset_status_handler( + thread_type: ThreadType, + app_config: Arc, + websocket_sender: Arc< + Mutex>, tungstenite::Message>>, + >, +) -> Box { + match thread_type { + ThreadType::Bars(_) => Box::new(BarsHandler { + app_config, + websocket_sender, + }), + ThreadType::News => Box::new(NewsHandler { websocket_sender }), + } +} + +struct BarsHandler { + app_config: Arc, + websocket_sender: + Arc>, tungstenite::Message>>>, +} + +#[async_trait] +impl Handler for BarsHandler { + async fn add_assets(&self, assets: Vec, symbols: Vec) { + let database_future = + database::assets::upsert_batch(&self.app_config.clickhouse_client, assets); + + let symbols_clone = symbols.clone(); + let websocket_future = async move { + self.websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Subscribe( + websocket::outgoing::subscribe::Message::new_market(symbols_clone), + )) + .unwrap(), + )) + .await + .unwrap(); + }; + + join!(database_future, websocket_future); + info!("Added {:?}.", symbols); + } + + async fn remove_assets(&self, _: Vec, symbols: Vec) { + let symbols_clone = symbols.clone(); + let database_future = database::assets::delete_where_symbols( + &self.app_config.clickhouse_client, + &symbols_clone, + ); + + let symbols_clone = symbols.clone(); + let websocket_future = async move { + self.websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Unsubscribe( + websocket::outgoing::subscribe::Message::new_market(symbols_clone), + )) + .unwrap(), + )) + .await + .unwrap(); + }; + + join!(database_future, websocket_future); + info!("Removed {:?}.", symbols); + } +} + +struct NewsHandler { + websocket_sender: + Arc>, tungstenite::Message>>>, +} + +#[async_trait] +impl Handler for NewsHandler { + async fn add_assets(&self, _: Vec, symbols: Vec) { + self.websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Subscribe( + websocket::outgoing::subscribe::Message::new_news(symbols), + )) + .unwrap(), + )) + .await + .unwrap(); + } + + async fn remove_assets(&self, _: Vec, symbols: Vec) { + self.websocket_sender + .lock() + .await + .send(tungstenite::Message::Text( + to_string(&websocket::outgoing::Message::Unsubscribe( + websocket::outgoing::subscribe::Message::new_news(symbols), + )) + .unwrap(), + )) + .await + .unwrap(); + } +} diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index 9c0e662..dabc168 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -12,6 +12,7 @@ use crate::{ }, utils::{duration_until, last_minute, remove_slash_from_pair, FIFTEEN_MINUTES, ONE_MINUTE}, }; +use async_trait::async_trait; use backoff::{future::retry, ExponentialBackoff}; use futures_util::future::join_all; use log::{error, info, warn}; @@ -49,28 +50,29 @@ impl Message { } } +#[async_trait] +pub trait Handler: Send + Sync { + async fn select_latest_backfill(&self, symbol: String) -> Option; + async fn delete_backfills(&self, symbol: &[String]); + async fn delete_data(&self, symbol: &[String]); + async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); + async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime); + fn log_string(&self) -> &'static str; +} + pub async fn run( - app_config: Arc, - thread_type: ThreadType, + handler: Arc>, guard: Arc>, - mut backfill_receiver: mpsc::Receiver, + mut receiver: mpsc::Receiver, ) { let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); - let data_url = match thread_type { - ThreadType::Bars(Class::UsEquity) => ALPACA_STOCK_DATA_URL.to_string(), - ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_URL.to_string(), - ThreadType::News => ALPACA_NEWS_DATA_URL.to_string(), - }; - loop { - let message = backfill_receiver.recv().await.unwrap(); + let message = receiver.recv().await.unwrap(); spawn(handle_backfill_message( - app_config.clone(), - thread_type, + handler.clone(), guard.clone(), - data_url.clone(), backfill_jobs.clone(), message, )); @@ -80,10 +82,8 @@ pub async fn run( #[allow(clippy::significant_drop_tightening)] #[allow(clippy::too_many_lines)] async fn handle_backfill_message( - app_config: Arc, - thread_type: ThreadType, + handler: Arc>, guard: Arc>, - data_url: String, backfill_jobs: Arc>>>, message: Message, ) { @@ -109,50 +109,40 @@ async fn handle_backfill_message( match message.action { Action::Backfill => { + let log_string = handler.log_string(); + for symbol in symbols { if let Some(job) = backfill_jobs.get(&symbol) { if !job.is_finished() { warn!( - "{:?} - Backfill for {} is already running, skipping.", - thread_type, symbol + "Backfill for {} {} is already running, skipping.", + symbol, log_string ); continue; } } - let app_config = app_config.clone(); - let data_url = data_url.clone(); - + let handler = handler.clone(); backfill_jobs.insert( symbol.clone(), spawn(async move { - let (fetch_from, fetch_to) = - queue_backfill(&app_config, thread_type, &symbol).await; + let fetch_from = handler + .select_latest_backfill(symbol.clone()) + .await + .as_ref() + .map_or(OffsetDateTime::UNIX_EPOCH, |backfill| { + backfill.time + ONE_MINUTE + }); - match thread_type { - ThreadType::Bars(_) => { - execute_backfill_bars( - app_config, - thread_type, - data_url, - symbol, - fetch_from, - fetch_to, - ) - .await; - } - ThreadType::News => { - execute_backfill_news( - app_config, - thread_type, - data_url, - symbol, - fetch_from, - fetch_to, - ) - .await; - } + let fetch_to = last_minute(); + + if fetch_from > fetch_to { + info!("No need to backfill {} {}.", symbol, log_string,); + return; } + + handler.queue_backfill(&symbol, fetch_to).await; + handler.backfill(symbol, fetch_from, fetch_to).await; }), ); } @@ -167,263 +157,326 @@ async fn handle_backfill_message( } } - let backfills_future = database::backfills::delete_where_symbols( - &app_config.clickhouse_client, - &thread_type, - &symbols, + join!( + handler.delete_backfills(&symbols), + handler.delete_data(&symbols) ); - - let data_future = async { - match thread_type { - ThreadType::Bars(_) => { - database::bars::delete_where_symbols( - &app_config.clickhouse_client, - &symbols, - ) - .await; - } - ThreadType::News => { - database::news::delete_where_symbols( - &app_config.clickhouse_client, - &symbols, - ) - .await; - } - } - }; - - join!(backfills_future, data_future); } } message.response.send(()).unwrap(); } -async fn queue_backfill( - app_config: &Arc, +pub fn create_backfill_handler( thread_type: ThreadType, - symbol: &String, -) -> (OffsetDateTime, OffsetDateTime) { - let latest_backfill = database::backfills::select_latest_where_symbol( - &app_config.clickhouse_client, - &thread_type, - &symbol, - ) - .await; + app_config: Arc, +) -> Box { + match thread_type { + ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler { + app_config, + data_url: ALPACA_STOCK_DATA_URL, + api_query_constructor: us_equity_query_constructor, + }), + ThreadType::Bars(Class::Crypto) => Box::new(BarHandler { + app_config, + data_url: ALPACA_CRYPTO_DATA_URL, + api_query_constructor: crypto_query_constructor, + }), + ThreadType::News => Box::new(NewsHandler { app_config }), + } +} - let fetch_from = latest_backfill - .as_ref() - .map_or(OffsetDateTime::UNIX_EPOCH, |backfill| { - backfill.time + ONE_MINUTE - }); +struct BarHandler { + app_config: Arc, + data_url: &'static str, + api_query_constructor: fn( + app_config: &Arc, + symbol: String, + fetch_from: OffsetDateTime, + fetch_to: OffsetDateTime, + next_page_token: Option, + ) -> api::outgoing::bar::Bar, +} - let fetch_to = last_minute(); +fn us_equity_query_constructor( + app_config: &Arc, + symbol: String, + fetch_from: OffsetDateTime, + fetch_to: OffsetDateTime, + next_page_token: Option, +) -> api::outgoing::bar::Bar { + api::outgoing::bar::Bar::UsEquity { + symbols: vec![symbol], + timeframe: ONE_MINUTE, + start: Some(fetch_from), + end: Some(fetch_to), + limit: Some(10000), + adjustment: None, + asof: None, + feed: Some(app_config.alpaca_source), + currency: None, + page_token: next_page_token, + sort: Some(Sort::Asc), + } +} - if app_config.alpaca_source == Source::Iex { +fn crypto_query_constructor( + _: &Arc, + symbol: String, + fetch_from: OffsetDateTime, + fetch_to: OffsetDateTime, + next_page_token: Option, +) -> api::outgoing::bar::Bar { + api::outgoing::bar::Bar::Crypto { + symbols: vec![symbol], + timeframe: ONE_MINUTE, + start: Some(fetch_from), + end: Some(fetch_to), + limit: Some(10000), + page_token: next_page_token, + sort: Some(Sort::Asc), + } +} + +#[async_trait] +impl Handler for BarHandler { + async fn select_latest_backfill(&self, symbol: String) -> Option { + database::backfills::select_latest_where_symbol( + &self.app_config.clickhouse_client, + &database::backfills::Table::Bars, + &symbol, + ) + .await + } + + async fn delete_backfills(&self, symbols: &[String]) { + database::backfills::delete_where_symbols( + &self.app_config.clickhouse_client, + &database::backfills::Table::Bars, + symbols, + ) + .await; + } + + async fn delete_data(&self, symbols: &[String]) { + database::bars::delete_where_symbols(&self.app_config.clickhouse_client, symbols).await; + } + + async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { + if self.app_config.alpaca_source == Source::Iex { + let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); + info!("Queing bar backfill for {} in {:?}.", symbol, run_delay); + sleep(run_delay).await; + } + } + + async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { + info!("Backfilling bars for {}.", symbol); + + let mut bars = Vec::new(); + let mut next_page_token = None; + + loop { + let message = retry(ExponentialBackoff::default(), || async { + self.app_config.alpaca_rate_limit.until_ready().await; + self.app_config + .alpaca_client + .get(self.data_url) + .query(&(self.api_query_constructor)( + &self.app_config, + symbol.clone(), + fetch_from, + fetch_to, + next_page_token.clone(), + )) + .send() + .await? + .error_for_status()? + .json::() + .await + .map_err(backoff::Error::Permanent) + }) + .await; + + let message = match message { + Ok(message) => message, + Err(e) => { + error!("Failed to backfill bars for {}: {}.", symbol, e); + return; + } + }; + + message.bars.into_iter().for_each(|(symbol, bar_vec)| { + for bar in bar_vec { + bars.push(Bar::from((bar, symbol.clone()))); + } + }); + + if message.next_page_token.is_none() { + break; + } + next_page_token = message.next_page_token; + } + + if bars.is_empty() { + info!("No bars to backfill for {}.", symbol); + return; + } + + let backfill = bars.last().unwrap().clone().into(); + database::bars::upsert_batch(&self.app_config.clickhouse_client, bars).await; + database::backfills::upsert( + &self.app_config.clickhouse_client, + &database::backfills::Table::Bars, + &backfill, + ) + .await; + + info!("Backfilled bars for {}.", symbol); + } + + fn log_string(&self) -> &'static str { + "bars" + } +} + +struct NewsHandler { + app_config: Arc, +} + +#[async_trait] +impl Handler for NewsHandler { + async fn select_latest_backfill(&self, symbol: String) -> Option { + database::backfills::select_latest_where_symbol( + &self.app_config.clickhouse_client, + &database::backfills::Table::News, + &symbol, + ) + .await + } + + async fn delete_backfills(&self, symbols: &[String]) { + database::backfills::delete_where_symbols( + &self.app_config.clickhouse_client, + &database::backfills::Table::News, + symbols, + ) + .await; + } + + async fn delete_data(&self, symbols: &[String]) { + database::news::delete_where_symbols(&self.app_config.clickhouse_client, symbols).await; + } + + async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); - info!( - "{:?} - Queing backfill for {} in {:?}.", - thread_type, symbol, run_delay - ); + info!("Queing news backfill for {} in {:?}.", symbol, run_delay); sleep(run_delay).await; } - (fetch_from, fetch_to) -} + async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { + info!("Backfilling news for {}.", symbol); -async fn execute_backfill_bars( - app_config: Arc, - thread_type: ThreadType, - data_url: String, - symbol: String, - fetch_from: OffsetDateTime, - fetch_to: OffsetDateTime, -) { - if fetch_from > fetch_to { - return; - } + let mut news = Vec::new(); + let mut next_page_token = None; - info!("{:?} - Backfilling data for {}.", thread_type, symbol); - - let mut bars = Vec::new(); - let mut next_page_token = None; - - loop { - let message = retry(ExponentialBackoff::default(), || async { - app_config.alpaca_rate_limit.until_ready().await; - app_config - .alpaca_client - .get(&data_url) - .query(&match thread_type { - ThreadType::Bars(Class::UsEquity) => api::outgoing::bar::Bar::UsEquity { - symbols: vec![symbol.clone()], - timeframe: ONE_MINUTE, + loop { + let message = retry(ExponentialBackoff::default(), || async { + self.app_config.alpaca_rate_limit.until_ready().await; + self.app_config + .alpaca_client + .get(ALPACA_NEWS_DATA_URL) + .query(&api::outgoing::news::News { + symbols: vec![remove_slash_from_pair(&symbol)], start: Some(fetch_from), end: Some(fetch_to), - limit: Some(10000), - adjustment: None, - asof: None, - feed: Some(app_config.alpaca_source), - currency: None, + limit: Some(50), + include_content: Some(true), + exclude_contentless: Some(false), page_token: next_page_token.clone(), sort: Some(Sort::Asc), - }, - ThreadType::Bars(Class::Crypto) => api::outgoing::bar::Bar::Crypto { - symbols: vec![symbol.clone()], - timeframe: ONE_MINUTE, - start: Some(fetch_from), - end: Some(fetch_to), - limit: Some(10000), - page_token: next_page_token.clone(), - sort: Some(Sort::Asc), - }, - _ => unreachable!(), - }) - .send() - .await? - .error_for_status()? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) - .await; - - let message = match message { - Ok(message) => message, - Err(e) => { - error!( - "{:?} - Failed to backfill data for {}: {}.", - thread_type, symbol, e - ); - return; - } - }; - - message.bars.into_iter().for_each(|(symbol, bar_vec)| { - for bar in bar_vec { - bars.push(Bar::from((bar, symbol.clone()))); - } - }); - - if message.next_page_token.is_none() { - break; - } - next_page_token = message.next_page_token; - } - - if bars.is_empty() { - return; - } - - let backfill = bars.last().unwrap().clone().into(); - database::bars::upsert_batch(&app_config.clickhouse_client, bars).await; - database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await; - - info!("{:?} - Backfilled data for {}.", thread_type, symbol); -} - -async fn execute_backfill_news( - app_config: Arc, - thread_type: ThreadType, - data_url: String, - symbol: String, - fetch_from: OffsetDateTime, - fetch_to: OffsetDateTime, -) { - if fetch_from > fetch_to { - return; - } - - info!("{:?} - Backfilling data for {}.", thread_type, symbol); - - let mut news = Vec::new(); - let mut next_page_token = None; - - loop { - let message = retry(ExponentialBackoff::default(), || async { - app_config.alpaca_rate_limit.until_ready().await; - app_config - .alpaca_client - .get(&data_url) - .query(&api::outgoing::news::News { - symbols: vec![remove_slash_from_pair(&symbol)], - start: Some(fetch_from), - end: Some(fetch_to), - limit: Some(50), - include_content: Some(true), - exclude_contentless: Some(false), - page_token: next_page_token.clone(), - sort: Some(Sort::Asc), - }) - .send() - .await? - .error_for_status()? - .json::() - .await - .map_err(backoff::Error::Permanent) - }) - .await; - - let message = match message { - Ok(message) => message, - Err(e) => { - error!( - "{:?} - Failed to backfill data for {}: {}.", - thread_type, symbol, e - ); - return; - } - }; - - message.news.into_iter().for_each(|news_item| { - news.push(News::from(news_item)); - }); - - if message.next_page_token.is_none() { - break; - } - next_page_token = message.next_page_token; - } - - if news.is_empty() { - return; - } - - let inputs = news - .iter() - .map(|news| format!("{}\n\n{}", news.headline, news.content)) - .collect::>(); - - let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| { - let sequence_classifier = app_config.sequence_classifier.clone(); - async move { - let sequence_classifier = sequence_classifier.lock().await; - block_in_place(|| { - sequence_classifier - .predict(inputs.iter().map(String::as_str).collect::>()) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>() + }) + .send() + .await? + .error_for_status()? + .json::() + .await + .map_err(backoff::Error::Permanent) }) + .await; + + let message = match message { + Ok(message) => message, + Err(e) => { + error!("Failed to backfill news for {}: {}.", symbol, e); + return; + } + }; + + message.news.into_iter().for_each(|news_item| { + news.push(News::from(news_item)); + }); + + if message.next_page_token.is_none() { + break; + } + next_page_token = message.next_page_token; } - })) - .await - .into_iter() - .flatten(); - let news = news + if news.is_empty() { + info!("No news to backfill for {}.", symbol); + return; + } + + let inputs = news + .iter() + .map(|news| format!("{}\n\n{}", news.headline, news.content)) + .collect::>(); + + let predictions = join_all( + inputs + .chunks(self.app_config.max_bert_inputs) + .map(|inputs| { + let sequence_classifier = self.app_config.sequence_classifier.clone(); + async move { + let sequence_classifier = sequence_classifier.lock().await; + block_in_place(|| { + sequence_classifier + .predict(inputs.iter().map(String::as_str).collect::>()) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>() + }) + } + }), + ) + .await .into_iter() - .zip(predictions) - .map(|(news, prediction)| News { - sentiment: prediction.sentiment, - confidence: prediction.confidence, - ..news - }) - .collect::>(); + .flatten(); - let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); - database::news::upsert_batch(&app_config.clickhouse_client, news).await; - database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await; + let news = news + .into_iter() + .zip(predictions) + .map(|(news, prediction)| News { + sentiment: prediction.sentiment, + confidence: prediction.confidence, + ..news + }) + .collect::>(); - info!("{:?} - Backfilled data for {}.", thread_type, symbol); + let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); + database::news::upsert_batch(&self.app_config.clickhouse_client, news).await; + database::backfills::upsert( + &self.app_config.clickhouse_client, + &database::backfills::Table::News, + &backfill, + ) + .await; + + info!("Backfilled news for {}.", symbol); + } + + fn log_string(&self) -> &'static str { + "news" + } } diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index cf37c1f..abb3d2b 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -2,6 +2,7 @@ pub mod asset_status; pub mod backfill; pub mod websocket; +use self::asset_status::create_asset_status_handler; use super::{clock, guard::Guard}; use crate::{ config::{ @@ -85,24 +86,27 @@ async fn init_thread( let (asset_status_sender, asset_status_receiver) = mpsc::channel(100); spawn(asset_status::run( - app_config.clone(), - thread_type, + Arc::new(create_asset_status_handler( + thread_type, + app_config.clone(), + websocket_sender.clone(), + )), guard.clone(), asset_status_receiver, - websocket_sender.clone(), )); let (backfill_sender, backfill_receiver) = mpsc::channel(100); spawn(backfill::run( - app_config.clone(), - thread_type, + Arc::new(backfill::create_backfill_handler( + thread_type, + app_config.clone(), + )), guard.clone(), backfill_receiver, )); spawn(websocket::run( app_config.clone(), - thread_type, guard.clone(), websocket_sender, websocket_receiver, diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index a691eea..84af18e 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -1,4 +1,4 @@ -use super::{backfill, Guard, ThreadType}; +use super::{backfill, Guard}; use crate::{ config::Config, database, @@ -23,22 +23,18 @@ use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; pub async fn run( app_config: Arc, - thread_type: ThreadType, guard: Arc>, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, - mut websocket_receiver: SplitStream>>, + sender: Arc>, tungstenite::Message>>>, + mut receiver: SplitStream>>, backfill_sender: mpsc::Sender, ) { loop { - let message = websocket_receiver.next().await.unwrap().unwrap(); + let message = receiver.next().await.unwrap().unwrap(); spawn(handle_websocket_message( app_config.clone(), - thread_type, guard.clone(), - websocket_sender.clone(), + sender.clone(), backfill_sender.clone(), message, )); @@ -47,11 +43,8 @@ pub async fn run( async fn handle_websocket_message( app_config: Arc, - thread_type: ThreadType, guard: Arc>, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, + sender: Arc>, tungstenite::Message>>>, backfill_sender: mpsc::Sender, message: tungstenite::Message, ) { @@ -63,31 +56,24 @@ async fn handle_websocket_message( for message in message { spawn(handle_parsed_websocket_message( app_config.clone(), - thread_type, guard.clone(), backfill_sender.clone(), message, )); } } else { - error!( - "{:?} - Failed to deserialize websocket message: {:?}", - thread_type, message - ); + error!("Failed to deserialize websocket message: {:?}", message); } } tungstenite::Message::Ping(_) => { - websocket_sender + sender .lock() .await .send(tungstenite::Message::Pong(vec![])) .await .unwrap(); } - _ => error!( - "{:?} - Unexpected websocket message: {:?}", - thread_type, message - ), + _ => error!("Unexpected websocket message: {:?}", message), } } @@ -95,19 +81,20 @@ async fn handle_websocket_message( #[allow(clippy::too_many_lines)] async fn handle_parsed_websocket_message( app_config: Arc, - thread_type: ThreadType, guard: Arc>, backfill_sender: mpsc::Sender, message: websocket::incoming::Message, ) { match message { websocket::incoming::Message::Subscription(message) => { - let symbols = match message { - websocket::incoming::subscription::Message::Market { bars, .. } => bars, - websocket::incoming::subscription::Message::News { news } => news - .into_iter() - .map(|symbol| add_slash_to_pair(&symbol)) - .collect(), + let (symbols, log_string) = match message { + websocket::incoming::subscription::Message::Market { bars, .. } => (bars, "bars"), + websocket::incoming::subscription::Message::News { news } => ( + news.into_iter() + .map(|symbol| add_slash_to_pair(&symbol)) + .collect(), + "news", + ), }; let mut guard = guard.write().await; @@ -127,8 +114,8 @@ async fn handle_parsed_websocket_message( let newly_subscribed_future = async { if !newly_subscribed.is_empty() { info!( - "{:?} - Subscribed to {:?}.", - thread_type, + "Subscribed to {} for {:?}.", + log_string, newly_subscribed .iter() .map(|asset| asset.symbol.clone()) @@ -148,8 +135,8 @@ async fn handle_parsed_websocket_message( let newly_unsubscribed_future = async { if !newly_unsubscribed.is_empty() { info!( - "{:?} - Unsubscribed from {:?}.", - thread_type, + "Unsubscribed from {} for {:?}.", + log_string, newly_unsubscribed .iter() .map(|asset| asset.symbol.clone()) @@ -175,16 +162,13 @@ async fn handle_parsed_websocket_message( let guard = guard.read().await; if !guard.assets.contains_right(&bar.symbol) { warn!( - "{:?} - Race condition: received bar for unsubscribed symbol: {:?}.", - thread_type, bar.symbol + "Race condition: received bar for unsubscribed symbol: {:?}.", + bar.symbol ); return; } - debug!( - "{:?} - Received bar for {}: {}.", - thread_type, bar.symbol, bar.time - ); + debug!("Received bar for {}: {}.", bar.symbol, bar.time); database::bars::upsert(&app_config.clickhouse_client, &bar).await; } websocket::incoming::Message::News(message) => { @@ -197,15 +181,15 @@ async fn handle_parsed_websocket_message( .any(|symbol| guard.assets.contains_right(symbol)) { warn!( - "{:?} - Race condition: received news for unsubscribed symbols: {:?}.", - thread_type, news.symbols + "Race condition: received news for unsubscribed symbols: {:?}.", + news.symbols ); return; } debug!( - "{:?} - Received news for {:?}: {}.", - thread_type, news.symbols, news.time_created + "Received news for {:?}: {}.", + news.symbols, news.time_created ); let input = format!("{}\n\n{}", news.headline, news.content); @@ -229,10 +213,7 @@ async fn handle_parsed_websocket_message( } websocket::incoming::Message::Success(_) => {} websocket::incoming::Message::Error(message) => { - error!( - "{:?} - Received error message: {}.", - thread_type, message.message - ); + error!("Received error message: {}.", message.message); } } } diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index 93bc193..d164a28 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -44,8 +44,8 @@ pub enum Status { } impl From for bool { - fn from(item: Status) -> Self { - match item { + fn from(status: Status) -> Self { + match status { Status::Active => true, Status::Inactive => false, }