Remove asset_status thread
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -1,51 +1,192 @@
|
||||
use super::{backfill, Guard};
|
||||
use super::ThreadType;
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
|
||||
types::{alpaca::websocket, news::Prediction, Bar, News},
|
||||
utils::add_slash_to_pair,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{
|
||||
future::join_all,
|
||||
stream::{SplitSink, SplitStream},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use log::{debug, error, info, warn};
|
||||
use serde_json::from_str;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
use log::{debug, error, info};
|
||||
use serde_json::{from_str, to_string};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{
|
||||
join,
|
||||
net::TcpStream,
|
||||
spawn,
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
select, spawn,
|
||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||
task::block_in_place,
|
||||
};
|
||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
pub async fn run(
|
||||
app_config: Arc<Config>,
|
||||
guard: Arc<RwLock<Guard>>,
|
||||
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
|
||||
mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
) {
|
||||
loop {
|
||||
let message = receiver.next().await.unwrap().unwrap();
|
||||
pub enum Action {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
||||
|
||||
spawn(handle_websocket_message(
|
||||
app_config.clone(),
|
||||
guard.clone(),
|
||||
sender.clone(),
|
||||
backfill_sender.clone(),
|
||||
message,
|
||||
));
|
||||
impl From<super::Action> for Action {
|
||||
fn from(action: super::Action) -> Self {
|
||||
match action {
|
||||
super::Action::Add => Self::Subscribe,
|
||||
super::Action::Remove => Self::Unsubscribe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Action,
|
||||
pub symbols: Vec<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pending {
|
||||
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::outgoing::subscribe::Message;
|
||||
async fn handle_parsed_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::incoming::Message,
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
mut receiver: mpsc::Receiver<Message>,
|
||||
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
|
||||
) {
|
||||
let pending = Arc::new(RwLock::new(Pending {
|
||||
subscriptions: HashMap::new(),
|
||||
unsubscriptions: HashMap::new(),
|
||||
}));
|
||||
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = receiver.recv() => {
|
||||
spawn(handle_message(
|
||||
handler.clone(),
|
||||
pending.clone(),
|
||||
websocket_sink.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
Some(Ok(message)) = websocket_stream.next() => {
|
||||
spawn(handle_websocket_message(
|
||||
handler.clone(),
|
||||
pending.clone(),
|
||||
websocket_sink.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
else => panic!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
websocket_sender: Arc<
|
||||
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
|
||||
>,
|
||||
message: Message,
|
||||
) {
|
||||
match message.action {
|
||||
Action::Subscribe => {
|
||||
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.subscriptions
|
||||
.extend(pending_subscriptions);
|
||||
|
||||
websocket_sender
|
||||
.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::outgoing::Message::Subscribe(
|
||||
handler.create_subscription_message(message.symbols),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
Action::Unsubscribe => {
|
||||
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.unsubscriptions
|
||||
.extend(pending_unsubscriptions);
|
||||
|
||||
websocket_sender
|
||||
.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::outgoing::Message::Unsubscribe(
|
||||
handler.create_subscription_message(message.symbols.clone()),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
app_config: Arc<Config>,
|
||||
guard: Arc<RwLock<Guard>>,
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
|
||||
backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
message: tungstenite::Message,
|
||||
) {
|
||||
match message {
|
||||
@@ -54,12 +195,14 @@ async fn handle_websocket_message(
|
||||
|
||||
if let Ok(message) = message {
|
||||
for message in message {
|
||||
spawn(handle_parsed_websocket_message(
|
||||
app_config.clone(),
|
||||
guard.clone(),
|
||||
backfill_sender.clone(),
|
||||
message,
|
||||
));
|
||||
let handler = handler.clone();
|
||||
let pending = pending.clone();
|
||||
|
||||
spawn(async move {
|
||||
handler
|
||||
.handle_parsed_websocket_message(pending, message)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
} else {
|
||||
error!("Failed to deserialize websocket message: {:?}", message);
|
||||
@@ -77,143 +220,190 @@ async fn handle_websocket_message(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::significant_drop_tightening)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn handle_parsed_websocket_message(
|
||||
struct BarsHandler {
|
||||
app_config: Arc<Config>,
|
||||
guard: Arc<RwLock<Guard>>,
|
||||
backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
message: websocket::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::incoming::Message::Subscription(message) => {
|
||||
let (symbols, log_string) = match message {
|
||||
websocket::incoming::subscription::Message::Market { bars, .. } => (bars, "bars"),
|
||||
websocket::incoming::subscription::Message::News { news } => (
|
||||
news.into_iter()
|
||||
.map(|symbol| add_slash_to_pair(&symbol))
|
||||
.collect(),
|
||||
"news",
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
let mut guard = guard.write().await;
|
||||
#[async_trait]
|
||||
impl Handler for BarsHandler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::outgoing::subscribe::Message {
|
||||
websocket::outgoing::subscribe::Message::new_market(symbols)
|
||||
}
|
||||
|
||||
let newly_subscribed = guard
|
||||
.pending_subscriptions
|
||||
.extract_if(|asset| symbols.contains(&asset.symbol))
|
||||
.collect::<HashSet<_>>();
|
||||
async fn handle_parsed_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::incoming::Message::Subscription(message) => {
|
||||
let websocket::incoming::subscription::Message::Market { bars: symbols, .. } =
|
||||
message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let newly_unsubscribed = guard
|
||||
.pending_unsubscriptions
|
||||
.extract_if(|asset| !symbols.contains(&asset.symbol))
|
||||
.collect::<HashSet<_>>();
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
drop(guard);
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
let newly_subscribed_future = async {
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to {} for {:?}.",
|
||||
log_string,
|
||||
newly_subscribed
|
||||
.iter()
|
||||
.map(|asset| asset.symbol.clone())
|
||||
.collect::<Vec<_>>()
|
||||
"Subscribed to bars for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let (backfill_message, backfill_receiver) = backfill::Message::new(
|
||||
backfill::Action::Backfill,
|
||||
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
|
||||
);
|
||||
|
||||
backfill_sender.send(backfill_message).await.unwrap();
|
||||
backfill_receiver.await.unwrap();
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let newly_unsubscribed_future = async {
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from {} for {:?}.",
|
||||
log_string,
|
||||
newly_unsubscribed
|
||||
.iter()
|
||||
.map(|asset| asset.symbol.clone())
|
||||
.collect::<Vec<_>>()
|
||||
"Unsubscribed from bars for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let (purge_message, purge_receiver) = backfill::Message::new(
|
||||
backfill::Action::Purge,
|
||||
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
|
||||
);
|
||||
|
||||
backfill_sender.send(purge_message).await.unwrap();
|
||||
purge_receiver.await.unwrap();
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
join!(newly_subscribed_future, newly_unsubscribed_future);
|
||||
}
|
||||
websocket::incoming::Message::Bar(message)
|
||||
| websocket::incoming::Message::UpdatedBar(message) => {
|
||||
let bar = Bar::from(message);
|
||||
|
||||
let guard = guard.read().await;
|
||||
if !guard.assets.contains_right(&bar.symbol) {
|
||||
warn!(
|
||||
"Race condition: received bar for unsubscribed symbol: {:?}.",
|
||||
bar.symbol
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
||||
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
|
||||
}
|
||||
websocket::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
|
||||
let guard = guard.read().await;
|
||||
if !news
|
||||
.symbols
|
||||
.iter()
|
||||
.any(|symbol| guard.assets.contains_right(symbol))
|
||||
{
|
||||
warn!(
|
||||
"Race condition: received news for unsubscribed symbols: {:?}.",
|
||||
news.symbols
|
||||
);
|
||||
return;
|
||||
websocket::incoming::Message::Bar(message)
|
||||
| websocket::incoming::Message::UpdatedBar(message) => {
|
||||
let bar = Bar::from(message);
|
||||
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
||||
database::bars::upsert(&self.app_config.clickhouse_client, &bar).await;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = app_config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
database::news::upsert(&app_config.clickhouse_client, &news).await;
|
||||
}
|
||||
websocket::incoming::Message::Success(_) => {}
|
||||
websocket::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
websocket::incoming::Message::Success(_) => {}
|
||||
websocket::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
websocket::incoming::Message::News(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NewsHandler {
|
||||
app_config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for NewsHandler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::outgoing::subscribe::Message {
|
||||
websocket::outgoing::subscribe::Message::new_news(symbols)
|
||||
}
|
||||
|
||||
async fn handle_parsed_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::incoming::Message::Subscription(message) => {
|
||||
let websocket::incoming::subscription::Message::News { news: symbols } = message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let symbols = symbols
|
||||
.into_iter()
|
||||
.map(|symbol| add_slash_to_pair(&symbol))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to news for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from news for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = self.app_config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
database::news::upsert(&self.app_config.clickhouse_client, &news).await;
|
||||
}
|
||||
websocket::incoming::Message::Success(_) => {}
|
||||
websocket::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
websocket::incoming::Message::Bar(_) | websocket::incoming::Message::UpdatedBar(_) => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(thread_type: ThreadType, app_config: Arc<Config>) -> Box<dyn Handler> {
|
||||
match thread_type {
|
||||
ThreadType::Bars(_) => Box::new(BarsHandler { app_config }),
|
||||
ThreadType::News => Box::new(NewsHandler { app_config }),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user