diff --git a/src/database/backfills_bars.rs b/src/database/backfills_bars.rs index 1360665..75f191f 100644 --- a/src/database/backfills_bars.rs +++ b/src/database/backfills_bars.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use crate::{ - cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert, + cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert_batch, }; use clickhouse::{error::Error, Client}; use tokio::sync::Semaphore; select_where_symbols!(Backfill, "backfills_bars"); -upsert!(Backfill, "backfills_bars"); +upsert_batch!(Backfill, "backfills_bars"); delete_where_symbols!("backfills_bars"); cleanup!("backfills_bars"); optimize!("backfills_bars"); diff --git a/src/database/backfills_news.rs b/src/database/backfills_news.rs index 5688eda..92d5040 100644 --- a/src/database/backfills_news.rs +++ b/src/database/backfills_news.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use crate::{ - cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert, + cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert_batch, }; use clickhouse::{error::Error, Client}; use tokio::sync::Semaphore; select_where_symbols!(Backfill, "backfills_news"); -upsert!(Backfill, "backfills_news"); +upsert_batch!(Backfill, "backfills_news"); delete_where_symbols!("backfills_news"); cleanup!("backfills_news"); optimize!("backfills_news"); diff --git a/src/database/bars.rs b/src/database/bars.rs index ca9ae01..bc7182b 100644 --- a/src/database/bars.rs +++ b/src/database/bars.rs @@ -4,6 +4,8 @@ use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; use clickhouse::Client; use tokio::sync::Semaphore; +pub const BATCH_FLUSH_SIZE: usize = 100_000; + upsert!(Bar, "bars"); upsert_batch!(Bar, "bars"); delete_where_symbols!("bars"); diff --git a/src/database/news.rs b/src/database/news.rs index a028c21..bbcd7dc 100644 --- a/src/database/news.rs +++ b/src/database/news.rs @@ -5,6 +5,8 @@ use clickhouse::{error::Error, Client}; use serde::Serialize; use tokio::sync::Semaphore; +pub const BATCH_FLUSH_SIZE: usize = 500; + upsert!(News, "news"); upsert_batch!(News, "news"); optimize!("news"); diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index 9f86d10..a4b587b 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -6,7 +6,10 @@ use crate::{ }, database, types::{ - alpaca::{self, shared::Source}, + alpaca::{ + self, + shared::{Sort, Source}, + }, news::Prediction, Backfill, Bar, Class, News, }, @@ -14,6 +17,7 @@ use crate::{ }; use async_trait::async_trait; use futures_util::future::join_all; +use itertools::{Either, Itertools}; use log::{error, info, warn}; use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; @@ -24,6 +28,7 @@ use tokio::{ time::sleep, try_join, }; +use uuid::Uuid; pub enum Action { Backfill, @@ -50,6 +55,12 @@ impl Message { } } +#[derive(Clone)] +pub struct Job { + pub fetch_from: OffsetDateTime, + pub fetch_to: OffsetDateTime, +} + #[async_trait] pub trait Handler: Send + Sync { async fn select_latest_backfills( @@ -58,13 +69,44 @@ pub trait Handler: Send + Sync { ) -> Result, clickhouse::error::Error>; async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; - async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); - async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime); + async fn queue_backfill(&self, jobs: &HashMap); + async fn backfill(&self, jobs: HashMap); + fn max_limit(&self) -> i64; fn log_string(&self) -> &'static str; } +pub struct Jobs { + pub symbol_to_uuid: HashMap, + pub uuid_to_job: HashMap>, +} + +impl Jobs { + pub fn insert(&mut self, jobs: Vec, fut: JoinHandle<()>) { + let uuid = Uuid::new_v4(); + for symbol in jobs { + self.symbol_to_uuid.insert(symbol.clone(), uuid); + } + self.uuid_to_job.insert(uuid, fut); + } + + pub fn get(&self, symbol: &str) -> Option<&JoinHandle<()>> { + self.symbol_to_uuid + .get(symbol) + .and_then(|uuid| self.uuid_to_job.get(uuid)) + } + + pub fn remove(&mut self, symbol: &str) -> Option> { + self.symbol_to_uuid + .remove(symbol) + .and_then(|uuid| self.uuid_to_job.remove(&uuid)) + } +} + pub async fn run(handler: Arc>, mut receiver: mpsc::Receiver) { - let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); + let backfill_jobs = Arc::new(Mutex::new(Jobs { + symbol_to_uuid: HashMap::new(), + uuid_to_job: HashMap::new(), + })); loop { let message = receiver.recv().await.unwrap(); @@ -78,7 +120,7 @@ pub async fn run(handler: Arc>, mut receiver: mpsc::Receiver>, - backfill_jobs: Arc>>>, + backfill_jobs: Arc>, message: Message, ) { let mut backfill_jobs = backfill_jobs.lock().await; @@ -86,6 +128,7 @@ async fn handle_backfill_message( match message.action { Action::Backfill => { let log_string = handler.log_string(); + let max_limit = handler.max_limit(); let backfills = handler .select_latest_backfills(&message.symbols) @@ -95,6 +138,8 @@ async fn handle_backfill_message( .map(|backfill| (backfill.symbol.clone(), backfill)) .collect::>(); + let mut jobs = vec![]; + for symbol in message.symbols { if let Some(job) = backfill_jobs.get(&symbol) { if !job.is_finished() { @@ -119,14 +164,49 @@ async fn handle_backfill_message( return; } + jobs.push(( + symbol, + Job { + fetch_from, + fetch_to, + }, + )); + } + + let jobs = jobs + .into_iter() + .sorted_by_key(|job| job.1.fetch_from) + .collect::>(); + + let mut job_groups = vec![HashMap::new()]; + let mut current_minutes = 0; + + for job in jobs { + let minutes = (job.1.fetch_to - job.1.fetch_from).whole_minutes(); + + if job_groups.last().unwrap().is_empty() || (current_minutes + minutes) <= max_limit + { + let job_group = job_groups.last_mut().unwrap(); + job_group.insert(job.0, job.1); + current_minutes += minutes; + } else { + let mut job_group = HashMap::new(); + job_group.insert(job.0, job.1); + job_groups.push(job_group); + current_minutes = minutes; + } + } + + for job_group in job_groups { + let symbols = job_group.keys().cloned().collect::>(); + let handler = handler.clone(); - backfill_jobs.insert( - symbol.clone(), - spawn(async move { - handler.queue_backfill(&symbol, fetch_to).await; - handler.backfill(symbol, fetch_from, fetch_to).await; - }), - ); + let fut = spawn(async move { + handler.queue_backfill(&job_group).await; + handler.backfill(job_group).await; + }); + + backfill_jobs.insert(symbols, fut); } } Action::Purge => { @@ -154,7 +234,7 @@ struct BarHandler { config: Arc, data_url: &'static str, api_query_constructor: fn( - symbol: String, + symbols: Vec, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, @@ -162,31 +242,33 @@ struct BarHandler { } fn us_equity_query_constructor( - symbol: String, + symbols: Vec, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, ) -> alpaca::api::outgoing::bar::Bar { alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity { - symbols: vec![symbol], + symbols, start: Some(fetch_from), end: Some(fetch_to), page_token: next_page_token, + sort: Some(Sort::Asc), ..Default::default() }) } fn crypto_query_constructor( - symbol: String, + symbols: Vec, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, ) -> alpaca::api::outgoing::bar::Bar { alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto { - symbols: vec![symbol], + symbols, start: Some(fetch_from), end: Some(fetch_to), page_token: next_page_token, + sort: Some(Sort::Asc), ..Default::default() }) } @@ -223,18 +305,31 @@ impl Handler for BarHandler { .await } - async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { - if *ALPACA_SOURCE == Source::Iex { - let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); - info!("Queing bar backfill for {} in {:?}.", symbol, run_delay); - sleep(run_delay).await; + async fn queue_backfill(&self, jobs: &HashMap) { + if *ALPACA_SOURCE == Source::Sip { + return; } + + let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); + let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); + let symbols = jobs.keys().collect::>(); + + info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay); + sleep(run_delay).await; } - async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { - info!("Backfilling bars for {}.", symbol); + async fn backfill(&self, jobs: HashMap) { + let symbols = jobs.keys().cloned().collect::>(); + let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap(); + let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); + + info!("Backfilling bars for {:?}.", symbols); let mut bars = vec![]; + let mut last_time = symbols + .iter() + .map(|symbol| (symbol.clone(), None)) + .collect::>(); let mut next_page_token = None; loop { @@ -243,7 +338,7 @@ impl Handler for BarHandler { &self.config.alpaca_rate_limiter, self.data_url, &(self.api_query_constructor)( - symbol.clone(), + symbols.clone(), fetch_from, fetch_to, next_page_token.clone(), @@ -252,15 +347,30 @@ impl Handler for BarHandler { ) .await else { - error!("Failed to backfill bars for {}.", symbol); + error!("Failed to backfill bars for {:?}.", symbols); return; }; - message.bars.into_iter().for_each(|(symbol, bar_vec)| { + for (symbol, bar_vec) in message.bars { + if let Some(last) = bar_vec.last() { + last_time.insert(symbol.clone(), Some(last.time)); + } + for bar in bar_vec { bars.push(Bar::from((bar, symbol.clone()))); } - }); + } + + if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() { + database::bars::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &bars, + ) + .await + .unwrap(); + bars = vec![]; + } if message.next_page_token.is_none() { break; @@ -268,29 +378,29 @@ impl Handler for BarHandler { next_page_token = message.next_page_token; } - if bars.is_empty() { - info!("No bars to backfill for {}.", symbol); - return; - } + let (backfilled, skipped): (Vec<_>, Vec<_>) = + last_time.into_iter().partition_map(|(symbol, time)| { + if let Some(time) = time { + Either::Left(Backfill { symbol, time }) + } else { + Either::Right(symbol) + } + }); - let backfill = bars.last().unwrap().clone().into(); - - database::bars::upsert_batch( + database::backfills_bars::upsert_batch( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, - &bars, - ) - .await - .unwrap(); - database::backfills_bars::upsert( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &backfill, + &backfilled, ) .await .unwrap(); - info!("Backfilled bars for {}.", symbol); + info!("No bars to backfill for {:?}.", skipped); + info!("Backfilled bars for {:?}.", backfilled); + } + + fn max_limit(&self) -> i64 { + alpaca::api::outgoing::bar::MAX_LIMIT } fn log_string(&self) -> &'static str { @@ -334,16 +444,31 @@ impl Handler for NewsHandler { .await } - async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { + async fn queue_backfill(&self, jobs: &HashMap) { + if *ALPACA_SOURCE == Source::Sip { + return; + } + + let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); - info!("Queing news backfill for {} in {:?}.", symbol, run_delay); + let symbols = jobs.keys().cloned().collect::>(); + + info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay); sleep(run_delay).await; } - async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { - info!("Backfilling news for {}.", symbol); + async fn backfill(&self, jobs: HashMap) { + let symbols = jobs.keys().cloned().collect::>(); + let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap(); + let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); + + info!("Backfilling news for {:?}.", symbols); let mut news = vec![]; + let mut last_time = symbols + .iter() + .map(|symbol| (symbol.clone(), None)) + .collect::>(); let mut next_page_token = None; loop { @@ -351,7 +476,7 @@ impl Handler for NewsHandler { &self.config.alpaca_client, &self.config.alpaca_rate_limiter, &alpaca::api::outgoing::news::News { - symbols: vec![symbol.clone()], + symbols: symbols.clone(), start: Some(fetch_from), end: Some(fetch_to), page_token: next_page_token.clone(), @@ -361,13 +486,62 @@ impl Handler for NewsHandler { ) .await else { - error!("Failed to backfill news for {}.", symbol); + error!("Failed to backfill news for {:?}.", symbols); return; }; - message.news.into_iter().for_each(|news_item| { - news.push(News::from(news_item)); - }); + for news_item in message.news { + let news_item = News::from(news_item); + + for symbol in &news_item.symbols { + last_time.insert(symbol.clone(), Some(news_item.time_created)); + } + + news.push(news_item); + } + + if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() { + let inputs = news + .iter() + .map(|news| format!("{}\n\n{}", news.headline, news.content)) + .collect::>(); + + let predictions = + join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move { + let sequence_classifier = self.config.sequence_classifier.lock().await; + block_in_place(|| { + sequence_classifier + .predict(inputs.iter().map(String::as_str).collect::>()) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>() + }) + })) + .await + .into_iter() + .flatten(); + + news = news + .into_iter() + .zip(predictions) + .map(|(news, prediction)| News { + sentiment: prediction.sentiment, + confidence: prediction.confidence, + ..news + }) + .collect::>(); + } + + if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() { + database::news::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &news, + ) + .await + .unwrap(); + news = vec![]; + } if message.next_page_token.is_none() { break; @@ -375,58 +549,29 @@ impl Handler for NewsHandler { next_page_token = message.next_page_token; } - if news.is_empty() { - info!("No news to backfill for {}.", symbol); - return; - } + let (backfilled, skipped): (Vec<_>, Vec<_>) = + last_time.into_iter().partition_map(|(symbol, time)| { + if let Some(time) = time { + Either::Left(Backfill { symbol, time }) + } else { + Either::Right(symbol) + } + }); - let inputs = news - .iter() - .map(|news| format!("{}\n\n{}", news.headline, news.content)) - .collect::>(); - - let predictions = join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move { - let sequence_classifier = self.config.sequence_classifier.lock().await; - block_in_place(|| { - sequence_classifier - .predict(inputs.iter().map(String::as_str).collect::>()) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>() - }) - })) - .await - .into_iter() - .flatten(); - - let news = news - .into_iter() - .zip(predictions) - .map(|(news, prediction)| News { - sentiment: prediction.sentiment, - confidence: prediction.confidence, - ..news - }) - .collect::>(); - - let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); - - database::news::upsert_batch( + database::backfills_news::upsert_batch( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, - &news, - ) - .await - .unwrap(); - database::backfills_news::upsert( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &backfill, + &backfilled, ) .await .unwrap(); - info!("Backfilled news for {}.", symbol); + info!("No news to backfill for {:?}.", skipped); + info!("Backfilled news for {:?}.", backfilled); + } + + fn max_limit(&self) -> i64 { + alpaca::api::outgoing::news::MAX_LIMIT } fn log_string(&self) -> &'static str { diff --git a/src/types/alpaca/api/outgoing/bar.rs b/src/types/alpaca/api/outgoing/bar.rs index d548a03..4300292 100644 --- a/src/types/alpaca/api/outgoing/bar.rs +++ b/src/types/alpaca/api/outgoing/bar.rs @@ -7,6 +7,8 @@ use serde::Serialize; use std::time::Duration; use time::OffsetDateTime; +pub const MAX_LIMIT: i64 = 10_000; + #[derive(Serialize)] #[serde(rename_all = "snake_case")] #[allow(dead_code)] @@ -53,7 +55,7 @@ impl Default for UsEquity { timeframe: ONE_MINUTE, start: None, end: None, - limit: Some(10000), + limit: Some(MAX_LIMIT), adjustment: Some(Adjustment::All), asof: None, feed: Some(*ALPACA_SOURCE), @@ -91,7 +93,7 @@ impl Default for Crypto { timeframe: ONE_MINUTE, start: None, end: None, - limit: Some(10000), + limit: Some(MAX_LIMIT), page_token: None, sort: Some(Sort::Asc), } diff --git a/src/types/alpaca/api/outgoing/news.rs b/src/types/alpaca/api/outgoing/news.rs index aae45d1..c271fe1 100644 --- a/src/types/alpaca/api/outgoing/news.rs +++ b/src/types/alpaca/api/outgoing/news.rs @@ -2,6 +2,8 @@ use crate::{types::alpaca::shared::Sort, utils::ser}; use serde::Serialize; use time::OffsetDateTime; +pub const MAX_LIMIT: i64 = 50; + #[derive(Serialize)] pub struct News { #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] @@ -30,7 +32,7 @@ impl Default for News { symbols: vec![], start: None, end: None, - limit: Some(50), + limit: Some(MAX_LIMIT), include_content: Some(true), exclude_contentless: Some(false), page_token: None,