438 lines
14 KiB
Rust
438 lines
14 KiB
Rust
use super::ThreadType;
|
|
use crate::{
|
|
config::Config,
|
|
database,
|
|
types::{alpaca::websocket, news::Prediction, Bar, Class, News},
|
|
};
|
|
use async_trait::async_trait;
|
|
use futures_util::{
|
|
future::join_all,
|
|
stream::{SplitSink, SplitStream},
|
|
SinkExt, StreamExt,
|
|
};
|
|
use log::{debug, error, info};
|
|
use serde_json::{from_str, to_string};
|
|
use std::{collections::HashMap, sync::Arc};
|
|
use tokio::{
|
|
net::TcpStream,
|
|
select, spawn,
|
|
sync::{mpsc, oneshot, Mutex, RwLock},
|
|
task::block_in_place,
|
|
};
|
|
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
|
|
|
pub enum Action {
|
|
Subscribe,
|
|
Unsubscribe,
|
|
}
|
|
|
|
impl From<super::Action> for Option<Action> {
|
|
fn from(action: super::Action) -> Self {
|
|
match action {
|
|
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
|
|
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Message {
|
|
pub action: Option<Action>,
|
|
pub symbols: Vec<String>,
|
|
pub response: oneshot::Sender<()>,
|
|
}
|
|
|
|
impl Message {
|
|
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
|
let (sender, receiver) = oneshot::channel();
|
|
(
|
|
Self {
|
|
action,
|
|
symbols,
|
|
response: sender,
|
|
},
|
|
receiver,
|
|
)
|
|
}
|
|
}
|
|
|
|
pub struct Pending {
|
|
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
|
|
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait Handler: Send + Sync {
|
|
fn create_subscription_message(
|
|
&self,
|
|
symbols: Vec<String>,
|
|
) -> websocket::data::outgoing::subscribe::Message;
|
|
async fn handle_websocket_message(
|
|
&self,
|
|
pending: Arc<RwLock<Pending>>,
|
|
message: websocket::data::incoming::Message,
|
|
);
|
|
}
|
|
|
|
pub async fn run(
|
|
handler: Arc<Box<dyn Handler>>,
|
|
mut receiver: mpsc::Receiver<Message>,
|
|
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
|
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
|
|
) {
|
|
let pending = Arc::new(RwLock::new(Pending {
|
|
subscriptions: HashMap::new(),
|
|
unsubscriptions: HashMap::new(),
|
|
}));
|
|
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
|
|
|
|
loop {
|
|
select! {
|
|
Some(message) = receiver.recv() => {
|
|
spawn(handle_message(
|
|
handler.clone(),
|
|
pending.clone(),
|
|
websocket_sink.clone(),
|
|
message,
|
|
));
|
|
}
|
|
Some(Ok(message)) = websocket_stream.next() => {
|
|
match message {
|
|
tungstenite::Message::Text(message) => {
|
|
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
|
|
|
|
if parsed_message.is_err() {
|
|
error!("Failed to deserialize websocket message: {:?}", message);
|
|
continue;
|
|
}
|
|
|
|
for message in parsed_message.unwrap() {
|
|
let handler = handler.clone();
|
|
let pending = pending.clone();
|
|
spawn(async move {
|
|
handler.handle_websocket_message(pending, message).await;
|
|
});
|
|
}
|
|
}
|
|
tungstenite::Message::Ping(_) => {}
|
|
_ => error!("Unexpected websocket message: {:?}", message),
|
|
}
|
|
}
|
|
else => panic!("Communication channel unexpectedly closed.")
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_message(
|
|
handler: Arc<Box<dyn Handler>>,
|
|
pending: Arc<RwLock<Pending>>,
|
|
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
|
|
message: Message,
|
|
) {
|
|
if message.symbols.is_empty() {
|
|
message.response.send(()).unwrap();
|
|
return;
|
|
}
|
|
|
|
match message.action {
|
|
Some(Action::Subscribe) => {
|
|
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
|
|
.symbols
|
|
.iter()
|
|
.map(|symbol| {
|
|
let (sender, receiver) = oneshot::channel();
|
|
((symbol.clone(), sender), receiver)
|
|
})
|
|
.unzip();
|
|
|
|
pending
|
|
.write()
|
|
.await
|
|
.subscriptions
|
|
.extend(pending_subscriptions);
|
|
|
|
sink.lock()
|
|
.await
|
|
.send(tungstenite::Message::Text(
|
|
to_string(&websocket::data::outgoing::Message::Subscribe(
|
|
handler.create_subscription_message(message.symbols),
|
|
))
|
|
.unwrap(),
|
|
))
|
|
.await
|
|
.unwrap();
|
|
|
|
join_all(receivers).await;
|
|
}
|
|
Some(Action::Unsubscribe) => {
|
|
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
|
|
.symbols
|
|
.iter()
|
|
.map(|symbol| {
|
|
let (sender, receiver) = oneshot::channel();
|
|
((symbol.clone(), sender), receiver)
|
|
})
|
|
.unzip();
|
|
|
|
pending
|
|
.write()
|
|
.await
|
|
.unsubscriptions
|
|
.extend(pending_unsubscriptions);
|
|
|
|
sink.lock()
|
|
.await
|
|
.send(tungstenite::Message::Text(
|
|
to_string(&websocket::data::outgoing::Message::Unsubscribe(
|
|
handler.create_subscription_message(message.symbols.clone()),
|
|
))
|
|
.unwrap(),
|
|
))
|
|
.await
|
|
.unwrap();
|
|
|
|
join_all(receivers).await;
|
|
}
|
|
None => {}
|
|
}
|
|
|
|
message.response.send(()).unwrap();
|
|
}
|
|
|
|
struct BarsHandler {
|
|
config: Arc<Config>,
|
|
subscription_message_constructor:
|
|
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Handler for BarsHandler {
|
|
fn create_subscription_message(
|
|
&self,
|
|
symbols: Vec<String>,
|
|
) -> websocket::data::outgoing::subscribe::Message {
|
|
(self.subscription_message_constructor)(symbols)
|
|
}
|
|
|
|
async fn handle_websocket_message(
|
|
&self,
|
|
pending: Arc<RwLock<Pending>>,
|
|
message: websocket::data::incoming::Message,
|
|
) {
|
|
match message {
|
|
websocket::data::incoming::Message::Subscription(message) => {
|
|
let websocket::data::incoming::subscription::Message::Market {
|
|
bars: symbols, ..
|
|
} = message
|
|
else {
|
|
unreachable!()
|
|
};
|
|
|
|
let mut pending = pending.write().await;
|
|
|
|
let newly_subscribed = pending
|
|
.subscriptions
|
|
.extract_if(|symbol, _| symbols.contains(symbol))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let newly_unsubscribed = pending
|
|
.unsubscriptions
|
|
.extract_if(|symbol, _| !symbols.contains(symbol))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
drop(pending);
|
|
|
|
if !newly_subscribed.is_empty() {
|
|
info!(
|
|
"Subscribed to bars for {:?}.",
|
|
newly_subscribed.keys().collect::<Vec<_>>()
|
|
);
|
|
|
|
for sender in newly_subscribed.into_values() {
|
|
sender.send(()).unwrap();
|
|
}
|
|
}
|
|
|
|
if !newly_unsubscribed.is_empty() {
|
|
info!(
|
|
"Unsubscribed from bars for {:?}.",
|
|
newly_unsubscribed.keys().collect::<Vec<_>>()
|
|
);
|
|
|
|
for sender in newly_unsubscribed.into_values() {
|
|
sender.send(()).unwrap();
|
|
}
|
|
}
|
|
}
|
|
websocket::data::incoming::Message::Bar(message)
|
|
| websocket::data::incoming::Message::UpdatedBar(message) => {
|
|
let bar = Bar::from(message);
|
|
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
|
|
|
database::bars::upsert(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&bar,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
websocket::data::incoming::Message::Status(message) => {
|
|
debug!(
|
|
"Received status message for {}: {:?}.",
|
|
message.symbol, message.status
|
|
);
|
|
|
|
match message.status {
|
|
websocket::data::incoming::status::Status::TradingHalt(_)
|
|
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
|
|
database::assets::update_status_where_symbol(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&message.symbol,
|
|
false,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
websocket::data::incoming::status::Status::Resume(_)
|
|
| websocket::data::incoming::status::Status::TradingResumption(_) => {
|
|
database::assets::update_status_where_symbol(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&message.symbol,
|
|
true,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
websocket::data::incoming::Message::Error(message) => {
|
|
error!("Received error message: {}.", message.message);
|
|
}
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct NewsHandler {
|
|
config: Arc<Config>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Handler for NewsHandler {
|
|
fn create_subscription_message(
|
|
&self,
|
|
symbols: Vec<String>,
|
|
) -> websocket::data::outgoing::subscribe::Message {
|
|
websocket::data::outgoing::subscribe::Message::new_news(symbols)
|
|
}
|
|
|
|
async fn handle_websocket_message(
|
|
&self,
|
|
pending: Arc<RwLock<Pending>>,
|
|
message: websocket::data::incoming::Message,
|
|
) {
|
|
match message {
|
|
websocket::data::incoming::Message::Subscription(message) => {
|
|
let websocket::data::incoming::subscription::Message::News { news: symbols } =
|
|
message
|
|
else {
|
|
unreachable!()
|
|
};
|
|
|
|
let mut pending = pending.write().await;
|
|
|
|
let newly_subscribed = pending
|
|
.subscriptions
|
|
.extract_if(|symbol, _| symbols.contains(symbol))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let newly_unsubscribed = pending
|
|
.unsubscriptions
|
|
.extract_if(|symbol, _| !symbols.contains(symbol))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
drop(pending);
|
|
|
|
if !newly_subscribed.is_empty() {
|
|
info!(
|
|
"Subscribed to news for {:?}.",
|
|
newly_subscribed.keys().collect::<Vec<_>>()
|
|
);
|
|
|
|
for sender in newly_subscribed.into_values() {
|
|
sender.send(()).unwrap();
|
|
}
|
|
}
|
|
|
|
if !newly_unsubscribed.is_empty() {
|
|
info!(
|
|
"Unsubscribed from news for {:?}.",
|
|
newly_unsubscribed.keys().collect::<Vec<_>>()
|
|
);
|
|
|
|
for sender in newly_unsubscribed.into_values() {
|
|
sender.send(()).unwrap();
|
|
}
|
|
}
|
|
}
|
|
websocket::data::incoming::Message::News(message) => {
|
|
let news = News::from(message);
|
|
|
|
debug!(
|
|
"Received news for {:?}: {}.",
|
|
news.symbols, news.time_created
|
|
);
|
|
|
|
let input = format!("{}\n\n{}", news.headline, news.content);
|
|
|
|
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
|
let prediction = block_in_place(|| {
|
|
sequence_classifier
|
|
.predict(vec![input.as_str()])
|
|
.into_iter()
|
|
.map(|label| Prediction::try_from(label).unwrap())
|
|
.collect::<Vec<_>>()[0]
|
|
});
|
|
drop(sequence_classifier);
|
|
|
|
let news = News {
|
|
sentiment: prediction.sentiment,
|
|
confidence: prediction.confidence,
|
|
..news
|
|
};
|
|
|
|
database::news::upsert(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&news,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
websocket::data::incoming::Message::Error(message) => {
|
|
error!("Received error message: {}.", message.message);
|
|
}
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
|
|
match thread_type {
|
|
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
|
|
config,
|
|
subscription_message_constructor:
|
|
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
|
|
}),
|
|
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
|
|
config,
|
|
subscription_message_constructor:
|
|
websocket::data::outgoing::subscribe::Message::new_market_crypto,
|
|
}),
|
|
ThreadType::News => Box::new(NewsHandler { config }),
|
|
}
|
|
}
|