Improve error handling

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-01-20 11:57:55 +00:00
parent 7200447bc5
commit 2d14fe35c8
12 changed files with 191 additions and 154 deletions

View File

@@ -1,6 +1,9 @@
use crate::types::alpaca::Source;
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use reqwest::{header::HeaderMap, Client};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets";
@@ -36,20 +39,24 @@ impl Config {
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
Self {
alpaca_api_key: alpaca_api_key.clone(),
alpaca_api_secret: alpaca_api_secret.clone(),
alpaca_client: Client::builder()
.default_headers({
let mut headers = HeaderMap::new();
headers.insert("APCA-API-KEY-ID", alpaca_api_key.parse().unwrap());
headers.insert("APCA-API-SECRET-KEY", alpaca_api_secret.parse().unwrap());
headers
})
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&alpaca_api_key)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&alpaca_api_secret)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source {
Source::Iex => NonZeroU32::new(180).unwrap(),
Source::Sip => NonZeroU32::new(900).unwrap(),
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
})),
alpaca_source,
clickhouse_client: clickhouse::Client::default()
@@ -57,6 +64,8 @@ impl Config {
.with_user(clickhouse_user)
.with_password(clickhouse_password)
.with_database(clickhouse_db),
alpaca_api_key,
alpaca_api_secret,
}
}

View File

@@ -1,7 +1,6 @@
use crate::{
config::{
Config, ALPACA_CLOCK_API_URL, ALPACA_CRYPTO_DATA_URL, ALPACA_CRYPTO_WEBSOCKET_URL,
ALPACA_STOCK_DATA_URL, ALPACA_STOCK_WEBSOCKET_URL,
Config, ALPACA_CLOCK_API_URL, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL,
},
data::authenticate_websocket,
database,
@@ -12,7 +11,7 @@ use crate::{
},
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
};
use core::panic;
use backoff::{future::retry, ExponentialBackoff};
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
@@ -29,7 +28,7 @@ use tokio::{
spawn,
sync::{
broadcast::{Receiver, Sender},
RwLock,
Mutex, RwLock,
},
task::JoinHandle,
time::sleep,
@@ -61,7 +60,7 @@ pub async fn run(
let (stream, _) = connect_async(websocket_url).await.unwrap();
let (mut sink, mut stream) = stream.split();
authenticate_websocket(&app_config, &mut stream, &mut sink).await;
let sink = Arc::new(RwLock::new(sink));
let sink = Arc::new(Mutex::new(sink));
let guard = Arc::new(RwLock::new(Guard {
symbols: HashSet::new(),
@@ -106,17 +105,14 @@ pub async fn run(
pub async fn broadcast_bus_handler(
app_config: Arc<Config>,
class: Class,
sink: Arc<RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
mut broadcast_bus_receiver: Receiver<BroadcastMessage>,
guard: Arc<RwLock<Guard>>,
) {
loop {
match broadcast_bus_receiver.recv().await.unwrap() {
BroadcastMessage::Asset((action, assets)) => {
let assets = assets
.into_iter()
.filter(|asset| asset.class == class)
.collect::<Vec<_>>();
BroadcastMessage::Asset((action, mut assets)) => {
assets.retain(|asset| asset.class == class);
if assets.is_empty() {
continue;
@@ -144,7 +140,7 @@ pub async fn broadcast_bus_handler(
guard.symbols.extend(symbols.clone());
sink.write()
sink.lock()
.await
.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
@@ -172,14 +168,9 @@ pub async fn broadcast_bus_handler(
.map(|asset| (asset.symbol.clone(), asset)),
);
guard.symbols = guard
.symbols
.clone()
.into_iter()
.filter(|symbol| !symbols.contains(symbol))
.collect::<HashSet<_, _>>();
guard.symbols.retain(|symbol| !symbols.contains(symbol));
sink.write()
sink.lock()
.await
.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
@@ -191,20 +182,15 @@ pub async fn broadcast_bus_handler(
.unwrap();
}
state::asset::BroadcastMessage::Backfill => {
info!("Creating backfill jobs for {:?}.", symbols);
let guard_clone = guard.clone();
let mut guard = guard.write().await;
info!("Creating backfill jobs for {:?}.", symbols);
for asset in assets {
let mut handles = Vec::new();
if let Some(backfill_job) = guard.backfill_jobs.remove(&asset.symbol) {
backfill_job.abort();
handles.push(backfill_job);
}
for handle in handles {
handle.await.unwrap_err();
backfill_job.await.unwrap_err();
}
guard.backfill_jobs.insert(asset.symbol.clone(), {
@@ -226,14 +212,9 @@ pub async fn broadcast_bus_handler(
info!("Purging {:?}.", symbols);
for asset in assets {
let mut handles = Vec::new();
if let Some(backfill_job) = guard.backfill_jobs.remove(&asset.symbol) {
backfill_job.abort();
handles.push(backfill_job);
}
for handle in handles {
handle.await.unwrap_err();
backfill_job.await.unwrap_err();
}
}
@@ -261,16 +242,20 @@ pub async fn clock_handler(
broadcast_bus_sender: Sender<BroadcastMessage>,
) {
loop {
app_config.alpaca_rate_limit.until_ready().await;
let clock = app_config
.alpaca_client
.get(ALPACA_CLOCK_API_URL)
.send()
.await
.unwrap()
.json::<api::incoming::clock::Clock>()
.await
.unwrap();
let clock = retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(ALPACA_CLOCK_API_URL)
.send()
.await?
.error_for_status()?
.json::<api::incoming::clock::Clock>()
.await
.map_err(backoff::Error::Permanent)
})
.await
.unwrap();
let sleep_until = duration_until(if clock.is_open {
if class == Class::UsEquity {
@@ -299,7 +284,7 @@ pub async fn clock_handler(
async fn websocket_handler(
app_config: Arc<Config>,
mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
sink: Arc<RwLock<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
broadcast_bus_sender: Sender<BroadcastMessage>,
guard: Arc<RwLock<Guard>>,
) {
@@ -308,11 +293,11 @@ async fn websocket_handler(
let sink = sink.clone();
let broadcast_bus_sender = broadcast_bus_sender.clone();
let guard = guard.clone();
let message = stream.next().await;
let message = stream.next().await.expect("Websocket stream closed.");
spawn(async move {
match message {
Some(Ok(Message::Text(data))) => {
Ok(Message::Text(data)) => {
let parsed_data = from_str::<Vec<websocket::data::incoming::Message>>(&data);
if let Ok(messages) = parsed_data {
@@ -327,20 +312,16 @@ async fn websocket_handler(
}
} else {
error!(
"Unparsed websocket::data::incoming message: {:?}: {}",
"Unparsed websocket message: {:?}: {}.",
data,
parsed_data.err().unwrap()
parsed_data.unwrap_err()
);
}
}
Some(Ok(Message::Ping(_))) => sink
.write()
.await
.send(Message::Pong(vec![]))
.await
.unwrap(),
Some(unknown) => error!("Unknown websocket::data::incoming message: {:?}", unknown),
_ => panic!(),
Ok(Message::Ping(_)) => {
sink.lock().await.send(Message::Pong(vec![])).await.unwrap();
}
_ => error!("Unknown websocket message: {:?}.", message),
}
});
}
@@ -361,8 +342,7 @@ async fn websocket_handle_message(
let newly_subscribed_assets = guard
.pending_subscriptions
.drain()
.filter(|(symbol, _)| symbols.contains(symbol))
.extract_if(|symbol, _| symbols.contains(symbol))
.map(|(_, asset)| asset)
.collect::<Vec<_>>();
@@ -385,8 +365,7 @@ async fn websocket_handle_message(
let newly_unsubscribed_assets = guard
.pending_unsubscriptions
.drain()
.filter(|(symbol, _)| !symbols.contains(symbol))
.extract_if(|symbol, _| !symbols.contains(symbol))
.map(|(_, asset)| asset)
.collect::<Vec<_>>();
@@ -422,7 +401,7 @@ async fn websocket_handle_message(
return;
}
info!("Received bar for {}: {}", bar.symbol, bar.time);
info!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
}
websocket::data::incoming::Message::Success(_) => {}
@@ -462,27 +441,38 @@ pub async fn backfill(app_config: Arc<Config>, class: Class, asset: Asset) {
let mut next_page_token = None;
loop {
app_config.alpaca_rate_limit.until_ready().await;
let message = app_config
.alpaca_client
.get(match class {
Class::UsEquity => ALPACA_STOCK_DATA_URL,
Class::Crypto => ALPACA_CRYPTO_DATA_URL,
})
.query(&api::outgoing::bar::Bar::new(
vec![asset.symbol.clone()],
ONE_MINUTE,
fetch_from,
fetch_until,
10000,
next_page_token,
))
.send()
.await
.unwrap()
.json::<api::incoming::bar::Message>()
.await
.unwrap();
let message = retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(class.get_data_url())
.query(&api::outgoing::bar::Bar::new(
vec![asset.symbol.clone()],
ONE_MINUTE,
fetch_from,
fetch_until,
10000,
next_page_token.clone(),
))
.send()
.await?
.error_for_status()?
.json::<api::incoming::bar::Message>()
.await
.map_err(backoff::Error::Permanent)
})
.await;
let message = match message {
Ok(message) => message,
Err(e) => {
error!(
"Failed to backfill historical data for {}: {}.",
asset.symbol, e
);
return;
}
};
message.bars.into_iter().for_each(|(symbol, bar_vec)| {
bar_vec.unwrap_or_default().into_iter().for_each(|bar| {

View File

@@ -24,7 +24,7 @@ async fn authenticate_websocket(
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Connected,
)) => {}
_ => panic!(),
_ => panic!("Failed to connect to Alpaca websocket."),
}
sink.send(Message::Text(
@@ -47,6 +47,6 @@ async fn authenticate_websocket(
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Authenticated,
)) => {}
_ => panic!(),
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}

View File

@@ -1,5 +1,6 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
#![allow(clippy::missing_docs_in_private_items)]
#![feature(hash_extract_if)]
mod config;
mod data;
@@ -14,38 +15,30 @@ use config::Config;
use dotenv::dotenv;
use log4rs::config::Deserializers;
use state::BroadcastMessage;
use std::error::Error;
use tokio::{spawn, sync::broadcast};
use types::Class;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() {
dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default())?;
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let app_config = Config::arc_from_env();
let mut threads = Vec::new();
cleanup(&app_config.clickhouse_client).await;
let (broadcast_bus, _) = broadcast::channel::<BroadcastMessage>(100);
threads.push(spawn(data::market::run(
spawn(data::market::run(
app_config.clone(),
Class::UsEquity,
broadcast_bus.clone(),
)));
));
threads.push(spawn(data::market::run(
spawn(data::market::run(
app_config.clone(),
Class::Crypto,
broadcast_bus.clone(),
)));
));
threads.push(spawn(routes::run(app_config.clone(), broadcast_bus)));
for thread in threads {
thread.await?;
}
unreachable!()
routes::run(app_config, broadcast_bus).await;
}

View File

@@ -8,6 +8,8 @@ use crate::{
},
};
use axum::{extract::Path, Extension, Json};
use backoff::{future::retry, ExponentialBackoff};
use core::panic;
use http::StatusCode;
use serde::Deserialize;
use std::sync::Arc;
@@ -47,27 +49,34 @@ pub async fn add(
return Err(StatusCode::CONFLICT);
}
app_config.alpaca_rate_limit.until_ready().await;
let asset = app_config
.alpaca_client
.get(&format!("{}/{}", ALPACA_ASSET_API_URL, request.symbol))
.send()
.await
.map_err(|e| {
if e.status() == Some(reqwest::StatusCode::NOT_FOUND) {
StatusCode::NOT_FOUND
} else {
panic!()
}
})
.unwrap();
let asset = retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(&format!("{}/{}", ALPACA_ASSET_API_URL, request.symbol))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::NOT_FOUND) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<incoming::asset::Asset>()
.await
.map_err(backoff::Error::Permanent)
})
.await
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::NOT_FOUND) => StatusCode::NOT_FOUND,
_ => panic!("Unexpected error: {}.", e),
})?;
let asset = asset.json::<incoming::asset::Asset>().await.unwrap();
if asset.status != Status::Active || !asset.tradable || !asset.fractionable {
return Err(StatusCode::FORBIDDEN);
}
let asset = Asset::from(asset);
broadcast_bus_sender
.send(BroadcastMessage::Asset((
state::asset::BroadcastMessage::Add,
@@ -85,8 +94,7 @@ pub async fn delete(
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol)
.await
.ok_or(StatusCode::NOT_FOUND)
.unwrap();
.ok_or(StatusCode::NOT_FOUND)?;
broadcast_bus_sender
.send(BroadcastMessage::Asset((

View File

@@ -20,7 +20,7 @@ pub async fn run(app_config: Arc<Config>, broadcast_sender: Sender<BroadcastMess
let addr = SocketAddr::from(([0, 0, 0, 0], 7878));
let listener = TcpListener::bind(addr).await.unwrap();
info!("Listening on {}.", addr);
serve(listener, app).await.unwrap();
unreachable!()
}

View File

@@ -1,3 +1,4 @@
use crate::config::{ALPACA_CRYPTO_DATA_URL, ALPACA_STOCK_DATA_URL};
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
@@ -10,6 +11,15 @@ pub enum Class {
Crypto = 2,
}
impl Class {
pub const fn get_data_url(self) -> &'static str {
match self {
Self::UsEquity => ALPACA_STOCK_DATA_URL,
Self::Crypto => ALPACA_CRYPTO_DATA_URL,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum Exchange {