Separate data management code

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-10 16:59:21 +00:00
parent acfc0ca4c9
commit a15fd2c3c9
9 changed files with 1265 additions and 1190 deletions

View File

@@ -0,0 +1,114 @@
use super::Pending;
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, News},
};
use async_trait::async_trait;
use log::{debug, error, info};
use std::{collections::HashMap, sync::Arc};
use tokio::{sync::RwLock, task::block_in_place};
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}