239 lines
7.9 KiB
Rust
239 lines
7.9 KiB
Rust
use super::{backfill, Guard, ThreadType};
|
|
use crate::{
|
|
config::Config,
|
|
database,
|
|
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
|
|
utils::add_slash_to_pair,
|
|
};
|
|
use futures_util::{
|
|
stream::{SplitSink, SplitStream},
|
|
SinkExt, StreamExt,
|
|
};
|
|
use log::{error, info, warn};
|
|
use serde_json::from_str;
|
|
use std::{collections::HashSet, sync::Arc};
|
|
use tokio::{
|
|
join,
|
|
net::TcpStream,
|
|
spawn,
|
|
sync::{mpsc, Mutex, RwLock},
|
|
task::block_in_place,
|
|
};
|
|
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
|
|
|
pub async fn run(
|
|
app_config: Arc<Config>,
|
|
thread_type: ThreadType,
|
|
guard: Arc<RwLock<Guard>>,
|
|
websocket_sender: Arc<
|
|
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
|
|
>,
|
|
mut websocket_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
|
backfill_sender: mpsc::Sender<backfill::Message>,
|
|
) {
|
|
loop {
|
|
let message = websocket_receiver.next().await.unwrap().unwrap();
|
|
|
|
spawn(handle_websocket_message(
|
|
app_config.clone(),
|
|
thread_type,
|
|
guard.clone(),
|
|
websocket_sender.clone(),
|
|
backfill_sender.clone(),
|
|
message,
|
|
));
|
|
}
|
|
}
|
|
|
|
async fn handle_websocket_message(
|
|
app_config: Arc<Config>,
|
|
thread_type: ThreadType,
|
|
guard: Arc<RwLock<Guard>>,
|
|
websocket_sender: Arc<
|
|
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
|
|
>,
|
|
backfill_sender: mpsc::Sender<backfill::Message>,
|
|
message: tungstenite::Message,
|
|
) {
|
|
match message {
|
|
tungstenite::Message::Text(message) => {
|
|
let message = from_str::<Vec<websocket::incoming::Message>>(&message);
|
|
|
|
if let Ok(message) = message {
|
|
for message in message {
|
|
spawn(handle_parsed_websocket_message(
|
|
app_config.clone(),
|
|
thread_type,
|
|
guard.clone(),
|
|
backfill_sender.clone(),
|
|
message,
|
|
));
|
|
}
|
|
} else {
|
|
error!(
|
|
"{:?} - Failed to deserialize websocket message: {:?}",
|
|
thread_type, message
|
|
);
|
|
}
|
|
}
|
|
tungstenite::Message::Ping(_) => {
|
|
websocket_sender
|
|
.lock()
|
|
.await
|
|
.send(tungstenite::Message::Pong(vec![]))
|
|
.await
|
|
.unwrap();
|
|
}
|
|
_ => error!(
|
|
"{:?} - Unexpected websocket message: {:?}",
|
|
thread_type, message
|
|
),
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::significant_drop_tightening)]
|
|
#[allow(clippy::too_many_lines)]
|
|
async fn handle_parsed_websocket_message(
|
|
app_config: Arc<Config>,
|
|
thread_type: ThreadType,
|
|
guard: Arc<RwLock<Guard>>,
|
|
backfill_sender: mpsc::Sender<backfill::Message>,
|
|
message: websocket::incoming::Message,
|
|
) {
|
|
match message {
|
|
websocket::incoming::Message::Subscription(message) => {
|
|
let symbols = match message {
|
|
websocket::incoming::subscription::Message::Market { bars, .. } => bars,
|
|
websocket::incoming::subscription::Message::News { news } => news
|
|
.into_iter()
|
|
.map(|symbol| add_slash_to_pair(&symbol))
|
|
.collect(),
|
|
};
|
|
|
|
let mut guard = guard.write().await;
|
|
|
|
let newly_subscribed = guard
|
|
.pending_subscriptions
|
|
.extract_if(|asset| symbols.contains(&asset.symbol))
|
|
.collect::<HashSet<_>>();
|
|
|
|
let newly_unsubscribed = guard
|
|
.pending_unsubscriptions
|
|
.extract_if(|asset| !symbols.contains(&asset.symbol))
|
|
.collect::<HashSet<_>>();
|
|
|
|
drop(guard);
|
|
|
|
let newly_subscribed_future = async {
|
|
if !newly_subscribed.is_empty() {
|
|
info!(
|
|
"{:?} - Subscribed to {:?}.",
|
|
thread_type,
|
|
newly_subscribed
|
|
.iter()
|
|
.map(|asset| asset.symbol.clone())
|
|
.collect::<Vec<_>>()
|
|
);
|
|
|
|
let (backfill_message, backfill_receiver) = backfill::Message::new(
|
|
backfill::Action::Backfill,
|
|
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
|
|
);
|
|
|
|
backfill_sender.send(backfill_message).await.unwrap();
|
|
backfill_receiver.await.unwrap();
|
|
}
|
|
};
|
|
|
|
let newly_unsubscribed_future = async {
|
|
if !newly_unsubscribed.is_empty() {
|
|
info!(
|
|
"{:?} - Unsubscribed from {:?}.",
|
|
thread_type,
|
|
newly_unsubscribed
|
|
.iter()
|
|
.map(|asset| asset.symbol.clone())
|
|
.collect::<Vec<_>>()
|
|
);
|
|
|
|
let (purge_message, purge_receiver) = backfill::Message::new(
|
|
backfill::Action::Purge,
|
|
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
|
|
);
|
|
|
|
backfill_sender.send(purge_message).await.unwrap();
|
|
purge_receiver.await.unwrap();
|
|
}
|
|
};
|
|
|
|
join!(newly_subscribed_future, newly_unsubscribed_future);
|
|
}
|
|
websocket::incoming::Message::Bar(message)
|
|
| websocket::incoming::Message::UpdatedBar(message) => {
|
|
let bar = Bar::from(message);
|
|
|
|
let guard = guard.read().await;
|
|
if !guard.assets.contains_right(&bar.symbol) {
|
|
warn!(
|
|
"{:?} - Race condition: received bar for unsubscribed symbol: {:?}.",
|
|
thread_type, bar.symbol
|
|
);
|
|
return;
|
|
}
|
|
|
|
info!(
|
|
"{:?} - Received bar for {}: {}.",
|
|
thread_type, bar.symbol, bar.time
|
|
);
|
|
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
|
|
}
|
|
websocket::incoming::Message::News(message) => {
|
|
let news = News::from(message);
|
|
|
|
let guard = guard.read().await;
|
|
if !news
|
|
.symbols
|
|
.iter()
|
|
.any(|symbol| guard.assets.contains_right(symbol))
|
|
{
|
|
warn!(
|
|
"{:?} - Race condition: received news for unsubscribed symbols: {:?}.",
|
|
thread_type, news.symbols
|
|
);
|
|
return;
|
|
}
|
|
|
|
info!(
|
|
"{:?} - Received news for {:?}: {}.",
|
|
thread_type, news.symbols, news.time_created
|
|
);
|
|
|
|
let input = format!("{}\n\n{}", news.headline, news.content);
|
|
|
|
let sequence_classifier = app_config.sequence_classifier.lock().await;
|
|
let prediction = block_in_place(|| {
|
|
sequence_classifier
|
|
.predict(vec![input.as_str()])
|
|
.into_iter()
|
|
.map(|label| Prediction::try_from(label).unwrap())
|
|
.collect::<Vec<_>>()[0]
|
|
});
|
|
drop(sequence_classifier);
|
|
|
|
let news = News {
|
|
sentiment: prediction.sentiment,
|
|
confidence: prediction.confidence,
|
|
..news
|
|
};
|
|
database::news::upsert(&app_config.clickhouse_client, &news).await;
|
|
}
|
|
websocket::incoming::Message::Success(_) => {}
|
|
websocket::incoming::Message::Error(message) => {
|
|
error!(
|
|
"{:?} - Received error message: {}.",
|
|
thread_type, message.message
|
|
);
|
|
}
|
|
}
|
|
}
|