Prevent race conditions

- This is a massive cope, I don't know how to code

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-01-17 17:16:37 +00:00
parent 36ee6030ce
commit 7200447bc5
23 changed files with 510 additions and 311 deletions

View File

@@ -5,10 +5,10 @@ use crate::{
},
data::authenticate_websocket,
database,
state::{self, BroadcastMessage},
types::{
alpaca::{api, websocket, Source},
asset::{self, Asset},
Bar, BarValidity, BroadcastMessage, Class,
Asset, Backfill, Bar, Class,
},
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
};
@@ -18,11 +18,12 @@ use futures_util::{
SinkExt, StreamExt,
};
use log::{error, info, warn};
use serde_json::from_str;
use serde_json::{from_str, to_string};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use time::OffsetDateTime;
use tokio::{
net::TcpStream,
spawn,
@@ -30,14 +31,22 @@ use tokio::{
broadcast::{Receiver, Sender},
RwLock,
},
task::JoinHandle,
time::sleep,
};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub struct Guard {
symbols: HashSet<String>,
backfill_jobs: HashMap<String, JoinHandle<()>>,
pending_subscriptions: HashMap<String, Asset>,
pending_unsubscriptions: HashMap<String, Asset>,
}
pub async fn run(
app_config: Arc<Config>,
class: Class,
broadcast_sender: Sender<BroadcastMessage>,
broadcast_bus_sender: Sender<BroadcastMessage>,
) {
info!("Running live data threads for {:?}.", class);
@@ -54,208 +63,204 @@ pub async fn run(
authenticate_websocket(&app_config, &mut stream, &mut sink).await;
let sink = Arc::new(RwLock::new(sink));
spawn(broadcast_handler(
let guard = Arc::new(RwLock::new(Guard {
symbols: HashSet::new(),
backfill_jobs: HashMap::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
spawn(broadcast_bus_handler(
app_config.clone(),
class,
sink.clone(),
broadcast_sender.subscribe(),
broadcast_bus_sender.subscribe(),
guard.clone(),
));
let assets = database::assets::select_where_class(&app_config.clickhouse_client, class).await;
broadcast_sender
spawn(websocket_handler(
app_config.clone(),
stream,
sink,
broadcast_bus_sender.clone(),
guard.clone(),
));
spawn(clock_handler(
app_config.clone(),
class,
broadcast_bus_sender.clone(),
));
let assets = database::assets::select_where_class(&app_config.clickhouse_client, &class).await;
broadcast_bus_sender
.send(BroadcastMessage::Asset((
asset::BroadcastMessage::Added,
state::asset::BroadcastMessage::Add,
assets,
)))
.unwrap();
websocket_handler(app_config, class, stream, sink).await;
unreachable!()
}
async fn broadcast_handler(
#[allow(clippy::too_many_lines)]
#[allow(clippy::significant_drop_tightening)]
pub async fn broadcast_bus_handler(
app_config: Arc<Config>,
class: Class,
sink: Arc<RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
mut broadcast_receiver: Receiver<BroadcastMessage>,
mut broadcast_bus_receiver: Receiver<BroadcastMessage>,
guard: Arc<RwLock<Guard>>,
) {
loop {
match broadcast_receiver.recv().await.unwrap() {
match broadcast_bus_receiver.recv().await.unwrap() {
BroadcastMessage::Asset((action, assets)) => {
let symbols = assets
let assets = assets
.into_iter()
.filter(|asset| asset.class == class)
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
if symbols.is_empty() {
if assets.is_empty() {
continue;
}
sink.write()
.await
.send(Message::Text(
serde_json::to_string(&match action {
asset::BroadcastMessage::Added => {
websocket::data::outgoing::Message::Subscribe(
websocket::data::outgoing::subscribe::Message::new(
symbols.clone(),
),
)
let symbols = assets
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>();
match action {
state::asset::BroadcastMessage::Add => {
database::assets::upsert_batch(&app_config.clickhouse_client, &assets)
.await;
info!("Added {:?}.", symbols);
let mut guard = guard.write().await;
guard.pending_subscriptions.extend(
assets
.into_iter()
.map(|asset| (asset.symbol.clone(), asset)),
);
guard.symbols.extend(symbols.clone());
sink.write()
.await
.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
websocket::data::outgoing::subscribe::Message::new(symbols),
))
.unwrap(),
))
.await
.unwrap();
}
state::asset::BroadcastMessage::Delete => {
database::assets::delete_where_symbols(
&app_config.clickhouse_client,
&symbols,
)
.await;
info!("Deleted {:?}.", symbols);
let mut guard = guard.write().await;
guard.pending_unsubscriptions.extend(
assets
.into_iter()
.map(|asset| (asset.symbol.clone(), asset)),
);
guard.symbols = guard
.symbols
.clone()
.into_iter()
.filter(|symbol| !symbols.contains(symbol))
.collect::<HashSet<_, _>>();
sink.write()
.await
.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
websocket::data::outgoing::subscribe::Message::new(symbols),
))
.unwrap(),
))
.await
.unwrap();
}
state::asset::BroadcastMessage::Backfill => {
info!("Creating backfill jobs for {:?}.", symbols);
let guard_clone = guard.clone();
let mut guard = guard.write().await;
for asset in assets {
let mut handles = Vec::new();
if let Some(backfill_job) = guard.backfill_jobs.remove(&asset.symbol) {
backfill_job.abort();
handles.push(backfill_job);
}
asset::BroadcastMessage::Deleted => {
websocket::data::outgoing::Message::Unsubscribe(
websocket::data::outgoing::subscribe::Message::new(
symbols.clone(),
),
)
for handle in handles {
handle.await.unwrap_err();
}
})
.unwrap(),
))
.await
.unwrap();
guard.backfill_jobs.insert(asset.symbol.clone(), {
let guard = guard_clone.clone();
let app_config = app_config.clone();
spawn(async move {
backfill(app_config, class, asset.clone()).await;
let mut guard = guard.write().await;
guard.backfill_jobs.remove(&asset.symbol);
})
});
}
}
state::asset::BroadcastMessage::Purge => {
let mut guard = guard.write().await;
info!("Purging {:?}.", symbols);
for asset in assets {
let mut handles = Vec::new();
if let Some(backfill_job) = guard.backfill_jobs.remove(&asset.symbol) {
backfill_job.abort();
handles.push(backfill_job);
}
for handle in handles {
handle.await.unwrap_err();
}
}
database::backfills::delete_where_symbols(
&app_config.clickhouse_client,
&symbols,
)
.await;
database::bars::delete_where_symbols(
&app_config.clickhouse_client,
&symbols,
)
.await;
}
}
}
}
}
}
async fn websocket_handler(
pub async fn clock_handler(
app_config: Arc<Config>,
class: Class,
mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
sink: Arc<RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
broadcast_bus_sender: Sender<BroadcastMessage>,
) {
let backfilled = Arc::new(RwLock::new(HashMap::new()));
loop {
match stream.next().await {
Some(Ok(Message::Text(data))) => {
let parsed_data = from_str::<Vec<websocket::data::incoming::Message>>(&data);
if let Err(e) = &parsed_data {
warn!(
"Unparsed websocket::data::incoming message: {:?}: {}",
data, e
);
}
for message in parsed_data.unwrap_or_default() {
spawn(websocket_handle_message(
app_config.clone(),
class,
backfilled.clone(),
message,
));
}
}
Some(Ok(Message::Ping(_))) => sink
.write()
.await
.send(Message::Pong(vec![]))
.await
.unwrap(),
Some(unknown) => error!("Unknown websocket::data::incoming message: {:?}", unknown),
None => panic!(),
}
}
}
async fn websocket_handle_message(
app_config: Arc<Config>,
class: Class,
backfilled: Arc<RwLock<HashMap<String, bool>>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let added_asset_symbols;
let deleted_asset_symbols;
{
let mut backfilled = backfilled.write().await;
let old_asset_sybols = backfilled.keys().cloned().collect::<HashSet<_>>();
let new_asset_symbols = message.bars.into_iter().collect::<HashSet<_>>();
added_asset_symbols = new_asset_symbols
.difference(&old_asset_sybols)
.cloned()
.collect::<HashSet<_>>();
for asset_symbol in &added_asset_symbols {
backfilled.insert(asset_symbol.clone(), false);
}
deleted_asset_symbols = old_asset_sybols
.difference(&new_asset_symbols)
.cloned()
.collect::<HashSet<_>>();
for asset_symbol in &deleted_asset_symbols {
backfilled.remove(asset_symbol);
}
drop(backfilled);
info!(
"Subscription update for {:?}: {:?} added, {:?} deleted.",
class, added_asset_symbols, deleted_asset_symbols
);
}
for asset_symbol in added_asset_symbols {
let asset = database::assets::select_where_symbol(
&app_config.clickhouse_client,
&asset_symbol,
)
.await
.unwrap();
database::bars::insert_validity_if_not_exists(
&app_config.clickhouse_client,
&BarValidity::none(asset.symbol.clone()),
)
.await;
spawn(backfill(app_config.clone(), backfilled.clone(), asset));
}
for asset_symbol in deleted_asset_symbols {
database::bars::delete_validity_where_symbol(
&app_config.clickhouse_client,
&asset_symbol,
)
.await;
database::bars::delete_where_symbol(&app_config.clickhouse_client, &asset_symbol)
.await;
}
}
websocket::data::incoming::Message::Bars(bar_message)
| websocket::data::incoming::Message::UpdatedBars(bar_message) => {
let bar = Bar::from(bar_message);
info!("websocket::Incoming bar for {}: {}", bar.symbol, bar.time);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
if *backfilled.read().await.get(&bar.symbol).unwrap() {
database::bars::upsert_validity(&app_config.clickhouse_client, &bar.into()).await;
}
}
websocket::data::incoming::Message::Success(_) => {}
}
}
pub async fn backfill(
app_config: Arc<Config>,
backfilled: Arc<RwLock<HashMap<String, bool>>>,
asset: Asset,
) {
let bar_validity =
database::bars::select_validity_where_symbol(&app_config.clickhouse_client, &asset.symbol)
.await
.unwrap();
let fetch_from = bar_validity.time_last + ONE_MINUTE;
let fetch_until = if app_config.alpaca_source == Source::Iex {
app_config.alpaca_rate_limit.until_ready().await;
let clock = app_config
.alpaca_client
@@ -267,15 +272,177 @@ pub async fn backfill(
.await
.unwrap();
if clock.is_open {
last_minute()
let sleep_until = duration_until(if clock.is_open {
if class == Class::UsEquity {
info!("Market is open, will close at {}.", clock.next_close);
}
clock.next_close
} else {
if class == Class::UsEquity {
info!("Market is closed, will reopen at {}.", clock.next_open);
}
clock.next_open
});
sleep(sleep_until).await;
let assets = database::assets::select(&app_config.clickhouse_client).await;
broadcast_bus_sender
.send(BroadcastMessage::Asset((
state::asset::BroadcastMessage::Backfill,
assets,
)))
.unwrap();
}
}
async fn websocket_handler(
app_config: Arc<Config>,
mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
sink: Arc<RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
broadcast_bus_sender: Sender<BroadcastMessage>,
guard: Arc<RwLock<Guard>>,
) {
loop {
let app_config = app_config.clone();
let sink = sink.clone();
let broadcast_bus_sender = broadcast_bus_sender.clone();
let guard = guard.clone();
let message = stream.next().await;
spawn(async move {
match message {
Some(Ok(Message::Text(data))) => {
let parsed_data = from_str::<Vec<websocket::data::incoming::Message>>(&data);
if let Ok(messages) = parsed_data {
for message in messages {
websocket_handle_message(
app_config.clone(),
broadcast_bus_sender.clone(),
guard.clone(),
message,
)
.await;
}
} else {
error!(
"Unparsed websocket::data::incoming message: {:?}: {}",
data,
parsed_data.err().unwrap()
);
}
}
Some(Ok(Message::Ping(_))) => sink
.write()
.await
.send(Message::Pong(vec![]))
.await
.unwrap(),
Some(unknown) => error!("Unknown websocket::data::incoming message: {:?}", unknown),
_ => panic!(),
}
});
}
}
#[allow(clippy::significant_drop_tightening)]
async fn websocket_handle_message(
app_config: Arc<Config>,
broadcast_bus_sender: Sender<BroadcastMessage>,
guard: Arc<RwLock<Guard>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let symbols = message.bars.into_iter().collect::<HashSet<_>>();
let mut guard = guard.write().await;
let newly_subscribed_assets = guard
.pending_subscriptions
.drain()
.filter(|(symbol, _)| symbols.contains(symbol))
.map(|(_, asset)| asset)
.collect::<Vec<_>>();
if !newly_subscribed_assets.is_empty() {
info!(
"Subscribed to {:?}.",
newly_subscribed_assets
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
broadcast_bus_sender
.send(BroadcastMessage::Asset((
state::asset::BroadcastMessage::Backfill,
newly_subscribed_assets,
)))
.unwrap();
}
let newly_unsubscribed_assets = guard
.pending_unsubscriptions
.drain()
.filter(|(symbol, _)| !symbols.contains(symbol))
.map(|(_, asset)| asset)
.collect::<Vec<_>>();
if !newly_unsubscribed_assets.is_empty() {
info!(
"Unsubscribed from {:?}.",
newly_unsubscribed_assets
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
broadcast_bus_sender
.send(BroadcastMessage::Asset((
state::asset::BroadcastMessage::Purge,
newly_unsubscribed_assets,
)))
.unwrap();
}
}
websocket::data::incoming::Message::Bars(bar_message)
| websocket::data::incoming::Message::UpdatedBars(bar_message) => {
let bar = Bar::from(bar_message);
let guard = guard.read().await;
let symbol_status = guard.symbols.get(&bar.symbol);
if symbol_status.is_none() {
warn!(
"Race condition: received bar for unsubscribed symbol: {:?}.",
bar.symbol
);
return;
}
info!("Received bar for {}: {}", bar.symbol, bar.time);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
}
websocket::data::incoming::Message::Success(_) => {}
}
}
pub async fn backfill(app_config: Arc<Config>, class: Class, asset: Asset) {
let latest_backfill = database::backfills::select_latest_where_symbol(
&app_config.clickhouse_client,
&asset.symbol,
)
.await;
let fetch_from = if let Some(backfill) = latest_backfill {
backfill.time + ONE_MINUTE
} else {
last_minute()
OffsetDateTime::UNIX_EPOCH
};
let fetch_until = last_minute();
if fetch_from > fetch_until {
return;
}
@@ -289,7 +456,7 @@ pub async fn backfill(
sleep(task_run_delay).await;
}
info!("Running historical data backfill for {}...", asset.symbol);
info!("Running historical data backfill for {}.", asset.symbol);
let mut bars = Vec::new();
let mut next_page_token = None;
@@ -298,7 +465,7 @@ pub async fn backfill(
app_config.alpaca_rate_limit.until_ready().await;
let message = app_config
.alpaca_client
.get(match asset.class {
.get(match class {
Class::UsEquity => ALPACA_STOCK_DATA_URL,
Class::Crypto => ALPACA_CRYPTO_DATA_URL,
})
@@ -330,12 +497,11 @@ pub async fn backfill(
}
database::bars::upsert_batch(&app_config.clickhouse_client, &bars).await;
if let Some(last_bar) = bars.last() {
database::bars::upsert_validity(&app_config.clickhouse_client, &last_bar.clone().into())
.await;
}
backfilled.write().await.insert(asset.symbol.clone(), true);
database::backfills::upsert(
&app_config.clickhouse_client,
&Backfill::new(asset.symbol.clone(), fetch_until),
)
.await;
info!("Backfilled historical data for {}.", asset.symbol);
}