Fix backfill sentiment batching bug

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-12 21:00:11 +00:00
parent d2d20e2978
commit 5ed0c7670a
7 changed files with 141 additions and 110 deletions

View File

@@ -13,7 +13,7 @@ use rust_bert::{
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::{Mutex, Semaphore};
use tokio::sync::Semaphore;
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
@@ -32,6 +32,14 @@ lazy_static! {
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
.expect("BERT_MAX_INPUTS must be set.")
.parse()
@@ -47,7 +55,7 @@ pub struct Config {
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
pub sequence_classifier: Mutex<SequenceClassificationModel>,
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
}
impl Config {
@@ -81,7 +89,7 @@ impl Config {
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
sequence_classifier: Mutex::new(
sequence_classifier: std::sync::Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {

View File

@@ -4,8 +4,6 @@ use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
pub const BATCH_FLUSH_SIZE: usize = 100_000;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");

View File

@@ -5,8 +5,6 @@ use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
pub const BATCH_FLUSH_SIZE: usize = 500;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");

View File

@@ -7,7 +7,10 @@ mod init;
mod routes;
mod threads;
use config::Config;
use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE,
BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS,
};
use dotenv::dotenv;
use log4rs::config::Deserializers;
use qrust::{create_send_await, database};
@@ -19,6 +22,14 @@ async fn main() {
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *BATCH_BACKFILL_BARS_SIZE;
let _ = *BATCH_BACKFILL_NEWS_SIZE;
let _ = *BERT_MAX_INPUTS;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
try_join!(
database::backfills_bars::unfresh(
&config.clickhouse_client,

View File

@@ -1,6 +1,6 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE},
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE},
database,
};
use async_trait::async_trait;
@@ -116,12 +116,12 @@ impl super::Handler for Handler {
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
info!("Backfilling bars for {:?}.", symbols);
let mut bars = Vec::with_capacity(database::bars::BATCH_FLUSH_SIZE);
let mut bars = Vec::with_capacity(*BATCH_BACKFILL_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
@@ -144,45 +144,47 @@ impl super::Handler for Handler {
let message = message.unwrap();
for (symbol, bar_vec) in message.bars {
if let Some(last) = bar_vec.last() {
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bar_vec {
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill { symbol, time })
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
if bars.len() < *BATCH_BACKFILL_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill { symbol, time })
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
info!("Backfilled bars for {:?}.", symbols);

View File

@@ -1,10 +1,9 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS},
database,
};
use async_trait::async_trait;
use futures_util::future::join_all;
use log::{error, info};
use qrust::{
alpaca,
@@ -16,7 +15,10 @@ use qrust::{
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::{task::block_in_place, time::sleep};
pub struct Handler {
@@ -68,22 +70,26 @@ impl super::Handler for Handler {
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: HashMap<String, Job>) {
if jobs.is_empty() {
return;
}
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
let symbols_set = symbols.iter().collect::<std::collections::HashSet<_>>();
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
info!("Backfilling news for {:?}.", symbols);
let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE);
let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE);
let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS);
let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
@@ -116,70 +122,77 @@ impl super::Handler for Handler {
}
}
news.push(news_item);
batch.push(news_item);
}
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions =
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}))
.await
.into_iter()
.flatten();
news = news
.into_iter()
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
if batch.len() < *BERT_MAX_INPUTS
&& batch.len() < *BATCH_BACKFILL_NEWS_SIZE
&& message.next_page_token.is_some()
{
continue;
}
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let inputs = batch
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill { symbol, time })
.collect::<Vec<_>>();
for chunk in inputs.chunks(*BERT_MAX_INPUTS) {
let chunk_predictions = block_in_place(|| {
self.config
.sequence_classifier
.lock()
.unwrap()
.predict(chunk.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
});
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
predictions.extend(chunk_predictions);
}
let zipped = batch
.drain(..)
.zip(predictions.drain(..))
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
});
news.extend(zipped);
if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill { symbol, time })
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
info!("Backfilled news for {:?}.", symbols);

View File

@@ -82,15 +82,16 @@ impl super::Handler for Handler {
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
self.config
.sequence_classifier
.lock()
.unwrap()
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,