Remove asset_status thread

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-07 20:40:11 +00:00
parent 85eef2bf0b
commit 52e88f4bc9
23 changed files with 796 additions and 774 deletions

View File

@@ -1,51 +1,192 @@
use super::{backfill, Guard};
use super::ThreadType;
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
types::{alpaca::websocket, news::Prediction, Bar, News},
utils::add_slash_to_pair,
};
use async_trait::async_trait;
use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::{debug, error, info, warn};
use serde_json::from_str;
use std::{collections::HashSet, sync::Arc};
use log::{debug, error, info};
use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc};
use tokio::{
join,
net::TcpStream,
spawn,
sync::{mpsc, Mutex, RwLock},
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::block_in_place,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run(
app_config: Arc<Config>,
guard: Arc<RwLock<Guard>>,
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
backfill_sender: mpsc::Sender<backfill::Message>,
) {
loop {
let message = receiver.next().await.unwrap().unwrap();
pub enum Action {
Subscribe,
Unsubscribe,
}
spawn(handle_websocket_message(
app_config.clone(),
guard.clone(),
sender.clone(),
backfill_sender.clone(),
message,
));
impl From<super::Action> for Action {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add => Self::Subscribe,
super::Action::Remove => Self::Unsubscribe,
}
}
}
pub struct Message {
pub action: Action,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::outgoing::subscribe::Message;
async fn handle_parsed_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::incoming::Message,
);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
spawn(handle_websocket_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
message: Message,
) {
match message.action {
Action::Subscribe => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.subscriptions
.extend(pending_subscriptions);
websocket_sender
.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Action::Unsubscribe => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.unsubscriptions
.extend(pending_unsubscriptions);
websocket_sender
.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
}
message.response.send(()).unwrap();
}
async fn handle_websocket_message(
app_config: Arc<Config>,
guard: Arc<RwLock<Guard>>,
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
backfill_sender: mpsc::Sender<backfill::Message>,
message: tungstenite::Message,
) {
match message {
@@ -54,12 +195,14 @@ async fn handle_websocket_message(
if let Ok(message) = message {
for message in message {
spawn(handle_parsed_websocket_message(
app_config.clone(),
guard.clone(),
backfill_sender.clone(),
message,
));
let handler = handler.clone();
let pending = pending.clone();
spawn(async move {
handler
.handle_parsed_websocket_message(pending, message)
.await;
});
}
} else {
error!("Failed to deserialize websocket message: {:?}", message);
@@ -77,143 +220,190 @@ async fn handle_websocket_message(
}
}
#[allow(clippy::significant_drop_tightening)]
#[allow(clippy::too_many_lines)]
async fn handle_parsed_websocket_message(
struct BarsHandler {
app_config: Arc<Config>,
guard: Arc<RwLock<Guard>>,
backfill_sender: mpsc::Sender<backfill::Message>,
message: websocket::incoming::Message,
) {
match message {
websocket::incoming::Message::Subscription(message) => {
let (symbols, log_string) = match message {
websocket::incoming::subscription::Message::Market { bars, .. } => (bars, "bars"),
websocket::incoming::subscription::Message::News { news } => (
news.into_iter()
.map(|symbol| add_slash_to_pair(&symbol))
.collect(),
"news",
),
};
}
let mut guard = guard.write().await;
#[async_trait]
impl Handler for BarsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::outgoing::subscribe::Message {
websocket::outgoing::subscribe::Message::new_market(symbols)
}
let newly_subscribed = guard
.pending_subscriptions
.extract_if(|asset| symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
async fn handle_parsed_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::incoming::Message,
) {
match message {
websocket::incoming::Message::Subscription(message) => {
let websocket::incoming::subscription::Message::Market { bars: symbols, .. } =
message
else {
unreachable!()
};
let newly_unsubscribed = guard
.pending_unsubscriptions
.extract_if(|asset| !symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
let mut pending = pending.write().await;
drop(guard);
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
let newly_subscribed_future = async {
if !newly_subscribed.is_empty() {
info!(
"Subscribed to {} for {:?}.",
log_string,
newly_subscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
let (backfill_message, backfill_receiver) = backfill::Message::new(
backfill::Action::Backfill,
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(backfill_message).await.unwrap();
backfill_receiver.await.unwrap();
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
};
let newly_unsubscribed_future = async {
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from {} for {:?}.",
log_string,
newly_unsubscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
let (purge_message, purge_receiver) = backfill::Message::new(
backfill::Action::Purge,
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(purge_message).await.unwrap();
purge_receiver.await.unwrap();
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
};
join!(newly_subscribed_future, newly_unsubscribed_future);
}
websocket::incoming::Message::Bar(message)
| websocket::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
let guard = guard.read().await;
if !guard.assets.contains_right(&bar.symbol) {
warn!(
"Race condition: received bar for unsubscribed symbol: {:?}.",
bar.symbol
);
return;
}
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
}
websocket::incoming::Message::News(message) => {
let news = News::from(message);
let guard = guard.read().await;
if !news
.symbols
.iter()
.any(|symbol| guard.assets.contains_right(symbol))
{
warn!(
"Race condition: received news for unsubscribed symbols: {:?}.",
news.symbols
);
return;
websocket::incoming::Message::Bar(message)
| websocket::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.app_config.clickhouse_client, &bar).await;
}
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = app_config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&app_config.clickhouse_client, &news).await;
}
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
websocket::incoming::Message::News(_) => unreachable!(),
}
}
}
struct NewsHandler {
app_config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::outgoing::subscribe::Message {
websocket::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_parsed_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::incoming::Message,
) {
match message {
websocket::incoming::Message::Subscription(message) => {
let websocket::incoming::subscription::Message::News { news: symbols } = message
else {
unreachable!()
};
let symbols = symbols
.into_iter()
.map(|symbol| add_slash_to_pair(&symbol))
.collect::<Vec<_>>();
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.app_config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&self.app_config.clickhouse_client, &news).await;
}
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
websocket::incoming::Message::Bar(_) | websocket::incoming::Message::UpdatedBar(_) => {
unreachable!()
}
}
}
}
pub fn create_handler(thread_type: ThreadType, app_config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(_) => Box::new(BarsHandler { app_config }),
ThreadType::News => Box::new(NewsHandler { app_config }),
}
}