Clean up error propagation

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-08 18:13:52 +00:00
parent 52e88f4bc9
commit 76bf2fddcb
24 changed files with 465 additions and 325 deletions

14
Cargo.lock generated
View File

@@ -1180,9 +1180,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "jobserver"
version = "0.1.27"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d"
checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6"
dependencies = [
"libc",
]
@@ -1419,19 +1419,18 @@ checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.45"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.17"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a"
dependencies = [
"autocfg",
]
@@ -1666,6 +1665,7 @@ dependencies = [
"html-escape",
"http 1.0.0",
"itertools 0.12.1",
"lazy_static",
"log",
"log4rs",
"regex",

View File

@@ -54,3 +54,4 @@ html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77"
itertools = "0.12.1"
lazy_static = "1.4.0"

View File

@@ -1,16 +1,18 @@
use crate::types::Asset;
use clickhouse::Client;
use clickhouse::{error::Error, Client};
use serde::Serialize;
pub async fn select(clickhouse_client: &Client) -> Vec<Asset> {
pub async fn select(clickhouse_client: &Client) -> Result<Vec<Asset>, Error> {
clickhouse_client
.query("SELECT ?fields FROM assets FINAL")
.fetch_all::<Asset>()
.await
.unwrap()
}
pub async fn select_where_symbol<T>(clickhouse_client: &Client, symbol: &T) -> Option<Asset>
pub async fn select_where_symbol<T>(
clickhouse_client: &Client,
symbol: &T,
) -> Result<Option<Asset>, Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -19,22 +21,21 @@ where
.bind(symbol)
.fetch_optional::<Asset>()
.await
.unwrap()
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, assets: T)
pub async fn upsert_batch<T>(clickhouse_client: &Client, assets: T) -> Result<(), Error>
where
T: IntoIterator<Item = Asset> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("assets").unwrap();
let mut insert = clickhouse_client.insert("assets")?;
for asset in assets {
insert.write(&asset).await.unwrap();
insert.write(&asset).await?;
}
insert.end().await.unwrap();
insert.end().await
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -43,5 +44,4 @@ where
.bind(symbols)
.execute()
.await
.unwrap();
}

View File

@@ -1,8 +1,8 @@
use crate::types::Backfill;
use clickhouse::Client;
use clickhouse::{error::Error, Client};
use serde::Serialize;
use std::fmt::Display;
use tokio::join;
use std::fmt::{Display, Formatter};
use tokio::try_join;
pub enum Table {
Bars,
@@ -10,7 +10,7 @@ pub enum Table {
}
impl Display for Table {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Bars => write!(f, "backfills_bars"),
Self::News => write!(f, "backfills_news"),
@@ -22,7 +22,7 @@ pub async fn select_latest_where_symbol<T>(
clickhouse_client: &Client,
table: &Table,
symbol: &T,
) -> Option<Backfill>
) -> Result<Option<Backfill>, Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -33,16 +33,23 @@ where
.bind(symbol)
.fetch_optional::<Backfill>()
.await
.unwrap()
}
pub async fn upsert(clickhouse_client: &Client, table: &Table, backfill: &Backfill) {
let mut insert = clickhouse_client.insert(&table.to_string()).unwrap();
insert.write(backfill).await.unwrap();
insert.end().await.unwrap();
pub async fn upsert(
clickhouse_client: &Client,
table: &Table,
backfill: &Backfill,
) -> Result<(), Error> {
let mut insert = clickhouse_client.insert(&table.to_string())?;
insert.write(backfill).await?;
insert.end().await
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, table: &Table, symbols: &[T])
pub async fn delete_where_symbols<T>(
clickhouse_client: &Client,
table: &Table,
symbols: &[T],
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -51,16 +58,14 @@ where
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
let delete_bars_future = async {
clickhouse_client
.query("DELETE FROM backfills_bars WHERE symbol NOT IN (SELECT symbol FROM assets)")
.execute()
.await
.unwrap();
};
let delete_news_future = async {
@@ -68,8 +73,7 @@ pub async fn cleanup(clickhouse_client: &Client) {
.query("DELETE FROM backfills_news WHERE symbol NOT IN (SELECT symbol FROM assets)")
.execute()
.await
.unwrap();
};
join!(delete_bars_future, delete_news_future);
try_join!(delete_bars_future, delete_news_future).map(|_| ())
}

View File

@@ -1,26 +1,26 @@
use crate::types::Bar;
use clickhouse::Client;
use clickhouse::{error::Error, Client};
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, bar: &Bar) {
let mut insert = clickhouse_client.insert("bars").unwrap();
insert.write(bar).await.unwrap();
insert.end().await.unwrap();
pub async fn upsert(clickhouse_client: &Client, bar: &Bar) -> Result<(), Error> {
let mut insert = clickhouse_client.insert("bars")?;
insert.write(bar).await?;
insert.end().await
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, bars: T)
pub async fn upsert_batch<T>(clickhouse_client: &Client, bars: T) -> Result<(), Error>
where
T: IntoIterator<Item = Bar> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("bars").unwrap();
let mut insert = clickhouse_client.insert("bars")?;
for bar in bars {
insert.write(&bar).await.unwrap();
insert.write(&bar).await?;
}
insert.end().await.unwrap();
insert.end().await
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -29,13 +29,11 @@ where
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets)")
.execute()
.await
.unwrap();
}

View File

@@ -1,26 +1,26 @@
use crate::types::News;
use clickhouse::Client;
use clickhouse::{error::Error, Client};
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, news: &News) {
let mut insert = clickhouse_client.insert("news").unwrap();
insert.write(news).await.unwrap();
insert.end().await.unwrap();
pub async fn upsert(clickhouse_client: &Client, news: &News) -> Result<(), Error> {
let mut insert = clickhouse_client.insert("news")?;
insert.write(news).await?;
insert.end().await
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, news: T)
pub async fn upsert_batch<T>(clickhouse_client: &Client, news: T) -> Result<(), Error>
where
T: IntoIterator<Item = News> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("news").unwrap();
let mut insert = clickhouse_client.insert("news")?;
for news in news {
insert.write(&news).await.unwrap();
insert.write(&news).await?;
}
insert.end().await.unwrap();
insert.end().await
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
@@ -29,15 +29,13 @@ where
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)
.execute()
.await
.unwrap();
}

View File

@@ -19,31 +19,34 @@ use tokio::{spawn, sync::mpsc};
async fn main() {
dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let app_config = Config::arc_from_env();
let config = Config::arc_from_env();
cleanup(&app_config.clickhouse_client).await;
cleanup(&config.clickhouse_client).await.unwrap();
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
spawn(threads::data::run(
app_config.clone(),
config.clone(),
data_receiver,
clock_receiver,
));
spawn(threads::clock::run(app_config.clone(), clock_sender));
spawn(threads::clock::run(config.clone(), clock_sender));
let assets = database::assets::select(&app_config.clickhouse_client)
let assets = database::assets::select(&config.clickhouse_client)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let (data_message, data_receiver) =
threads::data::Message::new(threads::data::Action::Add, assets);
data_sender.send(data_message).await.unwrap();
data_receiver.await.unwrap();
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
routes::run(app_config, data_sender).await;
routes::run(config, data_sender).await;
}

View File

@@ -1,7 +1,7 @@
use crate::{
config::Config,
database, threads,
types::{alpaca::api::incoming, Asset},
create_send_await, database, threads,
types::{alpaca, Asset},
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
@@ -10,17 +10,23 @@ use std::sync::Arc;
use tokio::sync::mpsc;
pub async fn get(
Extension(app_config): Extension<Arc<Config>>,
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(&app_config.clickhouse_client).await;
let assets = database::assets::select(&config.clickhouse_client)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
pub async fn get_where_symbol(
Extension(app_config): Extension<Arc<Config>>,
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol).await;
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
})
@@ -32,50 +38,58 @@ pub struct AddAssetRequest {
}
pub async fn add(
Extension(app_config): Extension<Arc<Config>>,
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetRequest>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(&app_config.clickhouse_client, &request.symbol)
if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = incoming::asset::get_by_symbol(&app_config, &request.symbol).await?;
let asset = alpaca::api::incoming::asset::get_by_symbol(&config, &request.symbol, None)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?;
if !asset.tradable || !asset.fractionable {
return Err(StatusCode::FORBIDDEN);
}
let asset = Asset::from(asset);
let (data_message, data_response) = threads::data::Message::new(
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
vec![(asset.symbol, asset.class)],
vec![(asset.symbol, asset.class)]
);
data_sender.send(data_message).await.unwrap();
data_response.await.unwrap();
Ok(StatusCode::CREATED)
}
pub async fn delete(
Extension(app_config): Extension<Arc<Config>>,
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol)
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let (asset_status_message, asset_status_response) = threads::data::Message::new(
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
vec![(asset.symbol, asset.class)],
vec![(asset.symbol, asset.class)]
);
data_sender.send(asset_status_message).await.unwrap();
asset_status_response.await.unwrap();
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -10,14 +10,14 @@ use log::info;
use std::{net::SocketAddr, sync::Arc};
use tokio::{net::TcpListener, sync::mpsc};
pub async fn run(app_config: Arc<Config>, data_sender: mpsc::Sender<threads::data::Message>) {
pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::Message>) {
let app = Router::new()
.route("/health", get(health::get))
.route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add))
.route("/assets/:symbol", delete(assets::delete))
.layer(Extension(app_config))
.layer(Extension(config))
.layer(Extension(data_sender));
let addr = SocketAddr::from(([0, 0, 0, 0], 7878));

View File

@@ -1,4 +1,8 @@
use crate::{config::Config, types::alpaca, utils::duration_until};
use crate::{
config::Config,
types::alpaca,
utils::{backoff, duration_until},
};
use log::info;
use std::sync::Arc;
use time::OffsetDateTime;
@@ -30,9 +34,11 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
}
}
pub async fn run(app_config: Arc<Config>, sender: mpsc::Sender<Message>) {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock = alpaca::api::incoming::clock::get(&app_config).await;
let clock = alpaca::api::incoming::clock::get(&config, Some(backoff::infinite()))
.await
.unwrap();
let sleep_until = duration_until(if clock.is_open {
info!("Market is open, will close at {}.", clock.next_close);

View File

@@ -9,7 +9,7 @@ use crate::{
Source,
},
news::Prediction,
Bar, Class, News,
Backfill, Bar, Class, News,
},
utils::{
duration_until, last_minute, remove_slash_from_pair, FIFTEEN_MINUTES, ONE_MINUTE,
@@ -18,14 +18,15 @@ use crate::{
};
use async_trait::async_trait;
use futures_util::future::join_all;
use log::{info, warn};
use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::{
join, spawn,
spawn,
sync::{mpsc, oneshot, Mutex},
task::{block_in_place, JoinHandle},
time::sleep,
try_join,
};
pub enum Action {
@@ -64,9 +65,12 @@ impl Message {
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfill(&self, symbol: String) -> Option<crate::types::Backfill>;
async fn delete_backfills(&self, symbol: &[String]);
async fn delete_data(&self, symbol: &[String]);
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime);
fn log_string(&self) -> &'static str;
@@ -111,13 +115,14 @@ async fn handle_backfill_message(
backfill_jobs.insert(
symbol.clone(),
spawn(async move {
let fetch_from = handler
let fetch_from = match handler
.select_latest_backfill(symbol.clone())
.await
.as_ref()
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
let fetch_to = last_minute();
@@ -142,10 +147,11 @@ async fn handle_backfill_message(
}
}
join!(
try_join!(
handler.delete_backfills(&message.symbols),
handler.delete_data(&message.symbols)
);
)
.unwrap();
}
}
@@ -153,10 +159,10 @@ async fn handle_backfill_message(
}
struct BarHandler {
app_config: Arc<Config>,
config: Arc<Config>,
data_url: &'static str,
api_query_constructor: fn(
app_config: &Arc<Config>,
config: &Arc<Config>,
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
@@ -165,7 +171,7 @@ struct BarHandler {
}
fn us_equity_query_constructor(
app_config: &Arc<Config>,
config: &Arc<Config>,
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
@@ -179,7 +185,7 @@ fn us_equity_query_constructor(
limit: Some(10000),
adjustment: None,
asof: None,
feed: Some(app_config.alpaca_source),
feed: Some(config.alpaca_source),
currency: None,
page_token: next_page_token,
sort: Some(Sort::Asc),
@@ -206,30 +212,33 @@ fn crypto_query_constructor(
#[async_trait]
impl Handler for BarHandler {
async fn select_latest_backfill(&self, symbol: String) -> Option<crate::types::Backfill> {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills::select_latest_where_symbol(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::Bars,
&symbol,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) {
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills::delete_where_symbols(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::Bars,
symbols,
)
.await;
.await
}
async fn delete_data(&self, symbols: &[String]) {
database::bars::delete_where_symbols(&self.app_config.clickhouse_client, symbols).await;
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
if self.app_config.alpaca_source == Source::Iex {
if self.config.alpaca_source == Source::Iex {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing bar backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
@@ -243,18 +252,23 @@ impl Handler for BarHandler {
let mut next_page_token = None;
loop {
let message = alpaca::api::incoming::bar::get_historical(
&self.app_config,
let Ok(message) = alpaca::api::incoming::bar::get_historical(
&self.config,
self.data_url,
&(self.api_query_constructor)(
&self.app_config,
&self.config,
symbol.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
.await
else {
error!("Failed to backfill bars for {}.", symbol);
return;
};
message.bars.into_iter().for_each(|(symbol, bar_vec)| {
for bar in bar_vec {
@@ -274,13 +288,17 @@ impl Handler for BarHandler {
}
let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&self.app_config.clickhouse_client, bars).await;
database::bars::upsert_batch(&self.config.clickhouse_client, bars)
.await
.unwrap();
database::backfills::upsert(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::Bars,
&backfill,
)
.await;
.await
.unwrap();
info!("Backfilled bars for {}.", symbol);
}
@@ -291,31 +309,34 @@ impl Handler for BarHandler {
}
struct NewsHandler {
app_config: Arc<Config>,
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
async fn select_latest_backfill(&self, symbol: String) -> Option<crate::types::Backfill> {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills::select_latest_where_symbol(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::News,
&symbol,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) {
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills::delete_where_symbols(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::News,
symbols,
)
.await;
.await
}
async fn delete_data(&self, symbols: &[String]) {
database::news::delete_where_symbols(&self.app_config.clickhouse_client, symbols).await;
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
@@ -331,8 +352,8 @@ impl Handler for NewsHandler {
let mut next_page_token = None;
loop {
let message = alpaca::api::incoming::news::get_historical(
&self.app_config,
let Ok(message) = alpaca::api::incoming::news::get_historical(
&self.config,
&api::outgoing::news::News {
symbols: vec![remove_slash_from_pair(&symbol)],
start: Some(fetch_from),
@@ -343,8 +364,13 @@ impl Handler for NewsHandler {
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
},
None,
)
.await;
.await
else {
error!("Failed to backfill news for {}.", symbol);
return;
};
message.news.into_iter().for_each(|news_item| {
news.push(News::from(news_item));
@@ -366,23 +392,19 @@ impl Handler for NewsHandler {
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions = join_all(
inputs
.chunks(self.app_config.max_bert_inputs)
.map(|inputs| {
let sequence_classifier = self.app_config.sequence_classifier.clone();
async move {
let sequence_classifier = sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}
}),
)
let predictions = join_all(inputs.chunks(self.config.max_bert_inputs).map(|inputs| {
let sequence_classifier = self.config.sequence_classifier.clone();
async move {
let sequence_classifier = sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}
}))
.await
.into_iter()
.flatten();
@@ -398,13 +420,17 @@ impl Handler for NewsHandler {
.collect::<Vec<_>>();
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.app_config.clickhouse_client, news).await;
database::news::upsert_batch(&self.config.clickhouse_client, news)
.await
.unwrap();
database::backfills::upsert(
&self.app_config.clickhouse_client,
&self.config.clickhouse_client,
&database::backfills::Table::News,
&backfill,
)
.await;
.await
.unwrap();
info!("Backfilled news for {}.", symbol);
}
@@ -414,18 +440,18 @@ impl Handler for NewsHandler {
}
}
pub fn create_handler(thread_type: ThreadType, app_config: Arc<Config>) -> Box<dyn Handler> {
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler {
app_config,
config,
data_url: ALPACA_STOCK_DATA_URL,
api_query_constructor: us_equity_query_constructor,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarHandler {
app_config,
config,
data_url: ALPACA_CRYPTO_DATA_URL,
api_query_constructor: crypto_query_constructor,
}),
ThreadType::News => Box::new(NewsHandler { app_config }),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -6,9 +6,9 @@ use crate::{
config::{
Config, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_NEWS_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL,
},
database,
create_send_await, database,
types::{alpaca, Asset, Class},
utils::{authenticate, cleanup},
utils::{authenticate, backoff, cleanup},
};
use futures_util::{future::join_all, StreamExt};
use itertools::{Either, Itertools};
@@ -52,22 +52,22 @@ pub enum ThreadType {
}
pub async fn run(
app_config: Arc<Config>,
config: Arc<Config>,
mut receiver: mpsc::Receiver<Message>,
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(app_config.clone(), ThreadType::Bars(Class::UsEquity)).await;
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await;
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(app_config.clone(), ThreadType::Bars(Class::Crypto)).await;
init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await;
let (news_websocket_sender, news_backfill_sender) =
init_thread(app_config.clone(), ThreadType::News).await;
init_thread(config.clone(), ThreadType::News).await;
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
app_config.clone(),
config.clone(),
bars_us_equity_websocket_sender.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_websocket_sender.clone(),
@@ -79,7 +79,7 @@ pub async fn run(
}
Some(_) = clock_receiver.recv() => {
spawn(handle_clock_message(
app_config.clone(),
config.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_backfill_sender.clone(),
@@ -91,34 +91,33 @@ pub async fn run(
}
async fn init_thread(
app_config: Arc<Config>,
config: Arc<Config>,
thread_type: ThreadType,
) -> (
mpsc::Sender<websocket::Message>,
mpsc::Sender<backfill::Message>,
) {
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => format!(
"{}/{}",
ALPACA_STOCK_WEBSOCKET_URL, &app_config.alpaca_source
),
ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_STOCK_WEBSOCKET_URL, &config.alpaca_source)
}
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_WEBSOCKET_URL.into(),
};
let (websocket, _) = connect_async(websocket_url).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
authenticate(&app_config, &mut websocket_sink, &mut websocket_stream).await;
authenticate(&config, &mut websocket_sink, &mut websocket_stream).await;
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(
Arc::new(backfill::create_handler(thread_type, app_config.clone())),
Arc::new(backfill::create_handler(thread_type, config.clone())),
backfill_receiver,
));
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run(
Arc::new(websocket::create_handler(thread_type, app_config.clone())),
Arc::new(websocket::create_handler(thread_type, config.clone())),
websocket_receiver,
websocket_stream,
websocket_sink,
@@ -127,17 +126,9 @@ async fn init_thread(
(websocket_sender, backfill_sender)
}
macro_rules! create_send_await {
($sender:expr, $action:expr, $($contents:expr),*) => {
let (message, receiver) = $action($($contents),*);
$sender.send(message).await.unwrap();
receiver.await.unwrap();
};
}
#[allow(clippy::too_many_arguments)]
async fn handle_message(
app_config: Arc<Config>,
config: Arc<Config>,
bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
@@ -221,22 +212,30 @@ async fn handle_message(
match message.action {
Action::Add => {
let assets =
join_all(symbols.into_iter().map(|symbol| {
let app_config = app_config.clone();
async move {
alpaca::api::incoming::asset::get_by_symbol(&app_config, &symbol).await
}
}))
.await
.into_iter()
.map(|result| Asset::from(result.unwrap()))
.collect::<Vec<_>>();
let assets = join_all(symbols.into_iter().map(|symbol| {
let config = config.clone();
async move {
Asset::from(
alpaca::api::incoming::asset::get_by_symbol(
&config,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap(),
)
}
}))
.await;
database::assets::upsert_batch(&app_config.clickhouse_client, assets).await;
database::assets::upsert_batch(&config.clickhouse_client, assets)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(&app_config.clickhouse_client, &symbols).await;
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols)
.await
.unwrap();
}
}
@@ -244,14 +243,16 @@ async fn handle_message(
}
async fn handle_clock_message(
app_config: Arc<Config>,
config: Arc<Config>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
) {
cleanup(&app_config.clickhouse_client).await;
cleanup(&config.clickhouse_client).await.unwrap();
let assets = database::assets::select(&app_config.clickhouse_client).await;
let assets = database::assets::select(&config.clickhouse_client)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()

View File

@@ -221,7 +221,7 @@ async fn handle_websocket_message(
}
struct BarsHandler {
app_config: Arc<Config>,
config: Arc<Config>,
}
#[async_trait]
@@ -286,7 +286,10 @@ impl Handler for BarsHandler {
| websocket::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.app_config.clickhouse_client, &bar).await;
database::bars::upsert(&self.config.clickhouse_client, &bar)
.await
.unwrap();
}
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
@@ -298,7 +301,7 @@ impl Handler for BarsHandler {
}
struct NewsHandler {
app_config: Arc<Config>,
config: Arc<Config>,
}
#[async_trait]
@@ -373,7 +376,7 @@ impl Handler for NewsHandler {
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.app_config.sequence_classifier.lock().await;
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
@@ -388,7 +391,10 @@ impl Handler for NewsHandler {
confidence: prediction.confidence,
..news
};
database::news::upsert(&self.app_config.clickhouse_client, &news).await;
database::news::upsert(&self.config.clickhouse_client, &news)
.await
.unwrap();
}
websocket::incoming::Message::Success(_) => {}
websocket::incoming::Message::Error(message) => {
@@ -401,9 +407,9 @@ impl Handler for NewsHandler {
}
}
pub fn create_handler(thread_type: ThreadType, app_config: Arc<Config>) -> Box<dyn Handler> {
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(_) => Box::new(BarsHandler { app_config }),
ThreadType::News => Box::new(NewsHandler { app_config }),
ThreadType::Bars(_) => Box::new(BarsHandler { config }),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -1,11 +1,12 @@
use crate::{
config::{Config, ALPACA_ASSET_API_URL},
types::{self, alpaca::api::impl_from_enum},
impl_from_enum, types,
};
use backoff::{future::retry, ExponentialBackoff};
use http::StatusCode;
use backoff::{future::retry_notify, ExponentialBackoff};
use log::warn;
use reqwest::Error;
use serde::Deserialize;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
@@ -87,26 +88,38 @@ impl From<Asset> for types::Asset {
}
}
pub async fn get_by_symbol(app_config: &Arc<Config>, symbol: &str) -> Result<Asset, StatusCode> {
retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(&format!("{ALPACA_ASSET_API_URL}/{symbol}"))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::NOT_FOUND) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
})
pub async fn get_by_symbol(
config: &Arc<Config>,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
config.alpaca_rate_limit.until_ready().await;
config
.alpaca_client
.get(&format!("{ALPACA_ASSET_API_URL}/{symbol}"))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::FORBIDDEN | reqwest::StatusCode::NOT_FOUND) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::NOT_FOUND) => StatusCode::NOT_FOUND,
_ => panic!("Unexpected error: {e}."),
})
}

View File

@@ -2,9 +2,11 @@ use crate::{
config::Config,
types::{self, alpaca::api::outgoing},
};
use backoff::{future::retry, ExponentialBackoff};
use backoff::{future::retry_notify, ExponentialBackoff};
use log::warn;
use reqwest::Error;
use serde::Deserialize;
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Deserialize)]
@@ -51,23 +53,37 @@ pub struct Message {
}
pub async fn get_historical(
app_config: &Arc<Config>,
config: &Arc<Config>,
data_url: &str,
query: &outgoing::bar::Bar,
) -> Message {
retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
})
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
config.alpaca_rate_limit.until_ready().await;
config
.alpaca_client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::FORBIDDEN) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
.unwrap()
}

View File

@@ -1,7 +1,9 @@
use crate::config::{Config, ALPACA_CLOCK_API_URL};
use backoff::{future::retry, ExponentialBackoff};
use backoff::{future::retry_notify, ExponentialBackoff};
use log::warn;
use reqwest::Error;
use serde::Deserialize;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
@@ -15,18 +17,35 @@ pub struct Clock {
pub next_close: OffsetDateTime,
}
pub async fn get(app_config: &Arc<Config>) -> Clock {
retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(ALPACA_CLOCK_API_URL)
.send()
.await?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
})
pub async fn get(
config: &Arc<Config>,
backoff: Option<ExponentialBackoff>,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
config.alpaca_rate_limit.until_ready().await;
config
.alpaca_client
.get(ALPACA_CLOCK_API_URL)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::FORBIDDEN) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
.unwrap()
}

View File

@@ -3,9 +3,11 @@ use crate::{
types::{self, alpaca::api::outgoing},
utils::{add_slash_to_pair, normalize_news_content},
};
use backoff::{future::retry, ExponentialBackoff};
use backoff::{future::retry_notify, ExponentialBackoff};
use log::warn;
use reqwest::Error;
use serde::Deserialize;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
@@ -70,20 +72,33 @@ pub struct Message {
pub next_page_token: Option<String>,
}
pub async fn get_historical(app_config: &Arc<Config>, query: &outgoing::news::News) -> Message {
retry(ExponentialBackoff::default(), || async {
app_config.alpaca_rate_limit.until_ready().await;
app_config
.alpaca_client
.get(ALPACA_NEWS_DATA_URL)
.query(query)
.send()
.await?
.error_for_status()?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
})
pub async fn get_historical(
config: &Arc<Config>,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
config.alpaca_rate_limit.until_ready().await;
config
.alpaca_client
.get(ALPACA_NEWS_DATA_URL)
.query(query)
.send()
.await?
.error_for_status()?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
.unwrap()
}

View File

@@ -1,24 +1,2 @@
pub mod incoming;
pub mod outgoing;
macro_rules! impl_from_enum {
($source:ty, $target:ty, $( $variant:ident ),* ) => {
impl From<$source> for $target {
fn from(item: $source) -> Self {
match item {
$( <$source>::$variant => <$target>::$variant, )*
}
}
}
impl From<$target> for $source {
fn from(item: $target) -> Self {
match item {
$( <$target>::$variant => <$source>::$variant, )*
}
}
}
};
}
use impl_from_enum;

8
src/utils/backoff.rs Normal file
View File

@@ -0,0 +1,8 @@
use backoff::ExponentialBackoff;
pub fn infinite() -> ExponentialBackoff {
ExponentialBackoff {
max_elapsed_time: None,
..ExponentialBackoff::default()
}
}

View File

@@ -1,11 +1,12 @@
use crate::database;
use clickhouse::Client;
use tokio::join;
use clickhouse::{error::Error, Client};
use tokio::try_join;
pub async fn cleanup(clickhouse_client: &Client) {
join!(
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
try_join!(
database::bars::cleanup(clickhouse_client),
database::news::cleanup(clickhouse_client),
database::backfills::cleanup(clickhouse_client)
);
)
.map(|_| ())
}

29
src/utils/macros.rs Normal file
View File

@@ -0,0 +1,29 @@
#[macro_export]
macro_rules! impl_from_enum {
($source:ty, $target:ty, $( $variant:ident ),* ) => {
impl From<$source> for $target {
fn from(item: $source) -> Self {
match item {
$( <$source>::$variant => <$target>::$variant, )*
}
}
}
impl From<$target> for $source {
fn from(item: $target) -> Self {
match item {
$( <$target>::$variant => <$source>::$variant, )*
}
}
}
};
}
#[macro_export]
macro_rules! create_send_await {
($sender:expr, $action:expr, $($contents:expr),*) => {
let (message, receiver) = $action($($contents),*);
$sender.send(message).await.unwrap();
receiver.await.unwrap()
};
}

View File

@@ -1,4 +1,6 @@
pub mod backoff;
pub mod cleanup;
pub mod macros;
pub mod news;
pub mod time;
pub mod websocket;

View File

@@ -1,13 +1,17 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static;
use regex::Regex;
pub fn normalize_news_content(content: &str) -> String {
let re_tags = Regex::new("<[^>]+>").unwrap();
let re_spaces = Regex::new("[\\u00A0\\s]+").unwrap();
lazy_static! {
static ref RE_TAGS: Regex = Regex::new("<[^>]+>").unwrap();
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap();
}
pub fn normalize_news_content(content: &str) -> String {
let content = content.replace('\n', " ");
let content = re_tags.replace_all(&content, "");
let content = re_spaces.replace_all(&content, " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim();
@@ -15,9 +19,7 @@ pub fn normalize_news_content(content: &str) -> String {
}
pub fn add_slash_to_pair(pair: &str) -> String {
let regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap();
regex.captures(pair).map_or_else(
RE_SLASH.captures(pair).map_or_else(
|| pair.to_string(),
|caps| format!("{}/{}", &caps[1], &caps[2]),
)

View File

@@ -10,7 +10,7 @@ use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
app_config: &Arc<Config>,
config: &Arc<Config>,
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
@@ -28,8 +28,8 @@ pub async fn authenticate(
sink.send(Message::Text(
to_string(&websocket::outgoing::Message::Auth(
websocket::outgoing::auth::Message {
key: app_config.alpaca_api_key.clone(),
secret: app_config.alpaca_api_secret.clone(),
key: config.alpaca_api_key.clone(),
secret: config.alpaca_api_secret.clone(),
},
))
.unwrap(),