Separate data management code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
114
src/threads/data/websocket/news.rs
Normal file
114
src/threads/data/websocket/news.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use super::Pending;
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, news::Prediction, News},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, error, info};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{sync::RwLock, task::block_in_place};
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
websocket::data::outgoing::subscribe::Message::new_news(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::News { news: symbols } =
|
||||
message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to news for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from news for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
|
||||
database::news::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user