Add multiple asset adding route

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-09 20:13:36 +00:00
parent 080f91b044
commit 681d7393d7
31 changed files with 754 additions and 282 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables # will have compiled files and executables
debug/ debug/
target/ target/
log/
# These are backup files generated by rustfmt # These are backup files generated by rustfmt
**/*.rs.bk **/*.rs.bk

View File

@@ -4,7 +4,14 @@ appenders:
encoder: encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}" pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root: root:
level: info level: info
appenders: appenders:
- stdout - stdout
- file

View File

@@ -13,7 +13,7 @@ use rust_bert::{
resources::LocalResource, resources::LocalResource,
}; };
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::{Mutex, Semaphore};
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
@@ -51,17 +51,21 @@ lazy_static! {
Mode::Paper => String::from("paper-api"), Mode::Paper => String::from("paper-api"),
} }
); );
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS") pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
.expect("MAX_BERT_INPUTS must be set.") .expect("BERT_MAX_INPUTS must be set.")
.parse() .parse()
.expect("MAX_BERT_INPUTS must be a positive integer."); .expect("BERT_MAX_INPUTS must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
} }
pub struct Config { pub struct Config {
pub alpaca_client: Client, pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client, pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
pub sequence_classifier: Mutex<SequenceClassificationModel>, pub sequence_classifier: Mutex<SequenceClassificationModel>,
} }
@@ -95,6 +99,7 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
) )
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
sequence_classifier: Mutex::new( sequence_classifier: Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new( SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert, ModelType::Bert,

View File

@@ -1,8 +1,11 @@
use std::sync::Arc;
use crate::{ use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
}; };
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use serde::Serialize; use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets"); select!(Asset, "assets");
select_where_symbol!(Asset, "assets"); select_where_symbol!(Asset, "assets");
@@ -11,14 +14,16 @@ delete_where_symbols!("assets");
optimize!("assets"); optimize!("assets");
pub async fn update_status_where_symbol<T>( pub async fn update_status_where_symbol<T>(
clickhouse_client: &Client, client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T, symbol: &T,
status: bool, status: bool,
) -> Result<(), Error> ) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?") .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status) .bind(status)
.bind(symbol) .bind(symbol)
@@ -27,14 +32,16 @@ where
} }
pub async fn update_qty_where_symbol<T>( pub async fn update_qty_where_symbol<T>(
clickhouse_client: &Client, client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T, symbol: &T,
qty: f64, qty: f64,
) -> Result<(), Error> ) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty) .bind(qty)
.bind(symbol) .bind(symbol)

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use crate::{ use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert,
}; };
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use tokio::sync::Semaphore;
select_where_symbol!(Backfill, "backfills_bars"); select_where_symbols!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars"); upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars"); delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars"); cleanup!("backfills_bars");
optimize!("backfills_bars"); optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { pub async fn unfresh(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true") .query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute() .execute()
.await .await

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use crate::{ use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert,
}; };
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use tokio::sync::Semaphore;
select_where_symbol!(Backfill, "backfills_news"); select_where_symbols!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news"); upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news"); delete_where_symbols!("backfills_news");
cleanup!("backfills_news"); cleanup!("backfills_news");
optimize!("backfills_news"); optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { pub async fn unfresh(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true") .query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute() .execute()
.await .await

View File

@@ -1,7 +1,21 @@
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars"); upsert!(Bar, "bars");
upsert_batch!(Bar, "bars"); upsert_batch!(Bar, "bars");
delete_where_symbols!("bars"); delete_where_symbols!("bars");
cleanup!("bars");
optimize!("bars"); optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,11 +1,14 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar}; use crate::{optimize, types::Calendar};
use clickhouse::error::Error; use clickhouse::{error::Error, Client};
use tokio::try_join; use tokio::{sync::Semaphore, try_join};
optimize!("calendar"); optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, T>( pub async fn upsert_batch_and_delete<'a, T>(
client: &clickhouse::Client, client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: T, records: T,
) -> Result<(), Error> ) -> Result<(), Error>
where where
@@ -34,5 +37,6 @@ where
.await .await
}; };
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ()) try_join!(upsert_future, delete_future).map(|_| ())
} }

View File

@@ -15,7 +15,9 @@ macro_rules! select {
($record:ty, $table_name:expr) => { ($record:ty, $table_name:expr) => {
pub async fn select( pub async fn select(
client: &clickhouse::Client, client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> { ) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name)) .query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>() .fetch_all::<$record>()
@@ -29,11 +31,13 @@ macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => { ($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>( pub async fn select_where_symbol<T>(
client: &clickhouse::Client, client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T, symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error> ) -> Result<Option<$record>, clickhouse::error::Error>
where where
T: AsRef<str> + serde::Serialize + Send + Sync, T: AsRef<str> + serde::Serialize + Send + Sync,
{ {
let _ = concurrency_limiter.acquire().await.unwrap();
client client
.query(&format!( .query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?", "SELECT ?fields FROM {} FINAL WHERE symbol = ?",
@@ -46,13 +50,39 @@ macro_rules! select_where_symbol {
}; };
} }
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export] #[macro_export]
macro_rules! upsert { macro_rules! upsert {
($record:ty, $table_name:expr) => { ($record:ty, $table_name:expr) => {
pub async fn upsert( pub async fn upsert(
client: &clickhouse::Client, client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record, record: &$record,
) -> Result<(), clickhouse::error::Error> { ) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?; let mut insert = client.insert($table_name)?;
insert.write(record).await?; insert.write(record).await?;
insert.end().await insert.end().await
@@ -65,12 +95,14 @@ macro_rules! upsert_batch {
($record:ty, $table_name:expr) => { ($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>( pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client, client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: T, records: T,
) -> Result<(), clickhouse::error::Error> ) -> Result<(), clickhouse::error::Error>
where where
T: IntoIterator<Item = &'a $record> + Send + Sync, T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send, T::IntoIter: Send,
{ {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?; let mut insert = client.insert($table_name)?;
for record in records { for record in records {
insert.write(record).await?; insert.write(record).await?;
@@ -85,11 +117,13 @@ macro_rules! delete_where_symbols {
($table_name:expr) => { ($table_name:expr) => {
pub async fn delete_where_symbols<T>( pub async fn delete_where_symbols<T>(
client: &clickhouse::Client, client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T], symbols: &[T],
) -> Result<(), clickhouse::error::Error> ) -> Result<(), clickhouse::error::Error>
where where
T: AsRef<str> + serde::Serialize + Send + Sync, T: AsRef<str> + serde::Serialize + Send + Sync,
{ {
let _ = concurrency_limiter.acquire().await.unwrap();
client client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name)) .query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols) .bind(symbols)
@@ -102,7 +136,11 @@ macro_rules! delete_where_symbols {
#[macro_export] #[macro_export]
macro_rules! cleanup { macro_rules! cleanup {
($table_name:expr) => { ($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client client
.query(&format!( .query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)", "DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
@@ -117,7 +155,11 @@ macro_rules! cleanup {
#[macro_export] #[macro_export]
macro_rules! optimize { macro_rules! optimize {
($table_name:expr) => { ($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name)) .query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute() .execute()
@@ -126,27 +168,33 @@ macro_rules! optimize {
}; };
} }
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> { pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
info!("Cleaning up database."); info!("Cleaning up database.");
try_join!( try_join!(
bars::cleanup(clickhouse_client), bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client), news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client), backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client) backfills_news::cleanup(clickhouse_client, concurrency_limiter)
) )
.map(|_| ()) .map(|_| ())
} }
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> { pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
info!("Optimizing database."); info!("Optimizing database.");
try_join!( try_join!(
assets::optimize(clickhouse_client), assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client), bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client), news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client), backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client), backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client), orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client) calendar::optimize(clickhouse_client, concurrency_limiter)
) )
.map(|_| ()) .map(|_| ())
} }

View File

@@ -1,24 +1,33 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch}; use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use serde::Serialize; use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news"); upsert!(News, "news");
upsert_batch!(News, "news"); upsert_batch!(News, "news");
optimize!("news"); optimize!("news");
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))") .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols) .bind(symbols)
.execute() .execute()
.await .await
} }
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query( .query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))", "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
) )

View File

@@ -68,9 +68,13 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
.flat_map(&alpaca::api::incoming::order::Order::normalize) .flat_map(&alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
database::orders::upsert_batch(&config.clickhouse_client, &orders) database::orders::upsert_batch(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
info!("Rehydrated order data."); info!("Rehydrated order data.");
} }
@@ -92,9 +96,12 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
}; };
let assets_future = async { let assets_future = async {
database::assets::select(&config.clickhouse_client) database::assets::select(
.await &config.clickhouse_client,
.unwrap() &config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
}; };
let (mut positions, assets) = join!(positions_future, assets_future); let (mut positions, assets) = join!(positions_future, assets_future);
@@ -111,9 +118,13 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
database::assets::upsert_batch(&config.clickhouse_client, &assets) database::assets::upsert_batch(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
for position in positions.values() { for position in positions.values() {
warn!( warn!(

View File

@@ -22,17 +22,29 @@ async fn main() {
let config = Config::arc_from_env(); let config = Config::arc_from_env();
try_join!( try_join!(
database::backfills_bars::unfresh(&config.clickhouse_client), database::backfills_bars::unfresh(
database::backfills_news::unfresh(&config.clickhouse_client) &config.clickhouse_client,
&config.clickhouse_concurrency_limiter
),
database::backfills_news::unfresh(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter
)
) )
.unwrap(); .unwrap();
database::cleanup_all(&config.clickhouse_client) database::cleanup_all(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
database::optimize_all(&config.clickhouse_client) )
.await .await
.unwrap(); .unwrap();
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
init::check_account(&config).await; init::check_account(&config).await;
join!( join!(
@@ -53,12 +65,15 @@ async fn main() {
spawn(threads::clock::run(config.clone(), clock_sender)); spawn(threads::clock::run(config.clone(), clock_sender));
let assets = database::assets::select(&config.clickhouse_client) let assets = database::assets::select(
.await &config.clickhouse_client,
.unwrap() &config.clickhouse_concurrency_limiter,
.into_iter() )
.map(|asset| (asset.symbol, asset.class)) .await
.collect::<Vec<_>>(); .unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
create_send_await!( create_send_await!(
data_sender, data_sender,

View File

@@ -5,16 +5,22 @@ use crate::{
}; };
use axum::{extract::Path, Extension, Json}; use axum::{extract::Path, Extension, Json};
use http::StatusCode; use http::StatusCode;
use serde::Deserialize; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub async fn get( pub async fn get(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> { ) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(&config.clickhouse_client) let assets = database::assets::select(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; &config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets))) Ok((StatusCode::OK, Json(assets)))
} }
@@ -23,9 +29,13 @@ pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> { ) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) let asset = database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; &config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset))) Ok((StatusCode::OK, Json(asset)))
@@ -33,19 +43,101 @@ pub async fn get_where_symbol(
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct AddAssetRequest { pub struct AddAssetsRequest {
symbol: String, symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
} }
pub async fn add( pub async fn add(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetRequest>, Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(vec![], vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == alpaca::shared::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets.clone()
);
Ok((
StatusCode::CREATED,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> { ) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol) if database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? &config.clickhouse_concurrency_limiter,
.is_some() &symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{ {
return Err(StatusCode::CONFLICT); return Err(StatusCode::CONFLICT);
} }
@@ -53,7 +145,7 @@ pub async fn add(
let asset = alpaca::api::incoming::asset::get_by_symbol( let asset = alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&request.symbol, &symbol,
None, None,
) )
.await .await
@@ -64,7 +156,10 @@ pub async fn add(
}) })
})?; })?;
if !asset.tradable || !asset.fractionable { if asset.status != alpaca::shared::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN); return Err(StatusCode::FORBIDDEN);
} }
@@ -83,10 +178,14 @@ pub async fn delete(
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> { ) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) let asset = database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? &config.clickhouse_concurrency_limiter,
.ok_or(StatusCode::NOT_FOUND)?; &symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!( create_send_await!(
data_sender, data_sender,

View File

@@ -16,6 +16,7 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
.route("/assets", get(assets::get)) .route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol)) .route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add)) .route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete)) .route("/assets/:symbol", delete(assets::delete))
.layer(Extension(config)) .layer(Extension(config))
.layer(Extension(data_sender)); .layer(Extension(data_sender));

View File

@@ -74,9 +74,13 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
let sleep_future = sleep(sleep_until); let sleep_future = sleep(sleep_until);
let calendar_future = async { let calendar_future = async {
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar) database::calendar::upsert_batch_and_delete(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
}; };
join!(sleep_future, calendar_future); join!(sleep_future, calendar_future);

View File

@@ -2,7 +2,7 @@ use super::ThreadType;
use crate::{ use crate::{
config::{ config::{
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL, Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
MAX_BERT_INPUTS, BERT_MAX_INPUTS,
}, },
database, database,
types::{ types::{
@@ -30,24 +30,14 @@ pub enum Action {
Purge, Purge,
} }
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Backfill),
super::Action::Remove => Some(Action::Purge),
super::Action::Disable => None,
}
}
}
pub struct Message { pub struct Message {
pub action: Option<Action>, pub action: Action,
pub symbols: Vec<String>, pub symbols: Vec<String>,
pub response: oneshot::Sender<()>, pub response: oneshot::Sender<()>,
} }
impl Message { impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) { pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>(); let (sender, receiver) = oneshot::channel::<()>();
( (
Self { Self {
@@ -62,10 +52,10 @@ impl Message {
#[async_trait] #[async_trait]
pub trait Handler: Send + Sync { pub trait Handler: Send + Sync {
async fn select_latest_backfill( async fn select_latest_backfills(
&self, &self,
symbol: String, symbols: &[String],
) -> Result<Option<Backfill>, clickhouse::error::Error>; ) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
@@ -94,9 +84,17 @@ async fn handle_backfill_message(
let mut backfill_jobs = backfill_jobs.lock().await; let mut backfill_jobs = backfill_jobs.lock().await;
match message.action { match message.action {
Some(Action::Backfill) => { Action::Backfill => {
let log_string = handler.log_string(); let log_string = handler.log_string();
let backfills = handler
.select_latest_backfills(&message.symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
for symbol in message.symbols { for symbol in message.symbols {
if let Some(job) = backfill_jobs.get(&symbol) { if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() { if !job.is_finished() {
@@ -108,33 +106,30 @@ async fn handle_backfill_message(
} }
} }
let fetch_from = backfills
.get(&symbol)
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let handler = handler.clone(); let handler = handler.clone();
backfill_jobs.insert( backfill_jobs.insert(
symbol.clone(), symbol.clone(),
spawn(async move { spawn(async move {
let fetch_from = match handler
.select_latest_backfill(symbol.clone())
.await
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
handler.queue_backfill(&symbol, fetch_to).await; handler.queue_backfill(&symbol, fetch_to).await;
handler.backfill(symbol, fetch_from, fetch_to).await; handler.backfill(symbol, fetch_from, fetch_to).await;
}), }),
); );
} }
} }
Some(Action::Purge) => { Action::Purge => {
for symbol in &message.symbols { for symbol in &message.symbols {
if let Some(job) = backfill_jobs.remove(symbol) { if let Some(job) = backfill_jobs.remove(symbol) {
if !job.is_finished() { if !job.is_finished() {
@@ -150,7 +145,6 @@ async fn handle_backfill_message(
) )
.unwrap(); .unwrap();
} }
None => {}
} }
message.response.send(()).unwrap(); message.response.send(()).unwrap();
@@ -199,20 +193,34 @@ fn crypto_query_constructor(
#[async_trait] #[async_trait]
impl Handler for BarHandler { impl Handler for BarHandler {
async fn select_latest_backfill( async fn select_latest_backfills(
&self, &self,
symbol: String, symbols: &[String],
) -> Result<Option<Backfill>, clickhouse::error::Error> { ) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols) database::backfills_bars::delete_where_symbols(
.await &self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
@@ -230,7 +238,7 @@ impl Handler for BarHandler {
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let Ok(message) = alpaca::api::incoming::bar::get_historical( let Ok(message) = alpaca::api::incoming::bar::get(
&self.config.alpaca_client, &self.config.alpaca_client,
&self.config.alpaca_rate_limiter, &self.config.alpaca_rate_limiter,
self.data_url, self.data_url,
@@ -267,12 +275,20 @@ impl Handler for BarHandler {
let backfill = bars.last().unwrap().clone().into(); let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&self.config.clickhouse_client, &bars) database::bars::upsert_batch(
.await &self.config.clickhouse_client,
.unwrap(); &self.config.clickhouse_concurrency_limiter,
database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill) &bars,
.await )
.unwrap(); .await
.unwrap();
database::backfills_bars::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfill,
)
.await
.unwrap();
info!("Backfilled bars for {}.", symbol); info!("Backfilled bars for {}.", symbol);
} }
@@ -288,20 +304,34 @@ struct NewsHandler {
#[async_trait] #[async_trait]
impl Handler for NewsHandler { impl Handler for NewsHandler {
async fn select_latest_backfill( async fn select_latest_backfills(
&self, &self,
symbol: String, symbols: &[String],
) -> Result<Option<Backfill>, clickhouse::error::Error> { ) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols) database::backfills_news::delete_where_symbols(
.await &self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
} }
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
@@ -317,7 +347,7 @@ impl Handler for NewsHandler {
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let Ok(message) = alpaca::api::incoming::news::get_historical( let Ok(message) = alpaca::api::incoming::news::get(
&self.config.alpaca_client, &self.config.alpaca_client,
&self.config.alpaca_rate_limiter, &self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News { &alpaca::api::outgoing::news::News {
@@ -355,7 +385,7 @@ impl Handler for NewsHandler {
.map(|news| format!("{}\n\n{}", news.headline, news.content)) .map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move { let predictions = join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await; let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| { block_in_place(|| {
sequence_classifier sequence_classifier
@@ -381,12 +411,20 @@ impl Handler for NewsHandler {
let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.config.clickhouse_client, &news) database::news::upsert_batch(
.await &self.config.clickhouse_client,
.unwrap(); &self.config.clickhouse_concurrency_limiter,
database::backfills_news::upsert(&self.config.clickhouse_client, &backfill) &news,
.await )
.unwrap(); .await
.unwrap();
database::backfills_news::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfill,
)
.await
.unwrap();
info!("Backfilled news for {}.", symbol); info!("Backfilled news for {}.", symbol);
} }

View File

@@ -9,18 +9,18 @@ use crate::{
}, },
create_send_await, database, create_send_await, database,
types::{alpaca, Asset, Class}, types::{alpaca, Asset, Class},
utils::backoff,
}; };
use futures_util::{future::join_all, StreamExt}; use futures_util::StreamExt;
use itertools::{Either, Itertools}; use itertools::{Either, Itertools};
use std::sync::Arc; use log::error;
use std::{collections::HashMap, sync::Arc};
use tokio::{ use tokio::{
join, select, spawn, join, select, spawn,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
}; };
use tokio_tungstenite::connect_async; use tokio_tungstenite::connect_async;
#[derive(Clone, Copy)] #[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum Action { pub enum Action {
Add, Add,
@@ -173,13 +173,6 @@ async fn handle_message(
message.action.into(), message.action.into(),
us_equity_symbols.clone() us_equity_symbols.clone()
); );
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
message.action.into(),
us_equity_symbols
);
}; };
let bars_crypto_future = async { let bars_crypto_future = async {
@@ -193,13 +186,6 @@ async fn handle_message(
message.action.into(), message.action.into(),
crypto_symbols.clone() crypto_symbols.clone()
); );
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
message.action.into(),
crypto_symbols
);
}; };
let news_future = async { let news_future = async {
@@ -209,62 +195,127 @@ async fn handle_message(
message.action.into(), message.action.into(),
symbols.clone() symbols.clone()
); );
create_send_await!(
news_backfill_sender,
backfill::Message::new,
message.action.into(),
symbols.clone()
);
}; };
join!(bars_us_equity_future, bars_crypto_future, news_future); join!(bars_us_equity_future, bars_crypto_future, news_future);
match message.action { match message.action {
Action::Add => { Action::Add => {
let assets = join_all(symbols.into_iter().map(|symbol| { let assets = async {
let config = config.clone(); alpaca::api::incoming::asset::get_by_symbols(
async move { &config.alpaca_client,
let asset_future = async { &config.alpaca_rate_limiter,
alpaca::api::incoming::asset::get_by_symbol( &symbols,
&config.alpaca_client, None,
&config.alpaca_rate_limiter, )
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let position_future = async {
alpaca::api::incoming::position::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let (asset, position) = join!(asset_future, position_future);
Asset::from((asset, position))
}
}))
.await;
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await .await
.unwrap(); .unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::api::incoming::position::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let mut batch = vec![];
for symbol in &symbols {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}", symbol);
}
}
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
} }
Action::Remove => { Action::Remove => {
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols) database::assets::delete_where_symbols(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&symbols,
)
.await
.unwrap();
} }
_ => {} _ => {}
} }
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
let bars_us_equity_future = async {
if us_equity_symbols.is_empty() {
return;
}
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
};
let bars_crypto_future = async {
if crypto_symbols.is_empty() {
return;
}
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap(); message.response.send(()).unwrap();
} }
@@ -274,13 +325,19 @@ async fn handle_clock_message(
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>, news_backfill_sender: mpsc::Sender<backfill::Message>,
) { ) {
database::cleanup_all(&config.clickhouse_client) database::cleanup_all(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let assets = database::assets::select(&config.clickhouse_client) let assets = database::assets::select(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone() .clone()
@@ -299,8 +356,8 @@ async fn handle_clock_message(
create_send_await!( create_send_await!(
bars_us_equity_backfill_sender, bars_us_equity_backfill_sender,
backfill::Message::new, backfill::Message::new,
Some(backfill::Action::Backfill), backfill::Action::Backfill,
us_equity_symbols.clone() us_equity_symbols
); );
}; };
@@ -308,8 +365,8 @@ async fn handle_clock_message(
create_send_await!( create_send_await!(
bars_crypto_backfill_sender, bars_crypto_backfill_sender,
backfill::Message::new, backfill::Message::new,
Some(backfill::Action::Backfill), backfill::Action::Backfill,
crypto_symbols.clone() crypto_symbols
); );
}; };
@@ -317,7 +374,7 @@ async fn handle_clock_message(
create_send_await!( create_send_await!(
news_backfill_sender, news_backfill_sender,
backfill::Message::new, backfill::Message::new,
Some(backfill::Action::Backfill), backfill::Action::Backfill,
symbols symbols
); );
}; };

View File

@@ -268,9 +268,13 @@ impl Handler for BarsHandler {
let bar = Bar::from(message); let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time); debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.config.clickhouse_client, &bar) database::bars::upsert(
.await &self.config.clickhouse_client,
.unwrap(); &self.config.clickhouse_concurrency_limiter,
&bar,
)
.await
.unwrap();
} }
websocket::data::incoming::Message::Status(message) => { websocket::data::incoming::Message::Status(message) => {
debug!( debug!(
@@ -283,6 +287,7 @@ impl Handler for BarsHandler {
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => { | websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol( database::assets::update_status_where_symbol(
&self.config.clickhouse_client, &self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol, &message.symbol,
false, false,
) )
@@ -293,6 +298,7 @@ impl Handler for BarsHandler {
| websocket::data::incoming::status::Status::TradingResumption(_) => { | websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol( database::assets::update_status_where_symbol(
&self.config.clickhouse_client, &self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol, &message.symbol,
true, true,
) )
@@ -398,9 +404,13 @@ impl Handler for NewsHandler {
..news ..news
}; };
database::news::upsert(&self.config.clickhouse_client, &news) database::news::upsert(
.await &self.config.clickhouse_client,
.unwrap(); &self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
} }
websocket::data::incoming::Message::Error(message) => { websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message); error!("Received error message: {}.", message.message);

View File

@@ -52,9 +52,13 @@ async fn handle_websocket_message(
let order = Order::from(message.order); let order = Order::from(message.order);
database::orders::upsert(&config.clickhouse_client, &order) database::orders::upsert(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
match message.event { match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. } websocket::trading::incoming::order::Event::Fill { position_qty, .. }
@@ -63,6 +67,7 @@ async fn handle_websocket_message(
} => { } => {
database::assets::update_qty_where_symbol( database::assets::update_qty_where_symbol(
&config.clickhouse_client, &config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol, &order.symbol,
position_qty, position_qty,
) )

View File

@@ -81,15 +81,15 @@ pub struct Account {
} }
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> { ) -> Result<Account, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/account", *ALPACA_API_URL)) .get(&format!("{}/account", *ALPACA_API_URL))
.send() .send()
.await? .await?

View File

@@ -3,20 +3,25 @@ use crate::{
config::ALPACA_API_URL, config::ALPACA_API_URL,
types::{ types::{
self, self,
alpaca::shared::asset::{Class, Exchange, Status}, alpaca::{
api::outgoing,
shared::asset::{Class, Exchange, Status},
},
}, },
}; };
use backoff::{future::retry_notify, ExponentialBackoff}; use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter; use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn; use log::warn;
use reqwest::{Client, Error}; use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string; use serde_aux::field_attributes::deserialize_option_number_from_string;
use std::time::Duration; use std::{collections::HashSet, time::Duration};
use tokio::try_join;
use uuid::Uuid; use uuid::Uuid;
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize)] #[derive(Deserialize, Clone)]
pub struct Asset { pub struct Asset {
pub id: Uuid, pub id: Uuid,
pub class: Class, pub class: Class,
@@ -47,17 +52,56 @@ impl From<(Asset, Option<Position>)> for types::Asset {
} }
} }
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("{}/assets", *ALPACA_API_URL))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol( pub async fn get_by_symbol(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
symbol: &str, symbol: &str,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Asset, Error> { ) -> Result<Asset, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol)) .get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol))
.send() .send()
.await? .await?
@@ -84,3 +128,43 @@ pub async fn get_by_symbol(
) )
.await .await
} }
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Asset>, Error> {
if symbols.len() < 2 {
let symbol = symbols.first().unwrap();
let asset = get_by_symbol(client, rate_limiter, symbol, backoff).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(client, rate_limiter, &us_equity_query, backoff_clone);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -50,9 +50,9 @@ pub struct Message {
pub next_page_token: Option<String>, pub next_page_token: Option<String>,
} }
pub async fn get_historical( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
data_url: &str, data_url: &str,
query: &outgoing::bar::Bar, query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
@@ -60,8 +60,8 @@ pub async fn get_historical(
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(data_url) .get(data_url)
.query(query) .query(query)
.send() .send()

View File

@@ -32,16 +32,16 @@ impl From<Calendar> for types::Calendar {
} }
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar, query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Calendar>, Error> { ) -> Result<Vec<Calendar>, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/calendar", *ALPACA_API_URL)) .get(&format!("{}/calendar", *ALPACA_API_URL))
.query(query) .query(query)
.send() .send()

View File

@@ -19,15 +19,15 @@ pub struct Clock {
} }
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Clock, Error> { ) -> Result<Clock, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/clock", *ALPACA_API_URL)) .get(&format!("{}/clock", *ALPACA_API_URL))
.send() .send()
.await? .await?

View File

@@ -73,17 +73,17 @@ pub struct Message {
pub next_page_token: Option<String>, pub next_page_token: Option<String>,
} }
pub async fn get_historical( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News, query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> { ) -> Result<Message, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(ALPACA_NEWS_DATA_API_URL) .get(ALPACA_NEWS_DATA_API_URL)
.query(query) .query(query)
.send() .send()

View File

@@ -11,16 +11,16 @@ use std::time::Duration;
pub use shared::order::Order; pub use shared::order::Order;
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order, query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Order>, Error> { ) -> Result<Vec<Order>, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/orders", *ALPACA_API_URL)) .get(&format!("{}/orders", *ALPACA_API_URL))
.query(query) .query(query)
.send() .send()

View File

@@ -12,10 +12,10 @@ use log::warn;
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string; use serde_aux::field_attributes::deserialize_number_from_string;
use std::time::Duration; use std::{collections::HashSet, time::Duration};
use uuid::Uuid; use uuid::Uuid;
#[derive(Deserialize)] #[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Side { pub enum Side {
Long, Long,
@@ -31,7 +31,7 @@ impl From<Side> for shared::order::Side {
} }
} }
#[derive(Deserialize)] #[derive(Deserialize, Clone)]
pub struct Position { pub struct Position {
pub asset_id: Uuid, pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")] #[serde(deserialize_with = "de::add_slash_to_symbol")]
@@ -67,15 +67,15 @@ pub struct Position {
} }
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> { ) -> Result<Vec<Position>, reqwest::Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/positions", *ALPACA_API_URL)) .get(&format!("{}/positions", *ALPACA_API_URL))
.send() .send()
.await? .await?
@@ -102,16 +102,16 @@ pub async fn get(
} }
pub async fn get_by_symbol( pub async fn get_by_symbol(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
symbol: &str, symbol: &str,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
) -> Result<Option<Position>, reqwest::Error> { ) -> Result<Option<Position>, reqwest::Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
let response = alpaca_client let response = client
.get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol)) .get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol))
.send() .send()
.await?; .await?;
@@ -143,3 +143,25 @@ pub async fn get_by_symbol(
) )
.await .await
} }
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() < 2 {
let symbol = symbols.first().unwrap();
let position = get_by_symbol(client, rate_limiter, symbol, backoff).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -0,0 +1,21 @@
use crate::types::alpaca::shared::asset::{Class, Exchange, Status};
use serde::Serialize;
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,3 +1,4 @@
pub mod asset;
pub mod bar; pub mod bar;
pub mod calendar; pub mod calendar;
pub mod news; pub mod news;

View File

@@ -1,7 +1,7 @@
use crate::{impl_from_enum, types}; use crate::{impl_from_enum, types};
use serde::Deserialize; use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)] #[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Class { pub enum Class {
UsEquity, UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto); impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Deserialize)] #[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange { pub enum Exchange {
Amex, Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto Crypto
); );
#[derive(Deserialize)] #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Status { pub enum Status {
Active, Active,

View File

@@ -8,7 +8,8 @@ use std::fmt;
use time::{format_description::OwnedFormatItem, macros::format_description, Time}; use time::{format_description::OwnedFormatItem, macros::format_description, Time};
lazy_static! { lazy_static! {
static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap(); // This *will* break in the future if a crypto pair with one letter is added
static ref RE_SLASH: Regex = Regex::new(r"^(.{2,})(BTC|USD.?)$").unwrap();
static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into(); static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into();
} }