Add finbert sentiment analysis
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -4,18 +4,19 @@ use crate::{
|
||||
database,
|
||||
types::{
|
||||
alpaca::{api, Source},
|
||||
news::Prediction,
|
||||
Asset, Bar, Class, News, Subset,
|
||||
},
|
||||
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use backoff::{future::retry, ExponentialBackoff};
|
||||
use log::{error, info};
|
||||
use log::{error, info, warn};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{
|
||||
join, spawn,
|
||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||
task::JoinHandle,
|
||||
task::{spawn_blocking, JoinHandle},
|
||||
time::sleep,
|
||||
};
|
||||
|
||||
@@ -103,11 +104,14 @@ async fn handle_backfill_message(
|
||||
match message.action {
|
||||
Action::Backfill => {
|
||||
for symbol in symbols {
|
||||
if let Some(job) = backfill_jobs.remove(&symbol) {
|
||||
if let Some(job) = backfill_jobs.get(&symbol) {
|
||||
if !job.is_finished() {
|
||||
job.abort();
|
||||
warn!(
|
||||
"{:?} - Backfill for {} is already running, skipping.",
|
||||
thread_type, symbol
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let _ = job.await;
|
||||
}
|
||||
|
||||
let app_config = app_config.clone();
|
||||
@@ -361,7 +365,41 @@ async fn execute_backfill_news(
|
||||
return;
|
||||
}
|
||||
|
||||
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
|
||||
let app_config_clone = app_config.clone();
|
||||
let inputs = news
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predictions: Vec<Prediction> = spawn_blocking(move || {
|
||||
inputs
|
||||
.chunks(app_config_clone.max_bert_inputs)
|
||||
.flat_map(|inputs| {
|
||||
app_config_clone
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let news = news
|
||||
.into_iter()
|
||||
.zip(predictions.into_iter())
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let backfill = (news[0].clone(), symbol.clone()).into();
|
||||
database::news::upsert_batch(&app_config.clickhouse_client, news).await;
|
||||
database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await;
|
||||
|
||||
|
@@ -27,16 +27,6 @@ pub struct Guard {
|
||||
pub pending_unsubscriptions: HashMap<String, Asset>,
|
||||
}
|
||||
|
||||
impl Guard {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
symbols: HashSet::new(),
|
||||
pending_subscriptions: HashMap::new(),
|
||||
pending_unsubscriptions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum ThreadType {
|
||||
Bars(Class),
|
||||
@@ -86,7 +76,11 @@ async fn init_thread(
|
||||
mpsc::Sender<asset_status::Message>,
|
||||
mpsc::Sender<backfill::Message>,
|
||||
) {
|
||||
let guard = Arc::new(RwLock::new(Guard::new()));
|
||||
let guard = Arc::new(RwLock::new(Guard {
|
||||
symbols: HashSet::new(),
|
||||
pending_subscriptions: HashMap::new(),
|
||||
pending_unsubscriptions: HashMap::new(),
|
||||
}));
|
||||
|
||||
let websocket_url = match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => format!(
|
||||
|
@@ -2,7 +2,7 @@ use super::{backfill, Guard, ThreadType};
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, Bar, News, Subset},
|
||||
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
|
||||
};
|
||||
use futures_util::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
@@ -19,6 +19,7 @@ use tokio::{
|
||||
net::TcpStream,
|
||||
spawn,
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
task::spawn_blocking,
|
||||
};
|
||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
@@ -93,6 +94,7 @@ async fn handle_websocket_message(
|
||||
}
|
||||
|
||||
#[allow(clippy::significant_drop_tightening)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn handle_parsed_websocket_message(
|
||||
app_config: Arc<Config>,
|
||||
thread_type: ThreadType,
|
||||
@@ -195,6 +197,28 @@ async fn handle_parsed_websocket_message(
|
||||
"{:?} - Received news for {:?}: {}.",
|
||||
thread_type, news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let app_config_clone = app_config.clone();
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let prediction = spawn_blocking(move || {
|
||||
app_config_clone
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
database::news::upsert(&app_config.clickhouse_client, &news).await;
|
||||
}
|
||||
websocket::incoming::Message::Success(_) => {}
|
||||
|
Reference in New Issue
Block a user