Fix backfill sentiment batching bug
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -13,7 +13,7 @@ use rust_bert::{
|
||||
resources::LocalResource,
|
||||
};
|
||||
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{Mutex, Semaphore};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
||||
@@ -32,6 +32,14 @@ lazy_static! {
|
||||
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
||||
pub static ref ALPACA_API_SECRET: String =
|
||||
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
||||
pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
|
||||
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
|
||||
.parse()
|
||||
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
|
||||
pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
|
||||
.parse()
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
|
||||
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
|
||||
.expect("BERT_MAX_INPUTS must be set.")
|
||||
.parse()
|
||||
@@ -47,7 +55,7 @@ pub struct Config {
|
||||
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
||||
pub clickhouse_client: clickhouse::Client,
|
||||
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
||||
pub sequence_classifier: Mutex<SequenceClassificationModel>,
|
||||
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -81,7 +89,7 @@ impl Config {
|
||||
)
|
||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
||||
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
||||
sequence_classifier: Mutex::new(
|
||||
sequence_classifier: std::sync::Mutex::new(
|
||||
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::Torch(Box::new(LocalResource {
|
||||
|
@@ -4,8 +4,6 @@ use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
|
||||
use clickhouse::Client;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
pub const BATCH_FLUSH_SIZE: usize = 100_000;
|
||||
|
||||
upsert!(Bar, "bars");
|
||||
upsert_batch!(Bar, "bars");
|
||||
delete_where_symbols!("bars");
|
||||
|
@@ -5,8 +5,6 @@ use clickhouse::{error::Error, Client};
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
pub const BATCH_FLUSH_SIZE: usize = 500;
|
||||
|
||||
upsert!(News, "news");
|
||||
upsert_batch!(News, "news");
|
||||
optimize!("news");
|
||||
|
13
src/main.rs
13
src/main.rs
@@ -7,7 +7,10 @@ mod init;
|
||||
mod routes;
|
||||
mod threads;
|
||||
|
||||
use config::Config;
|
||||
use config::{
|
||||
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE,
|
||||
BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS,
|
||||
};
|
||||
use dotenv::dotenv;
|
||||
use log4rs::config::Deserializers;
|
||||
use qrust::{create_send_await, database};
|
||||
@@ -19,6 +22,14 @@ async fn main() {
|
||||
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
|
||||
let config = Config::arc_from_env();
|
||||
|
||||
let _ = *ALPACA_MODE;
|
||||
let _ = *ALPACA_API_BASE;
|
||||
let _ = *ALPACA_SOURCE;
|
||||
let _ = *BATCH_BACKFILL_BARS_SIZE;
|
||||
let _ = *BATCH_BACKFILL_NEWS_SIZE;
|
||||
let _ = *BERT_MAX_INPUTS;
|
||||
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
|
||||
|
||||
try_join!(
|
||||
database::backfills_bars::unfresh(
|
||||
&config.clickhouse_client,
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE},
|
||||
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE},
|
||||
database,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
@@ -116,12 +116,12 @@ impl super::Handler for Handler {
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling bars for {:?}.", symbols);
|
||||
|
||||
let mut bars = Vec::with_capacity(database::bars::BATCH_FLUSH_SIZE);
|
||||
let mut bars = Vec::with_capacity(*BATCH_BACKFILL_BARS_SIZE);
|
||||
let mut last_times = HashMap::new();
|
||||
let mut next_page_token = None;
|
||||
|
||||
info!("Backfilling bars for {:?}.", symbols);
|
||||
|
||||
loop {
|
||||
let message = alpaca::bars::get(
|
||||
&self.config.alpaca_client,
|
||||
@@ -144,45 +144,47 @@ impl super::Handler for Handler {
|
||||
|
||||
let message = message.unwrap();
|
||||
|
||||
for (symbol, bar_vec) in message.bars {
|
||||
if let Some(last) = bar_vec.last() {
|
||||
for (symbol, bars_vec) in message.bars {
|
||||
if let Some(last) = bars_vec.last() {
|
||||
last_times.insert(symbol.clone(), last.time);
|
||||
}
|
||||
|
||||
for bar in bar_vec {
|
||||
for bar in bars_vec {
|
||||
bars.push(Bar::from((bar, symbol.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bars,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill { symbol, time })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::backfills_bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
bars.clear();
|
||||
if bars.len() < *BATCH_BACKFILL_BARS_SIZE && message.next_page_token.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
database::bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bars,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill { symbol, time })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::backfills_bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
bars.clear();
|
||||
}
|
||||
|
||||
info!("Backfilled bars for {:?}.", symbols);
|
||||
|
@@ -1,10 +1,9 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
|
||||
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS},
|
||||
database,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future::join_all;
|
||||
use log::{error, info};
|
||||
use qrust::{
|
||||
alpaca,
|
||||
@@ -16,7 +15,10 @@ use qrust::{
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{task::block_in_place, time::sleep};
|
||||
|
||||
pub struct Handler {
|
||||
@@ -68,22 +70,26 @@ impl super::Handler for Handler {
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[allow(clippy::iter_with_drain)]
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||
if jobs.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
let symbols_set = symbols.iter().collect::<std::collections::HashSet<_>>();
|
||||
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling news for {:?}.", symbols);
|
||||
|
||||
let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE);
|
||||
let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE);
|
||||
let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||
let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||
let mut last_times = HashMap::new();
|
||||
let mut next_page_token = None;
|
||||
|
||||
info!("Backfilling news for {:?}.", symbols);
|
||||
|
||||
loop {
|
||||
let message = alpaca::news::get(
|
||||
&self.config.alpaca_client,
|
||||
@@ -116,70 +122,77 @@ impl super::Handler for Handler {
|
||||
}
|
||||
}
|
||||
|
||||
news.push(news_item);
|
||||
batch.push(news_item);
|
||||
}
|
||||
|
||||
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
|
||||
let inputs = news
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predictions =
|
||||
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
news = news
|
||||
.into_iter()
|
||||
.zip(predictions)
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if batch.len() < *BERT_MAX_INPUTS
|
||||
&& batch.len() < *BATCH_BACKFILL_NEWS_SIZE
|
||||
&& message.next_page_token.is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let inputs = batch
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill { symbol, time })
|
||||
.collect::<Vec<_>>();
|
||||
for chunk in inputs.chunks(*BERT_MAX_INPUTS) {
|
||||
let chunk_predictions = block_in_place(|| {
|
||||
self.config
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(chunk.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
});
|
||||
|
||||
database::backfills_news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
news.clear();
|
||||
predictions.extend(chunk_predictions);
|
||||
}
|
||||
|
||||
let zipped = batch
|
||||
.drain(..)
|
||||
.zip(predictions.drain(..))
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
});
|
||||
|
||||
news.extend(zipped);
|
||||
|
||||
if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
database::news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill { symbol, time })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::backfills_news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
news.clear();
|
||||
}
|
||||
|
||||
info!("Backfilled news for {:?}.", symbols);
|
||||
|
@@ -82,15 +82,16 @@ impl super::Handler for Handler {
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
self.config
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
|
Reference in New Issue
Block a user