Fix backfill sentiment batching bug
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -13,7 +13,7 @@ use rust_bert::{
|
|||||||
resources::LocalResource,
|
resources::LocalResource,
|
||||||
};
|
};
|
||||||
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
||||||
use tokio::sync::{Mutex, Semaphore};
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
||||||
@@ -32,6 +32,14 @@ lazy_static! {
|
|||||||
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
||||||
pub static ref ALPACA_API_SECRET: String =
|
pub static ref ALPACA_API_SECRET: String =
|
||||||
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
||||||
|
pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
|
||||||
|
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
|
||||||
|
.parse()
|
||||||
|
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
|
||||||
|
pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
|
||||||
|
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
|
||||||
|
.parse()
|
||||||
|
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
|
||||||
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
|
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
|
||||||
.expect("BERT_MAX_INPUTS must be set.")
|
.expect("BERT_MAX_INPUTS must be set.")
|
||||||
.parse()
|
.parse()
|
||||||
@@ -47,7 +55,7 @@ pub struct Config {
|
|||||||
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
||||||
pub clickhouse_client: clickhouse::Client,
|
pub clickhouse_client: clickhouse::Client,
|
||||||
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
||||||
pub sequence_classifier: Mutex<SequenceClassificationModel>,
|
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@@ -81,7 +89,7 @@ impl Config {
|
|||||||
)
|
)
|
||||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
||||||
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
||||||
sequence_classifier: Mutex::new(
|
sequence_classifier: std::sync::Mutex::new(
|
||||||
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
ModelResource::Torch(Box::new(LocalResource {
|
ModelResource::Torch(Box::new(LocalResource {
|
||||||
|
@@ -4,8 +4,6 @@ use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
|
|||||||
use clickhouse::Client;
|
use clickhouse::Client;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
pub const BATCH_FLUSH_SIZE: usize = 100_000;
|
|
||||||
|
|
||||||
upsert!(Bar, "bars");
|
upsert!(Bar, "bars");
|
||||||
upsert_batch!(Bar, "bars");
|
upsert_batch!(Bar, "bars");
|
||||||
delete_where_symbols!("bars");
|
delete_where_symbols!("bars");
|
||||||
|
@@ -5,8 +5,6 @@ use clickhouse::{error::Error, Client};
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
pub const BATCH_FLUSH_SIZE: usize = 500;
|
|
||||||
|
|
||||||
upsert!(News, "news");
|
upsert!(News, "news");
|
||||||
upsert_batch!(News, "news");
|
upsert_batch!(News, "news");
|
||||||
optimize!("news");
|
optimize!("news");
|
||||||
|
13
src/main.rs
13
src/main.rs
@@ -7,7 +7,10 @@ mod init;
|
|||||||
mod routes;
|
mod routes;
|
||||||
mod threads;
|
mod threads;
|
||||||
|
|
||||||
use config::Config;
|
use config::{
|
||||||
|
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE,
|
||||||
|
BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS,
|
||||||
|
};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use log4rs::config::Deserializers;
|
use log4rs::config::Deserializers;
|
||||||
use qrust::{create_send_await, database};
|
use qrust::{create_send_await, database};
|
||||||
@@ -19,6 +22,14 @@ async fn main() {
|
|||||||
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
|
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
|
||||||
let config = Config::arc_from_env();
|
let config = Config::arc_from_env();
|
||||||
|
|
||||||
|
let _ = *ALPACA_MODE;
|
||||||
|
let _ = *ALPACA_API_BASE;
|
||||||
|
let _ = *ALPACA_SOURCE;
|
||||||
|
let _ = *BATCH_BACKFILL_BARS_SIZE;
|
||||||
|
let _ = *BATCH_BACKFILL_NEWS_SIZE;
|
||||||
|
let _ = *BERT_MAX_INPUTS;
|
||||||
|
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
|
||||||
|
|
||||||
try_join!(
|
try_join!(
|
||||||
database::backfills_bars::unfresh(
|
database::backfills_bars::unfresh(
|
||||||
&config.clickhouse_client,
|
&config.clickhouse_client,
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
use super::Job;
|
use super::Job;
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{Config, ALPACA_SOURCE},
|
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE},
|
||||||
database,
|
database,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -116,12 +116,12 @@ impl super::Handler for Handler {
|
|||||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||||
|
|
||||||
info!("Backfilling bars for {:?}.", symbols);
|
let mut bars = Vec::with_capacity(*BATCH_BACKFILL_BARS_SIZE);
|
||||||
|
|
||||||
let mut bars = Vec::with_capacity(database::bars::BATCH_FLUSH_SIZE);
|
|
||||||
let mut last_times = HashMap::new();
|
let mut last_times = HashMap::new();
|
||||||
let mut next_page_token = None;
|
let mut next_page_token = None;
|
||||||
|
|
||||||
|
info!("Backfilling bars for {:?}.", symbols);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let message = alpaca::bars::get(
|
let message = alpaca::bars::get(
|
||||||
&self.config.alpaca_client,
|
&self.config.alpaca_client,
|
||||||
@@ -144,45 +144,47 @@ impl super::Handler for Handler {
|
|||||||
|
|
||||||
let message = message.unwrap();
|
let message = message.unwrap();
|
||||||
|
|
||||||
for (symbol, bar_vec) in message.bars {
|
for (symbol, bars_vec) in message.bars {
|
||||||
if let Some(last) = bar_vec.last() {
|
if let Some(last) = bars_vec.last() {
|
||||||
last_times.insert(symbol.clone(), last.time);
|
last_times.insert(symbol.clone(), last.time);
|
||||||
}
|
}
|
||||||
|
|
||||||
for bar in bar_vec {
|
for bar in bars_vec {
|
||||||
bars.push(Bar::from((bar, symbol.clone())));
|
bars.push(Bar::from((bar, symbol.clone())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
if bars.len() < *BATCH_BACKFILL_BARS_SIZE && message.next_page_token.is_some() {
|
||||||
database::bars::upsert_batch(
|
continue;
|
||||||
&self.config.clickhouse_client,
|
|
||||||
&self.config.clickhouse_concurrency_limiter,
|
|
||||||
&bars,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let backfilled = last_times
|
|
||||||
.drain()
|
|
||||||
.map(|(symbol, time)| Backfill { symbol, time })
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
database::backfills_bars::upsert_batch(
|
|
||||||
&self.config.clickhouse_client,
|
|
||||||
&self.config.clickhouse_concurrency_limiter,
|
|
||||||
&backfilled,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if message.next_page_token.is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
next_page_token = message.next_page_token;
|
|
||||||
bars.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
database::bars::upsert_batch(
|
||||||
|
&self.config.clickhouse_client,
|
||||||
|
&self.config.clickhouse_concurrency_limiter,
|
||||||
|
&bars,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let backfilled = last_times
|
||||||
|
.drain()
|
||||||
|
.map(|(symbol, time)| Backfill { symbol, time })
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
database::backfills_bars::upsert_batch(
|
||||||
|
&self.config.clickhouse_client,
|
||||||
|
&self.config.clickhouse_concurrency_limiter,
|
||||||
|
&backfilled,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if message.next_page_token.is_none() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
next_page_token = message.next_page_token;
|
||||||
|
bars.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Backfilled bars for {:?}.", symbols);
|
info!("Backfilled bars for {:?}.", symbols);
|
||||||
|
@@ -1,10 +1,9 @@
|
|||||||
use super::Job;
|
use super::Job;
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
|
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS},
|
||||||
database,
|
database,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures_util::future::join_all;
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use qrust::{
|
use qrust::{
|
||||||
alpaca,
|
alpaca,
|
||||||
@@ -16,7 +15,10 @@ use qrust::{
|
|||||||
},
|
},
|
||||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||||
};
|
};
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
use tokio::{task::block_in_place, time::sleep};
|
use tokio::{task::block_in_place, time::sleep};
|
||||||
|
|
||||||
pub struct Handler {
|
pub struct Handler {
|
||||||
@@ -68,22 +70,26 @@ impl super::Handler for Handler {
|
|||||||
sleep(run_delay).await;
|
sleep(run_delay).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
#[allow(clippy::iter_with_drain)]
|
||||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||||
if jobs.is_empty() {
|
if jobs.is_empty() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||||
let symbols_set = symbols.iter().collect::<std::collections::HashSet<_>>();
|
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
|
||||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||||
|
|
||||||
info!("Backfilling news for {:?}.", symbols);
|
let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE);
|
||||||
|
let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||||
let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE);
|
let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||||
let mut last_times = HashMap::new();
|
let mut last_times = HashMap::new();
|
||||||
let mut next_page_token = None;
|
let mut next_page_token = None;
|
||||||
|
|
||||||
|
info!("Backfilling news for {:?}.", symbols);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let message = alpaca::news::get(
|
let message = alpaca::news::get(
|
||||||
&self.config.alpaca_client,
|
&self.config.alpaca_client,
|
||||||
@@ -116,70 +122,77 @@ impl super::Handler for Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
news.push(news_item);
|
batch.push(news_item);
|
||||||
}
|
}
|
||||||
|
|
||||||
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
|
if batch.len() < *BERT_MAX_INPUTS
|
||||||
let inputs = news
|
&& batch.len() < *BATCH_BACKFILL_NEWS_SIZE
|
||||||
.iter()
|
&& message.next_page_token.is_some()
|
||||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
{
|
||||||
.collect::<Vec<_>>();
|
continue;
|
||||||
|
|
||||||
let predictions =
|
|
||||||
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
|
|
||||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
|
||||||
block_in_place(|| {
|
|
||||||
sequence_classifier
|
|
||||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
|
||||||
.into_iter()
|
|
||||||
.map(|label| Prediction::try_from(label).unwrap())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
.await
|
|
||||||
.into_iter()
|
|
||||||
.flatten();
|
|
||||||
|
|
||||||
news = news
|
|
||||||
.into_iter()
|
|
||||||
.zip(predictions)
|
|
||||||
.map(|(news, prediction)| News {
|
|
||||||
sentiment: prediction.sentiment,
|
|
||||||
confidence: prediction.confidence,
|
|
||||||
..news
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
let inputs = batch
|
||||||
database::news::upsert_batch(
|
.iter()
|
||||||
&self.config.clickhouse_client,
|
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||||
&self.config.clickhouse_concurrency_limiter,
|
.collect::<Vec<_>>();
|
||||||
&news,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let backfilled = last_times
|
for chunk in inputs.chunks(*BERT_MAX_INPUTS) {
|
||||||
.drain()
|
let chunk_predictions = block_in_place(|| {
|
||||||
.map(|(symbol, time)| Backfill { symbol, time })
|
self.config
|
||||||
.collect::<Vec<_>>();
|
.sequence_classifier
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.predict(chunk.iter().map(String::as_str).collect::<Vec<_>>())
|
||||||
|
.into_iter()
|
||||||
|
.map(|label| Prediction::try_from(label).unwrap())
|
||||||
|
});
|
||||||
|
|
||||||
database::backfills_news::upsert_batch(
|
predictions.extend(chunk_predictions);
|
||||||
&self.config.clickhouse_client,
|
|
||||||
&self.config.clickhouse_concurrency_limiter,
|
|
||||||
&backfilled,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if message.next_page_token.is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
next_page_token = message.next_page_token;
|
|
||||||
news.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let zipped = batch
|
||||||
|
.drain(..)
|
||||||
|
.zip(predictions.drain(..))
|
||||||
|
.map(|(news, prediction)| News {
|
||||||
|
sentiment: prediction.sentiment,
|
||||||
|
confidence: prediction.confidence,
|
||||||
|
..news
|
||||||
|
});
|
||||||
|
|
||||||
|
news.extend(zipped);
|
||||||
|
|
||||||
|
if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
database::news::upsert_batch(
|
||||||
|
&self.config.clickhouse_client,
|
||||||
|
&self.config.clickhouse_concurrency_limiter,
|
||||||
|
&news,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let backfilled = last_times
|
||||||
|
.drain()
|
||||||
|
.map(|(symbol, time)| Backfill { symbol, time })
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
database::backfills_news::upsert_batch(
|
||||||
|
&self.config.clickhouse_client,
|
||||||
|
&self.config.clickhouse_concurrency_limiter,
|
||||||
|
&backfilled,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if message.next_page_token.is_none() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
next_page_token = message.next_page_token;
|
||||||
|
news.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Backfilled news for {:?}.", symbols);
|
info!("Backfilled news for {:?}.", symbols);
|
||||||
|
@@ -82,15 +82,16 @@ impl super::Handler for Handler {
|
|||||||
|
|
||||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||||
|
|
||||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
|
||||||
let prediction = block_in_place(|| {
|
let prediction = block_in_place(|| {
|
||||||
sequence_classifier
|
self.config
|
||||||
|
.sequence_classifier
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
.predict(vec![input.as_str()])
|
.predict(vec![input.as_str()])
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|label| Prediction::try_from(label).unwrap())
|
.map(|label| Prediction::try_from(label).unwrap())
|
||||||
.collect::<Vec<_>>()[0]
|
.collect::<Vec<_>>()[0]
|
||||||
});
|
});
|
||||||
drop(sequence_classifier);
|
|
||||||
|
|
||||||
let news = News {
|
let news = News {
|
||||||
sentiment: prediction.sentiment,
|
sentiment: prediction.sentiment,
|
||||||
|
Reference in New Issue
Block a user