use super::ThreadType; use crate::{ config::{Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_STOCK_DATA_API_URL}, database, types::{ alpaca::{ api, shared::{Sort, Source}, }, news::Prediction, Backfill, Bar, Class, News, }, utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND}, }; use async_trait::async_trait; use futures_util::future::join_all; use log::{error, info, warn}; use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; use tokio::{ spawn, sync::{mpsc, oneshot, Mutex}, task::{block_in_place, JoinHandle}, time::sleep, try_join, }; pub enum Action { Backfill, Purge, } impl From for Action { fn from(action: super::Action) -> Self { match action { super::Action::Add => Self::Backfill, super::Action::Remove => Self::Purge, } } } pub struct Message { pub action: Action, pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { action, symbols, response: sender, }, receiver, ) } } #[async_trait] pub trait Handler: Send + Sync { async fn select_latest_backfill( &self, symbol: String, ) -> Result, clickhouse::error::Error>; async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime); fn log_string(&self) -> &'static str; } pub async fn run(handler: Arc>, mut receiver: mpsc::Receiver) { let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); loop { let message = receiver.recv().await.unwrap(); spawn(handle_backfill_message( handler.clone(), backfill_jobs.clone(), message, )); } } async fn handle_backfill_message( handler: Arc>, backfill_jobs: Arc>>>, message: Message, ) { let mut backfill_jobs = backfill_jobs.lock().await; match message.action { Action::Backfill => { let log_string = handler.log_string(); for symbol in message.symbols { if let Some(job) = backfill_jobs.get(&symbol) { if !job.is_finished() { warn!( "Backfill for {} {} is already running, skipping.", symbol, log_string ); continue; } } let handler = handler.clone(); backfill_jobs.insert( symbol.clone(), spawn(async move { let fetch_from = match handler .select_latest_backfill(symbol.clone()) .await .unwrap() { Some(latest_backfill) => latest_backfill.time + ONE_SECOND, None => OffsetDateTime::UNIX_EPOCH, }; let fetch_to = last_minute(); if fetch_from > fetch_to { info!("No need to backfill {} {}.", symbol, log_string,); return; } handler.queue_backfill(&symbol, fetch_to).await; handler.backfill(symbol, fetch_from, fetch_to).await; }), ); } } Action::Purge => { for symbol in &message.symbols { if let Some(job) = backfill_jobs.remove(symbol) { if !job.is_finished() { job.abort(); } let _ = job.await; } } try_join!( handler.delete_backfills(&message.symbols), handler.delete_data(&message.symbols) ) .unwrap(); } } message.response.send(()).unwrap(); } struct BarHandler { config: Arc, data_url: &'static str, api_query_constructor: fn( config: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, ) -> api::outgoing::bar::Bar, } fn us_equity_query_constructor( config: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, ) -> api::outgoing::bar::Bar { api::outgoing::bar::Bar::UsEquity { symbols: vec![symbol], timeframe: ONE_MINUTE, start: Some(fetch_from), end: Some(fetch_to), limit: Some(10000), adjustment: None, asof: None, feed: Some(config.alpaca_source), currency: None, page_token: next_page_token, sort: Some(Sort::Asc), } } fn crypto_query_constructor( _: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, ) -> api::outgoing::bar::Bar { api::outgoing::bar::Bar::Crypto { symbols: vec![symbol], timeframe: ONE_MINUTE, start: Some(fetch_from), end: Some(fetch_to), limit: Some(10000), page_token: next_page_token, sort: Some(Sort::Asc), } } #[async_trait] impl Handler for BarHandler { async fn select_latest_backfill( &self, symbol: String, ) -> Result, clickhouse::error::Error> { database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols) .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await } async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { if self.config.alpaca_source == Source::Iex { let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); info!("Queing bar backfill for {} in {:?}.", symbol, run_delay); sleep(run_delay).await; } } async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { info!("Backfilling bars for {}.", symbol); let mut bars = vec![]; let mut next_page_token = None; loop { let Ok(message) = api::incoming::bar::get_historical( &self.config, self.data_url, &(self.api_query_constructor)( &self.config, symbol.clone(), fetch_from, fetch_to, next_page_token.clone(), ), None, ) .await else { error!("Failed to backfill bars for {}.", symbol); return; }; message.bars.into_iter().for_each(|(symbol, bar_vec)| { for bar in bar_vec { bars.push(Bar::from((bar, symbol.clone()))); } }); if message.next_page_token.is_none() { break; } next_page_token = message.next_page_token; } if bars.is_empty() { info!("No bars to backfill for {}.", symbol); return; } let backfill = bars.last().unwrap().clone().into(); database::bars::upsert_batch(&self.config.clickhouse_client, &bars) .await .unwrap(); database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill) .await .unwrap(); info!("Backfilled bars for {}.", symbol); } fn log_string(&self) -> &'static str { "bars" } } struct NewsHandler { config: Arc, } #[async_trait] impl Handler for NewsHandler { async fn select_latest_backfill( &self, symbol: String, ) -> Result, clickhouse::error::Error> { database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols) .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await } async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); info!("Queing news backfill for {} in {:?}.", symbol, run_delay); sleep(run_delay).await; } async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { info!("Backfilling news for {}.", symbol); let mut news = vec![]; let mut next_page_token = None; loop { let Ok(message) = api::incoming::news::get_historical( &self.config, &api::outgoing::news::News { symbols: vec![symbol.clone()], start: Some(fetch_from), end: Some(fetch_to), limit: Some(50), include_content: Some(true), exclude_contentless: Some(false), page_token: next_page_token.clone(), sort: Some(Sort::Asc), }, None, ) .await else { error!("Failed to backfill news for {}.", symbol); return; }; message.news.into_iter().for_each(|news_item| { news.push(News::from(news_item)); }); if message.next_page_token.is_none() { break; } next_page_token = message.next_page_token; } if news.is_empty() { info!("No news to backfill for {}.", symbol); return; } let inputs = news .iter() .map(|news| format!("{}\n\n{}", news.headline, news.content)) .collect::>(); let predictions = join_all(inputs.chunks(self.config.max_bert_inputs).map(|inputs| { let sequence_classifier = self.config.sequence_classifier.clone(); async move { let sequence_classifier = sequence_classifier.lock().await; block_in_place(|| { sequence_classifier .predict(inputs.iter().map(String::as_str).collect::>()) .into_iter() .map(|label| Prediction::try_from(label).unwrap()) .collect::>() }) } })) .await .into_iter() .flatten(); let news = news .into_iter() .zip(predictions) .map(|(news, prediction)| News { sentiment: prediction.sentiment, confidence: prediction.confidence, ..news }) .collect::>(); let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); database::news::upsert_batch(&self.config.clickhouse_client, &news) .await .unwrap(); database::backfills_news::upsert(&self.config.clickhouse_client, &backfill) .await .unwrap(); info!("Backfilled news for {}.", symbol); } fn log_string(&self) -> &'static str { "news" } } pub fn create_handler(thread_type: ThreadType, config: Arc) -> Box { match thread_type { ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler { config, data_url: ALPACA_STOCK_DATA_API_URL, api_query_constructor: us_equity_query_constructor, }), ThreadType::Bars(Class::Crypto) => Box::new(BarHandler { config, data_url: ALPACA_CRYPTO_DATA_API_URL, api_query_constructor: crypto_query_constructor, }), ThreadType::News => Box::new(NewsHandler { config }), } }