Files
qrust/src/bin/qrust/threads/data/websocket/bars.rs
2024-03-20 19:43:46 +00:00

172 lines
5.8 KiB
Rust

use super::State;
use crate::{
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, Bar, Class},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<Bar>>>,
pub subscription_message_constructor:
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
self.inserter.lock().await.write(&bar).await.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("bars")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
));
let subscription_message_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
websocket::data::outgoing::subscribe::Message::new_market_us_equity
}
ThreadType::Bars(Class::Crypto) => {
websocket::data::outgoing::subscribe::Message::new_market_crypto
}
_ => unreachable!(),
};
Box::new(Handler {
config,
inserter,
subscription_message_constructor,
})
}