From 5ed0c7670ac44d29fc68dca2cdf896e478989694 Mon Sep 17 00:00:00 2001 From: Nikolaos Karaolidis Date: Tue, 12 Mar 2024 21:00:11 +0000 Subject: [PATCH] Fix backfill sentiment batching bug Signed-off-by: Nikolaos Karaolidis --- src/config.rs | 14 ++- src/lib/database/bars.rs | 2 - src/lib/database/news.rs | 2 - src/main.rs | 13 ++- src/threads/data/backfill/bars.rs | 72 ++++++++------- src/threads/data/backfill/news.rs | 141 ++++++++++++++++------------- src/threads/data/websocket/news.rs | 7 +- 7 files changed, 141 insertions(+), 110 deletions(-) diff --git a/src/config.rs b/src/config.rs index f57e040..d1bd81a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,7 +13,7 @@ use rust_bert::{ resources::LocalResource, }; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; -use tokio::sync::{Mutex, Semaphore}; +use tokio::sync::Semaphore; lazy_static! { pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") @@ -32,6 +32,14 @@ lazy_static! { env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); + pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE") + .expect("BATCH_BACKFILL_BARS_SIZE must be set.") + .parse() + .expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer."); + pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE") + .expect("BATCH_BACKFILL_NEWS_SIZE must be set.") + .parse() + .expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer."); pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS") .expect("BERT_MAX_INPUTS must be set.") .parse() @@ -47,7 +55,7 @@ pub struct Config { pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub clickhouse_client: clickhouse::Client, pub clickhouse_concurrency_limiter: Arc, - pub sequence_classifier: Mutex, + pub sequence_classifier: std::sync::Mutex, } impl Config { @@ -81,7 +89,7 @@ impl Config { ) .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)), - sequence_classifier: Mutex::new( + sequence_classifier: std::sync::Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, ModelResource::Torch(Box::new(LocalResource { diff --git a/src/lib/database/bars.rs b/src/lib/database/bars.rs index bc7182b..ca9ae01 100644 --- a/src/lib/database/bars.rs +++ b/src/lib/database/bars.rs @@ -4,8 +4,6 @@ use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; use clickhouse::Client; use tokio::sync::Semaphore; -pub const BATCH_FLUSH_SIZE: usize = 100_000; - upsert!(Bar, "bars"); upsert_batch!(Bar, "bars"); delete_where_symbols!("bars"); diff --git a/src/lib/database/news.rs b/src/lib/database/news.rs index bbcd7dc..a028c21 100644 --- a/src/lib/database/news.rs +++ b/src/lib/database/news.rs @@ -5,8 +5,6 @@ use clickhouse::{error::Error, Client}; use serde::Serialize; use tokio::sync::Semaphore; -pub const BATCH_FLUSH_SIZE: usize = 500; - upsert!(News, "news"); upsert_batch!(News, "news"); optimize!("news"); diff --git a/src/main.rs b/src/main.rs index b34c273..efa3fbd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,10 @@ mod init; mod routes; mod threads; -use config::Config; +use config::{ + Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE, + BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS, +}; use dotenv::dotenv; use log4rs::config::Deserializers; use qrust::{create_send_await, database}; @@ -19,6 +22,14 @@ async fn main() { log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); let config = Config::arc_from_env(); + let _ = *ALPACA_MODE; + let _ = *ALPACA_API_BASE; + let _ = *ALPACA_SOURCE; + let _ = *BATCH_BACKFILL_BARS_SIZE; + let _ = *BATCH_BACKFILL_NEWS_SIZE; + let _ = *BERT_MAX_INPUTS; + let _ = *CLICKHOUSE_MAX_CONNECTIONS; + try_join!( database::backfills_bars::unfresh( &config.clickhouse_client, diff --git a/src/threads/data/backfill/bars.rs b/src/threads/data/backfill/bars.rs index 702d188..09f12ae 100644 --- a/src/threads/data/backfill/bars.rs +++ b/src/threads/data/backfill/bars.rs @@ -1,6 +1,6 @@ use super::Job; use crate::{ - config::{Config, ALPACA_SOURCE}, + config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE}, database, }; use async_trait::async_trait; @@ -116,12 +116,12 @@ impl super::Handler for Handler { let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap(); let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); - info!("Backfilling bars for {:?}.", symbols); - - let mut bars = Vec::with_capacity(database::bars::BATCH_FLUSH_SIZE); + let mut bars = Vec::with_capacity(*BATCH_BACKFILL_BARS_SIZE); let mut last_times = HashMap::new(); let mut next_page_token = None; + info!("Backfilling bars for {:?}.", symbols); + loop { let message = alpaca::bars::get( &self.config.alpaca_client, @@ -144,45 +144,47 @@ impl super::Handler for Handler { let message = message.unwrap(); - for (symbol, bar_vec) in message.bars { - if let Some(last) = bar_vec.last() { + for (symbol, bars_vec) in message.bars { + if let Some(last) = bars_vec.last() { last_times.insert(symbol.clone(), last.time); } - for bar in bar_vec { + for bar in bars_vec { bars.push(Bar::from((bar, symbol.clone()))); } } - if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() { - database::bars::upsert_batch( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &bars, - ) - .await - .unwrap(); - - let backfilled = last_times - .drain() - .map(|(symbol, time)| Backfill { symbol, time }) - .collect::>(); - - database::backfills_bars::upsert_batch( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &backfilled, - ) - .await - .unwrap(); - - if message.next_page_token.is_none() { - break; - } - - next_page_token = message.next_page_token; - bars.clear(); + if bars.len() < *BATCH_BACKFILL_BARS_SIZE && message.next_page_token.is_some() { + continue; } + + database::bars::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &bars, + ) + .await + .unwrap(); + + let backfilled = last_times + .drain() + .map(|(symbol, time)| Backfill { symbol, time }) + .collect::>(); + + database::backfills_bars::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &backfilled, + ) + .await + .unwrap(); + + if message.next_page_token.is_none() { + break; + } + + next_page_token = message.next_page_token; + bars.clear(); } info!("Backfilled bars for {:?}.", symbols); diff --git a/src/threads/data/backfill/news.rs b/src/threads/data/backfill/news.rs index e23b3ff..43faf85 100644 --- a/src/threads/data/backfill/news.rs +++ b/src/threads/data/backfill/news.rs @@ -1,10 +1,9 @@ use super::Job; use crate::{ - config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS}, + config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS}, database, }; use async_trait::async_trait; -use futures_util::future::join_all; use log::{error, info}; use qrust::{ alpaca, @@ -16,7 +15,10 @@ use qrust::{ }, utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, }; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use tokio::{task::block_in_place, time::sleep}; pub struct Handler { @@ -68,22 +70,26 @@ impl super::Handler for Handler { sleep(run_delay).await; } + #[allow(clippy::too_many_lines)] + #[allow(clippy::iter_with_drain)] async fn backfill(&self, jobs: HashMap) { if jobs.is_empty() { return; } let symbols = jobs.keys().cloned().collect::>(); - let symbols_set = symbols.iter().collect::>(); + let symbols_set = symbols.clone().into_iter().collect::>(); let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap(); let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); - info!("Backfilling news for {:?}.", symbols); - - let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE); + let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE); + let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS); + let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS); let mut last_times = HashMap::new(); let mut next_page_token = None; + info!("Backfilling news for {:?}.", symbols); + loop { let message = alpaca::news::get( &self.config.alpaca_client, @@ -116,70 +122,77 @@ impl super::Handler for Handler { } } - news.push(news_item); + batch.push(news_item); } - if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() { - let inputs = news - .iter() - .map(|news| format!("{}\n\n{}", news.headline, news.content)) - .collect::>(); - - let predictions = - join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move { - let sequence_classifier = self.config.sequence_classifier.lock().await; - block_in_place(|| { - sequence_classifier - .predict(inputs.iter().map(String::as_str).collect::>()) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>() - }) - })) - .await - .into_iter() - .flatten(); - - news = news - .into_iter() - .zip(predictions) - .map(|(news, prediction)| News { - sentiment: prediction.sentiment, - confidence: prediction.confidence, - ..news - }) - .collect::>(); + if batch.len() < *BERT_MAX_INPUTS + && batch.len() < *BATCH_BACKFILL_NEWS_SIZE + && message.next_page_token.is_some() + { + continue; } - if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() { - database::news::upsert_batch( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &news, - ) - .await - .unwrap(); + let inputs = batch + .iter() + .map(|news| format!("{}\n\n{}", news.headline, news.content)) + .collect::>(); - let backfilled = last_times - .drain() - .map(|(symbol, time)| Backfill { symbol, time }) - .collect::>(); + for chunk in inputs.chunks(*BERT_MAX_INPUTS) { + let chunk_predictions = block_in_place(|| { + self.config + .sequence_classifier + .lock() + .unwrap() + .predict(chunk.iter().map(String::as_str).collect::>()) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + }); - database::backfills_news::upsert_batch( - &self.config.clickhouse_client, - &self.config.clickhouse_concurrency_limiter, - &backfilled, - ) - .await - .unwrap(); - - if message.next_page_token.is_none() { - break; - } - - next_page_token = message.next_page_token; - news.clear(); + predictions.extend(chunk_predictions); } + + let zipped = batch + .drain(..) + .zip(predictions.drain(..)) + .map(|(news, prediction)| News { + sentiment: prediction.sentiment, + confidence: prediction.confidence, + ..news + }); + + news.extend(zipped); + + if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() { + continue; + } + + database::news::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &news, + ) + .await + .unwrap(); + + let backfilled = last_times + .drain() + .map(|(symbol, time)| Backfill { symbol, time }) + .collect::>(); + + database::backfills_news::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &backfilled, + ) + .await + .unwrap(); + + if message.next_page_token.is_none() { + break; + } + + next_page_token = message.next_page_token; + news.clear(); } info!("Backfilled news for {:?}.", symbols); diff --git a/src/threads/data/websocket/news.rs b/src/threads/data/websocket/news.rs index 08a32cf..be46cc2 100644 --- a/src/threads/data/websocket/news.rs +++ b/src/threads/data/websocket/news.rs @@ -82,15 +82,16 @@ impl super::Handler for Handler { let input = format!("{}\n\n{}", news.headline, news.content); - let sequence_classifier = self.config.sequence_classifier.lock().await; let prediction = block_in_place(|| { - sequence_classifier + self.config + .sequence_classifier + .lock() + .unwrap() .predict(vec![input.as_str()]) .into_iter() .map(|label| Prediction::try_from(label).unwrap()) .collect::>()[0] }); - drop(sequence_classifier); let news = News { sentiment: prediction.sentiment,