Files
qrust/src/threads/data/websocket.rs
2024-02-05 00:30:32 +00:00

239 lines
7.9 KiB
Rust

use super::{backfill, Guard, ThreadType};
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
utils::add_slash_to_pair,
};
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::{error, info, warn};
use serde_json::from_str;
use std::{collections::HashSet, sync::Arc};
use tokio::{
join,
net::TcpStream,
spawn,
sync::{mpsc, Mutex, RwLock},
task::block_in_place,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run(
app_config: Arc<Config>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
mut websocket_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
backfill_sender: mpsc::Sender<backfill::Message>,
) {
loop {
let message = websocket_receiver.next().await.unwrap().unwrap();
spawn(handle_websocket_message(
app_config.clone(),
thread_type,
guard.clone(),
websocket_sender.clone(),
backfill_sender.clone(),
message,
));
}
}
async fn handle_websocket_message(
app_config: Arc<Config>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
backfill_sender: mpsc::Sender<backfill::Message>,
message: tungstenite::Message,
) {
match message {
tungstenite::Message::Text(message) => {
let message = from_str::<Vec<websocket::incoming::Message>>(&message);
if let Ok(message) = message {
for message in message {
spawn(handle_parsed_websocket_message(
app_config.clone(),
thread_type,
guard.clone(),
backfill_sender.clone(),
message,
));
}
} else {
error!(
"{:?} - Failed to deserialize websocket message: {:?}",
thread_type, message
);
}
}
tungstenite::Message::Ping(_) => {
websocket_sender
.lock()
.await
.send(tungstenite::Message::Pong(vec![]))
.await
.unwrap();
}
_ => error!(
"{:?} - Unexpected websocket message: {:?}",
thread_type, message
),
}
}
#[allow(clippy::significant_drop_tightening)]
#[allow(clippy::too_many_lines)]
async fn handle_parsed_websocket_message(
app_config: Arc<Config>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
backfill_sender: mpsc::Sender<backfill::Message>,
message: websocket::incoming::Message,
) {
match message {
websocket::incoming::Message::Subscription(message) => {
let symbols = match message {
websocket::incoming::subscription::Message::Market { bars, .. } => bars,
websocket::incoming::subscription::Message::News { news } => news
.into_iter()
.map(|symbol| add_slash_to_pair(&symbol))
.collect(),
};
let mut guard = guard.write().await;
let newly_subscribed = guard
.pending_subscriptions
.extract_if(|asset| symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
let newly_unsubscribed = guard
.pending_unsubscriptions
.extract_if(|asset| !symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
drop(guard);
let newly_subscribed_future = async {
if !newly_subscribed.is_empty() {
info!(
"{:?} - Subscribed to {:?}.",
thread_type,
newly_subscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
let (backfill_message, backfill_receiver) = backfill::Message::new(
backfill::Action::Backfill,
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(backfill_message).await.unwrap();
backfill_receiver.await.unwrap();
}
};
let newly_unsubscribed_future = async {
if !newly_unsubscribed.is_empty() {
info!(
"{:?} - Unsubscribed from {:?}.",
thread_type,
newly_unsubscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
let (purge_message, purge_receiver) = backfill::Message::new(
backfill::Action::Purge,
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(purge_message).await.unwrap();
purge_receiver.await.unwrap();
}
};
join!(newly_subscribed_future, newly_unsubscribed_future);
}
websocket::incoming::Message::Bar(message)
| websocket::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
let guard = guard.read().await;
if !guard.assets.contains_right(&bar.symbol) {
warn!(
"{:?} - Race condition: received bar for unsubscribed symbol: {:?}.",
thread_type, bar.symbol
);
return;
}
info!(
"{:?} - Received bar for {}: {}.",
thread_type, bar.symbol, bar.time
);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
}
websocket::incoming::Message::News(message) => {
let news = News::from(message);
let guard = guard.read().await;
if !news
.symbols
.iter()
.any(|symbol| guard.assets.contains_right(symbol))
{
warn!(
"{:?} - Race condition: received news for unsubscribed symbols: {:?}.",
thread_type, news.symbols
);
return;
}
info!(
"{:?} - Received news for {:?}: {}.",
thread_type, news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = app_config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&app_config.clickhouse_client, &news).await;
}
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
error!(
"{:?} - Received error message: {}.",
thread_type, message.message
);
}
}
}