Add finbert sentiment analysis

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-03 18:58:40 +00:00
parent 973917dad2
commit 65c9ae8b25
26 changed files with 31460 additions and 215 deletions

View File

@@ -4,18 +4,19 @@ use crate::{
database,
types::{
alpaca::{api, Source},
news::Prediction,
Asset, Bar, Class, News, Subset,
},
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
};
use backoff::{future::retry, ExponentialBackoff};
use log::{error, info};
use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::{
join, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::JoinHandle,
task::{spawn_blocking, JoinHandle},
time::sleep,
};
@@ -103,11 +104,14 @@ async fn handle_backfill_message(
match message.action {
Action::Backfill => {
for symbol in symbols {
if let Some(job) = backfill_jobs.remove(&symbol) {
if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() {
job.abort();
warn!(
"{:?} - Backfill for {} is already running, skipping.",
thread_type, symbol
);
continue;
}
let _ = job.await;
}
let app_config = app_config.clone();
@@ -361,7 +365,41 @@ async fn execute_backfill_news(
return;
}
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
let app_config_clone = app_config.clone();
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions: Vec<Prediction> = spawn_blocking(move || {
inputs
.chunks(app_config_clone.max_bert_inputs)
.flat_map(|inputs| {
app_config_clone
.sequence_classifier
.lock()
.unwrap()
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
.collect()
})
.await
.unwrap();
let news = news
.into_iter()
.zip(predictions.into_iter())
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
let backfill = (news[0].clone(), symbol.clone()).into();
database::news::upsert_batch(&app_config.clickhouse_client, news).await;
database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await;