diff --git a/.gitignore b/.gitignore index d0dd9f5..f7003f9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # will have compiled files and executables debug/ target/ +log/ # These are backup files generated by rustfmt **/*.rs.bk diff --git a/log4rs.yaml b/log4rs.yaml index b191c37..a297395 100644 --- a/log4rs.yaml +++ b/log4rs.yaml @@ -4,7 +4,14 @@ appenders: encoder: pattern: "{d} {h({l})} {M}::{L} - {m}{n}" + file: + kind: file + path: "./log/output.log" + encoder: + pattern: "{d} {l} {M}::{L} - {m}{n}" + root: level: info appenders: - stdout + - file diff --git a/src/config.rs b/src/config.rs index 5b6a180..e77ad4c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,7 +13,7 @@ use rust_bert::{ resources::LocalResource, }; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, Semaphore}; pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; @@ -51,17 +51,21 @@ lazy_static! { Mode::Paper => String::from("paper-api"), } ); - pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS") - .expect("MAX_BERT_INPUTS must be set.") + pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS") + .expect("BERT_MAX_INPUTS must be set.") .parse() - .expect("MAX_BERT_INPUTS must be a positive integer."); - + .expect("BERT_MAX_INPUTS must be a positive integer."); + pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS") + .expect("CLICKHOUSE_MAX_CONNECTIONS must be set.") + .parse() + .expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer."); } pub struct Config { pub alpaca_client: Client, pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub clickhouse_client: clickhouse::Client, + pub clickhouse_concurrency_limiter: Arc, pub sequence_classifier: Mutex, } @@ -95,6 +99,7 @@ impl Config { env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), ) .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), + clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)), sequence_classifier: Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, diff --git a/src/database/assets.rs b/src/database/assets.rs index 04a9192..cd68d0a 100644 --- a/src/database/assets.rs +++ b/src/database/assets.rs @@ -1,8 +1,11 @@ +use std::sync::Arc; + use crate::{ delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, }; use clickhouse::{error::Error, Client}; use serde::Serialize; +use tokio::sync::Semaphore; select!(Asset, "assets"); select_where_symbol!(Asset, "assets"); @@ -11,14 +14,16 @@ delete_where_symbols!("assets"); optimize!("assets"); pub async fn update_status_where_symbol( - clickhouse_client: &Client, + client: &Client, + concurrency_limiter: &Arc, symbol: &T, status: bool, ) -> Result<(), Error> where T: AsRef + Serialize + Send + Sync, { - clickhouse_client + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?") .bind(status) .bind(symbol) @@ -27,14 +32,16 @@ where } pub async fn update_qty_where_symbol( - clickhouse_client: &Client, + client: &Client, + concurrency_limiter: &Arc, symbol: &T, qty: f64, ) -> Result<(), Error> where T: AsRef + Serialize + Send + Sync, { - clickhouse_client + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") .bind(qty) .bind(symbol) diff --git a/src/database/backfills_bars.rs b/src/database/backfills_bars.rs index 2d8017c..1360665 100644 --- a/src/database/backfills_bars.rs +++ b/src/database/backfills_bars.rs @@ -1,16 +1,20 @@ +use std::sync::Arc; + use crate::{ - cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, + cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert, }; use clickhouse::{error::Error, Client}; +use tokio::sync::Semaphore; -select_where_symbol!(Backfill, "backfills_bars"); +select_where_symbols!(Backfill, "backfills_bars"); upsert!(Backfill, "backfills_bars"); delete_where_symbols!("backfills_bars"); cleanup!("backfills_bars"); optimize!("backfills_bars"); -pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { - clickhouse_client +pub async fn unfresh(client: &Client, concurrency_limiter: &Arc) -> Result<(), Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true") .execute() .await diff --git a/src/database/backfills_news.rs b/src/database/backfills_news.rs index 9d4f2ef..5688eda 100644 --- a/src/database/backfills_news.rs +++ b/src/database/backfills_news.rs @@ -1,16 +1,20 @@ +use std::sync::Arc; + use crate::{ - cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, + cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert, }; use clickhouse::{error::Error, Client}; +use tokio::sync::Semaphore; -select_where_symbol!(Backfill, "backfills_news"); +select_where_symbols!(Backfill, "backfills_news"); upsert!(Backfill, "backfills_news"); delete_where_symbols!("backfills_news"); cleanup!("backfills_news"); optimize!("backfills_news"); -pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { - clickhouse_client +pub async fn unfresh(client: &Client, concurrency_limiter: &Arc) -> Result<(), Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true") .execute() .await diff --git a/src/database/bars.rs b/src/database/bars.rs index 8be2374..ca9ae01 100644 --- a/src/database/bars.rs +++ b/src/database/bars.rs @@ -1,7 +1,21 @@ -use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; +use std::sync::Arc; + +use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; +use clickhouse::Client; +use tokio::sync::Semaphore; upsert!(Bar, "bars"); upsert_batch!(Bar, "bars"); delete_where_symbols!("bars"); -cleanup!("bars"); optimize!("bars"); + +pub async fn cleanup( + client: &Client, + concurrency_limiter: &Arc, +) -> Result<(), clickhouse::error::Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); + client + .query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)") + .execute() + .await +} diff --git a/src/database/calendar.rs b/src/database/calendar.rs index 20024ca..7001be1 100644 --- a/src/database/calendar.rs +++ b/src/database/calendar.rs @@ -1,11 +1,14 @@ +use std::sync::Arc; + use crate::{optimize, types::Calendar}; -use clickhouse::error::Error; -use tokio::try_join; +use clickhouse::{error::Error, Client}; +use tokio::{sync::Semaphore, try_join}; optimize!("calendar"); pub async fn upsert_batch_and_delete<'a, T>( - client: &clickhouse::Client, + client: &Client, + concurrency_limiter: &Arc, records: T, ) -> Result<(), Error> where @@ -34,5 +37,6 @@ where .await }; + let _ = concurrency_limiter.acquire_many(2).await.unwrap(); try_join!(upsert_future, delete_future).map(|_| ()) } diff --git a/src/database/mod.rs b/src/database/mod.rs index 9523349..136f353 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -15,7 +15,9 @@ macro_rules! select { ($record:ty, $table_name:expr) => { pub async fn select( client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, ) -> Result, clickhouse::error::Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); client .query(&format!("SELECT ?fields FROM {} FINAL", $table_name)) .fetch_all::<$record>() @@ -29,11 +31,13 @@ macro_rules! select_where_symbol { ($record:ty, $table_name:expr) => { pub async fn select_where_symbol( client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, symbol: &T, ) -> Result, clickhouse::error::Error> where T: AsRef + serde::Serialize + Send + Sync, { + let _ = concurrency_limiter.acquire().await.unwrap(); client .query(&format!( "SELECT ?fields FROM {} FINAL WHERE symbol = ?", @@ -46,13 +50,39 @@ macro_rules! select_where_symbol { }; } +#[macro_export] +macro_rules! select_where_symbols { + ($record:ty, $table_name:expr) => { + pub async fn select_where_symbols( + client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, + symbols: &[T], + ) -> Result, clickhouse::error::Error> + where + T: AsRef + serde::Serialize + Send + Sync, + { + let _ = concurrency_limiter.acquire().await.unwrap(); + client + .query(&format!( + "SELECT ?fields FROM {} FINAL WHERE symbol IN ?", + $table_name + )) + .bind(symbols) + .fetch_all::<$record>() + .await + } + }; +} + #[macro_export] macro_rules! upsert { ($record:ty, $table_name:expr) => { pub async fn upsert( client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, record: &$record, ) -> Result<(), clickhouse::error::Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); let mut insert = client.insert($table_name)?; insert.write(record).await?; insert.end().await @@ -65,12 +95,14 @@ macro_rules! upsert_batch { ($record:ty, $table_name:expr) => { pub async fn upsert_batch<'a, T>( client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, records: T, ) -> Result<(), clickhouse::error::Error> where T: IntoIterator + Send + Sync, T::IntoIter: Send, { + let _ = concurrency_limiter.acquire().await.unwrap(); let mut insert = client.insert($table_name)?; for record in records { insert.write(record).await?; @@ -85,11 +117,13 @@ macro_rules! delete_where_symbols { ($table_name:expr) => { pub async fn delete_where_symbols( client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, symbols: &[T], ) -> Result<(), clickhouse::error::Error> where T: AsRef + serde::Serialize + Send + Sync, { + let _ = concurrency_limiter.acquire().await.unwrap(); client .query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name)) .bind(symbols) @@ -102,7 +136,11 @@ macro_rules! delete_where_symbols { #[macro_export] macro_rules! cleanup { ($table_name:expr) => { - pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { + pub async fn cleanup( + client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, + ) -> Result<(), clickhouse::error::Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); client .query(&format!( "DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)", @@ -117,7 +155,11 @@ macro_rules! cleanup { #[macro_export] macro_rules! optimize { ($table_name:expr) => { - pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { + pub async fn optimize( + client: &clickhouse::Client, + concurrency_limiter: &std::sync::Arc, + ) -> Result<(), clickhouse::error::Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); client .query(&format!("OPTIMIZE TABLE {} FINAL", $table_name)) .execute() @@ -126,27 +168,33 @@ macro_rules! optimize { }; } -pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> { +pub async fn cleanup_all( + clickhouse_client: &Client, + concurrency_limiter: &std::sync::Arc, +) -> Result<(), Error> { info!("Cleaning up database."); try_join!( - bars::cleanup(clickhouse_client), - news::cleanup(clickhouse_client), - backfills_bars::cleanup(clickhouse_client), - backfills_news::cleanup(clickhouse_client) + bars::cleanup(clickhouse_client, concurrency_limiter), + news::cleanup(clickhouse_client, concurrency_limiter), + backfills_bars::cleanup(clickhouse_client, concurrency_limiter), + backfills_news::cleanup(clickhouse_client, concurrency_limiter) ) .map(|_| ()) } -pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> { +pub async fn optimize_all( + clickhouse_client: &Client, + concurrency_limiter: &std::sync::Arc, +) -> Result<(), Error> { info!("Optimizing database."); try_join!( - assets::optimize(clickhouse_client), - bars::optimize(clickhouse_client), - news::optimize(clickhouse_client), - backfills_bars::optimize(clickhouse_client), - backfills_news::optimize(clickhouse_client), - orders::optimize(clickhouse_client), - calendar::optimize(clickhouse_client) + assets::optimize(clickhouse_client, concurrency_limiter), + bars::optimize(clickhouse_client, concurrency_limiter), + news::optimize(clickhouse_client, concurrency_limiter), + backfills_bars::optimize(clickhouse_client, concurrency_limiter), + backfills_news::optimize(clickhouse_client, concurrency_limiter), + orders::optimize(clickhouse_client, concurrency_limiter), + calendar::optimize(clickhouse_client, concurrency_limiter) ) .map(|_| ()) } diff --git a/src/database/news.rs b/src/database/news.rs index 4f40072..a028c21 100644 --- a/src/database/news.rs +++ b/src/database/news.rs @@ -1,24 +1,33 @@ +use std::sync::Arc; + use crate::{optimize, types::News, upsert, upsert_batch}; use clickhouse::{error::Error, Client}; use serde::Serialize; +use tokio::sync::Semaphore; upsert!(News, "news"); upsert_batch!(News, "news"); optimize!("news"); -pub async fn delete_where_symbols(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> +pub async fn delete_where_symbols( + client: &Client, + concurrency_limiter: &Arc, + symbols: &[T], +) -> Result<(), Error> where T: AsRef + Serialize + Send + Sync, { - clickhouse_client + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))") .bind(symbols) .execute() .await } -pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { - clickhouse_client +pub async fn cleanup(client: &Client, concurrency_limiter: &Arc) -> Result<(), Error> { + let _ = concurrency_limiter.acquire().await.unwrap(); + client .query( "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))", ) diff --git a/src/init.rs b/src/init.rs index 0fc425e..0a0c4c2 100644 --- a/src/init.rs +++ b/src/init.rs @@ -68,9 +68,13 @@ pub async fn rehydrate_orders(config: &Arc) { .flat_map(&alpaca::api::incoming::order::Order::normalize) .collect::>(); - database::orders::upsert_batch(&config.clickhouse_client, &orders) - .await - .unwrap(); + database::orders::upsert_batch( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &orders, + ) + .await + .unwrap(); info!("Rehydrated order data."); } @@ -92,9 +96,12 @@ pub async fn rehydrate_positions(config: &Arc) { }; let assets_future = async { - database::assets::select(&config.clickhouse_client) - .await - .unwrap() + database::assets::select( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap() }; let (mut positions, assets) = join!(positions_future, assets_future); @@ -111,9 +118,13 @@ pub async fn rehydrate_positions(config: &Arc) { }) .collect::>(); - database::assets::upsert_batch(&config.clickhouse_client, &assets) - .await - .unwrap(); + database::assets::upsert_batch( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &assets, + ) + .await + .unwrap(); for position in positions.values() { warn!( diff --git a/src/main.rs b/src/main.rs index 445843b..41f6f58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,17 +22,29 @@ async fn main() { let config = Config::arc_from_env(); try_join!( - database::backfills_bars::unfresh(&config.clickhouse_client), - database::backfills_news::unfresh(&config.clickhouse_client) + database::backfills_bars::unfresh( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter + ), + database::backfills_news::unfresh( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter + ) ) .unwrap(); - database::cleanup_all(&config.clickhouse_client) - .await - .unwrap(); - database::optimize_all(&config.clickhouse_client) - .await - .unwrap(); + database::cleanup_all( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap(); + database::optimize_all( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap(); init::check_account(&config).await; join!( @@ -53,12 +65,15 @@ async fn main() { spawn(threads::clock::run(config.clone(), clock_sender)); - let assets = database::assets::select(&config.clickhouse_client) - .await - .unwrap() - .into_iter() - .map(|asset| (asset.symbol, asset.class)) - .collect::>(); + let assets = database::assets::select( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap() + .into_iter() + .map(|asset| (asset.symbol, asset.class)) + .collect::>(); create_send_await!( data_sender, diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 8f9b1e9..3e99ede 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -5,16 +5,22 @@ use crate::{ }; use axum::{extract::Path, Extension, Json}; use http::StatusCode; -use serde::Deserialize; -use std::sync::Arc; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use tokio::sync::mpsc; pub async fn get( Extension(config): Extension>, ) -> Result<(StatusCode, Json>), StatusCode> { - let assets = database::assets::select(&config.clickhouse_client) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let assets = database::assets::select( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok((StatusCode::OK, Json(assets))) } @@ -23,9 +29,13 @@ pub async fn get_where_symbol( Extension(config): Extension>, Path(symbol): Path, ) -> Result<(StatusCode, Json), StatusCode> { - let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let asset = database::assets::select_where_symbol( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &symbol, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { Ok((StatusCode::OK, Json(asset))) @@ -33,19 +43,101 @@ pub async fn get_where_symbol( } #[derive(Deserialize)] -pub struct AddAssetRequest { - symbol: String, +pub struct AddAssetsRequest { + symbols: Vec, +} + +#[derive(Serialize)] +pub struct AddAssetsResponse { + added: Vec, + skipped: Vec, + failed: Vec, } pub async fn add( Extension(config): Extension>, Extension(data_sender): Extension>, - Json(request): Json, + Json(request): Json, +) -> Result<(StatusCode, Json), StatusCode> { + let database_symbols = database::assets::select( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .into_iter() + .map(|asset| asset.symbol) + .collect::>(); + + let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols( + &config.alpaca_client, + &config.alpaca_rate_limiter, + &request.symbols, + None, + ) + .await + .map_err(|e| { + e.status() + .map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| { + StatusCode::from_u16(status.as_u16()).unwrap() + }) + })? + .into_iter() + .map(|asset| (asset.symbol.clone(), asset)) + .collect::>(); + + let (assets, skipped, failed) = request.symbols.into_iter().fold( + (vec![], vec![], vec![]), + |(mut assets, mut skipped, mut failed), symbol| { + if database_symbols.contains(&symbol) { + skipped.push(symbol); + } else if let Some(asset) = alpaca_assets.remove(&symbol) { + if asset.status == alpaca::shared::asset::Status::Active + && asset.tradable + && asset.fractionable + { + assets.push((asset.symbol, asset.class.into())); + } else { + failed.push(asset.symbol); + } + } else { + failed.push(symbol); + } + + (assets, skipped, failed) + }, + ); + + create_send_await!( + data_sender, + threads::data::Message::new, + threads::data::Action::Add, + assets.clone() + ); + + Ok(( + StatusCode::CREATED, + Json(AddAssetsResponse { + added: assets.into_iter().map(|asset| asset.0).collect(), + skipped, + failed, + }), + )) +} + +pub async fn add_symbol( + Extension(config): Extension>, + Extension(data_sender): Extension>, + Path(symbol): Path, ) -> Result { - if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - .is_some() + if database::assets::select_where_symbol( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &symbol, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .is_some() { return Err(StatusCode::CONFLICT); } @@ -53,7 +145,7 @@ pub async fn add( let asset = alpaca::api::incoming::asset::get_by_symbol( &config.alpaca_client, &config.alpaca_rate_limiter, - &request.symbol, + &symbol, None, ) .await @@ -64,7 +156,10 @@ pub async fn add( }) })?; - if !asset.tradable || !asset.fractionable { + if asset.status != alpaca::shared::asset::Status::Active + || !asset.tradable + || !asset.fractionable + { return Err(StatusCode::FORBIDDEN); } @@ -83,10 +178,14 @@ pub async fn delete( Extension(data_sender): Extension>, Path(symbol): Path, ) -> Result { - let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - .ok_or(StatusCode::NOT_FOUND)?; + let asset = database::assets::select_where_symbol( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &symbol, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; create_send_await!( data_sender, diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 5d3cb5b..1839aa8 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -16,6 +16,7 @@ pub async fn run(config: Arc, data_sender: mpsc::Sender, sender: mpsc::Sender) { let sleep_future = sleep(sleep_until); let calendar_future = async { - database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar) - .await - .unwrap(); + database::calendar::upsert_batch_and_delete( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &calendar, + ) + .await + .unwrap(); }; join!(sleep_future, calendar_future); diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index 565a2bf..9f86d10 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -2,7 +2,7 @@ use super::ThreadType; use crate::{ config::{ Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL, - MAX_BERT_INPUTS, + BERT_MAX_INPUTS, }, database, types::{ @@ -30,24 +30,14 @@ pub enum Action { Purge, } -impl From for Option { - fn from(action: super::Action) -> Self { - match action { - super::Action::Add | super::Action::Enable => Some(Action::Backfill), - super::Action::Remove => Some(Action::Purge), - super::Action::Disable => None, - } - } -} - pub struct Message { - pub action: Option, + pub action: Action, pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { - pub fn new(action: Option, symbols: Vec) -> (Self, oneshot::Receiver<()>) { + pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { @@ -62,10 +52,10 @@ impl Message { #[async_trait] pub trait Handler: Send + Sync { - async fn select_latest_backfill( + async fn select_latest_backfills( &self, - symbol: String, - ) -> Result, clickhouse::error::Error>; + symbols: &[String], + ) -> Result, clickhouse::error::Error>; async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); @@ -94,9 +84,17 @@ async fn handle_backfill_message( let mut backfill_jobs = backfill_jobs.lock().await; match message.action { - Some(Action::Backfill) => { + Action::Backfill => { let log_string = handler.log_string(); + let backfills = handler + .select_latest_backfills(&message.symbols) + .await + .unwrap() + .into_iter() + .map(|backfill| (backfill.symbol.clone(), backfill)) + .collect::>(); + for symbol in message.symbols { if let Some(job) = backfill_jobs.get(&symbol) { if !job.is_finished() { @@ -108,33 +106,30 @@ async fn handle_backfill_message( } } + let fetch_from = backfills + .get(&symbol) + .map_or(OffsetDateTime::UNIX_EPOCH, |backfill| { + backfill.time + ONE_SECOND + }); + + let fetch_to = last_minute(); + + if fetch_from > fetch_to { + info!("No need to backfill {} {}.", symbol, log_string,); + return; + } + let handler = handler.clone(); backfill_jobs.insert( symbol.clone(), spawn(async move { - let fetch_from = match handler - .select_latest_backfill(symbol.clone()) - .await - .unwrap() - { - Some(latest_backfill) => latest_backfill.time + ONE_SECOND, - None => OffsetDateTime::UNIX_EPOCH, - }; - - let fetch_to = last_minute(); - - if fetch_from > fetch_to { - info!("No need to backfill {} {}.", symbol, log_string,); - return; - } - handler.queue_backfill(&symbol, fetch_to).await; handler.backfill(symbol, fetch_from, fetch_to).await; }), ); } } - Some(Action::Purge) => { + Action::Purge => { for symbol in &message.symbols { if let Some(job) = backfill_jobs.remove(symbol) { if !job.is_finished() { @@ -150,7 +145,6 @@ async fn handle_backfill_message( ) .unwrap(); } - None => {} } message.response.send(()).unwrap(); @@ -199,20 +193,34 @@ fn crypto_query_constructor( #[async_trait] impl Handler for BarHandler { - async fn select_latest_backfill( + async fn select_latest_backfills( &self, - symbol: String, - ) -> Result, clickhouse::error::Error> { - database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await + symbols: &[String], + ) -> Result, clickhouse::error::Error> { + database::backfills_bars::select_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols) - .await + database::backfills_bars::delete_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await + database::bars::delete_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { @@ -230,7 +238,7 @@ impl Handler for BarHandler { let mut next_page_token = None; loop { - let Ok(message) = alpaca::api::incoming::bar::get_historical( + let Ok(message) = alpaca::api::incoming::bar::get( &self.config.alpaca_client, &self.config.alpaca_rate_limiter, self.data_url, @@ -267,12 +275,20 @@ impl Handler for BarHandler { let backfill = bars.last().unwrap().clone().into(); - database::bars::upsert_batch(&self.config.clickhouse_client, &bars) - .await - .unwrap(); - database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill) - .await - .unwrap(); + database::bars::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &bars, + ) + .await + .unwrap(); + database::backfills_bars::upsert( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &backfill, + ) + .await + .unwrap(); info!("Backfilled bars for {}.", symbol); } @@ -288,20 +304,34 @@ struct NewsHandler { #[async_trait] impl Handler for NewsHandler { - async fn select_latest_backfill( + async fn select_latest_backfills( &self, - symbol: String, - ) -> Result, clickhouse::error::Error> { - database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await + symbols: &[String], + ) -> Result, clickhouse::error::Error> { + database::backfills_news::select_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols) - .await + database::backfills_news::delete_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await + database::news::delete_where_symbols( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + symbols, + ) + .await } async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { @@ -317,7 +347,7 @@ impl Handler for NewsHandler { let mut next_page_token = None; loop { - let Ok(message) = alpaca::api::incoming::news::get_historical( + let Ok(message) = alpaca::api::incoming::news::get( &self.config.alpaca_client, &self.config.alpaca_rate_limiter, &alpaca::api::outgoing::news::News { @@ -355,7 +385,7 @@ impl Handler for NewsHandler { .map(|news| format!("{}\n\n{}", news.headline, news.content)) .collect::>(); - let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move { + let predictions = join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move { let sequence_classifier = self.config.sequence_classifier.lock().await; block_in_place(|| { sequence_classifier @@ -381,12 +411,20 @@ impl Handler for NewsHandler { let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); - database::news::upsert_batch(&self.config.clickhouse_client, &news) - .await - .unwrap(); - database::backfills_news::upsert(&self.config.clickhouse_client, &backfill) - .await - .unwrap(); + database::news::upsert_batch( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &news, + ) + .await + .unwrap(); + database::backfills_news::upsert( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &backfill, + ) + .await + .unwrap(); info!("Backfilled news for {}.", symbol); } diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index 9822fda..82b25d8 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -9,18 +9,18 @@ use crate::{ }, create_send_await, database, types::{alpaca, Asset, Class}, - utils::backoff, }; -use futures_util::{future::join_all, StreamExt}; +use futures_util::StreamExt; use itertools::{Either, Itertools}; -use std::sync::Arc; +use log::error; +use std::{collections::HashMap, sync::Arc}; use tokio::{ join, select, spawn, sync::{mpsc, oneshot}, }; use tokio_tungstenite::connect_async; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] #[allow(dead_code)] pub enum Action { Add, @@ -173,13 +173,6 @@ async fn handle_message( message.action.into(), us_equity_symbols.clone() ); - - create_send_await!( - bars_us_equity_backfill_sender, - backfill::Message::new, - message.action.into(), - us_equity_symbols - ); }; let bars_crypto_future = async { @@ -193,13 +186,6 @@ async fn handle_message( message.action.into(), crypto_symbols.clone() ); - - create_send_await!( - bars_crypto_backfill_sender, - backfill::Message::new, - message.action.into(), - crypto_symbols - ); }; let news_future = async { @@ -209,62 +195,127 @@ async fn handle_message( message.action.into(), symbols.clone() ); - - create_send_await!( - news_backfill_sender, - backfill::Message::new, - message.action.into(), - symbols.clone() - ); }; join!(bars_us_equity_future, bars_crypto_future, news_future); match message.action { Action::Add => { - let assets = join_all(symbols.into_iter().map(|symbol| { - let config = config.clone(); - async move { - let asset_future = async { - alpaca::api::incoming::asset::get_by_symbol( - &config.alpaca_client, - &config.alpaca_rate_limiter, - &symbol, - Some(backoff::infinite()), - ) - .await - .unwrap() - }; - - let position_future = async { - alpaca::api::incoming::position::get_by_symbol( - &config.alpaca_client, - &config.alpaca_rate_limiter, - &symbol, - Some(backoff::infinite()), - ) - .await - .unwrap() - }; - - let (asset, position) = join!(asset_future, position_future); - Asset::from((asset, position)) - } - })) - .await; - - database::assets::upsert_batch(&config.clickhouse_client, &assets) + let assets = async { + alpaca::api::incoming::asset::get_by_symbols( + &config.alpaca_client, + &config.alpaca_rate_limiter, + &symbols, + None, + ) .await - .unwrap(); + .unwrap() + .into_iter() + .map(|asset| (asset.symbol.clone(), asset)) + .collect::>() + }; + + let positions = async { + alpaca::api::incoming::position::get_by_symbols( + &config.alpaca_client, + &config.alpaca_rate_limiter, + &symbols, + None, + ) + .await + .unwrap() + .into_iter() + .map(|position| (position.symbol.clone(), position)) + .collect::>() + }; + + let (mut assets, mut positions) = join!(assets, positions); + + let mut batch = vec![]; + + for symbol in &symbols { + if let Some(asset) = assets.remove(symbol) { + let position = positions.remove(symbol); + batch.push(Asset::from((asset, position))); + } else { + error!("Failed to find asset for symbol: {}", symbol); + } + } + + database::assets::upsert_batch( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &batch, + ) + .await + .unwrap(); } Action::Remove => { - database::assets::delete_where_symbols(&config.clickhouse_client, &symbols) - .await - .unwrap(); + database::assets::delete_where_symbols( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &symbols, + ) + .await + .unwrap(); } _ => {} } + if message.action == Action::Disable { + message.response.send(()).unwrap(); + return; + } + + let bars_us_equity_future = async { + if us_equity_symbols.is_empty() { + return; + } + + create_send_await!( + bars_us_equity_backfill_sender, + backfill::Message::new, + match message.action { + Action::Add | Action::Enable => backfill::Action::Backfill, + Action::Remove => backfill::Action::Purge, + Action::Disable => unreachable!(), + }, + us_equity_symbols + ); + }; + + let bars_crypto_future = async { + if crypto_symbols.is_empty() { + return; + } + + create_send_await!( + bars_crypto_backfill_sender, + backfill::Message::new, + match message.action { + Action::Add | Action::Enable => backfill::Action::Backfill, + Action::Remove => backfill::Action::Purge, + Action::Disable => unreachable!(), + }, + crypto_symbols + ); + }; + + let news_future = async { + create_send_await!( + news_backfill_sender, + backfill::Message::new, + match message.action { + Action::Add | Action::Enable => backfill::Action::Backfill, + Action::Remove => backfill::Action::Purge, + Action::Disable => unreachable!(), + }, + symbols + ); + }; + + join!(bars_us_equity_future, bars_crypto_future, news_future); + message.response.send(()).unwrap(); } @@ -274,13 +325,19 @@ async fn handle_clock_message( bars_crypto_backfill_sender: mpsc::Sender, news_backfill_sender: mpsc::Sender, ) { - database::cleanup_all(&config.clickhouse_client) - .await - .unwrap(); + database::cleanup_all( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap(); - let assets = database::assets::select(&config.clickhouse_client) - .await - .unwrap(); + let assets = database::assets::select( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + ) + .await + .unwrap(); let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets .clone() @@ -299,8 +356,8 @@ async fn handle_clock_message( create_send_await!( bars_us_equity_backfill_sender, backfill::Message::new, - Some(backfill::Action::Backfill), - us_equity_symbols.clone() + backfill::Action::Backfill, + us_equity_symbols ); }; @@ -308,8 +365,8 @@ async fn handle_clock_message( create_send_await!( bars_crypto_backfill_sender, backfill::Message::new, - Some(backfill::Action::Backfill), - crypto_symbols.clone() + backfill::Action::Backfill, + crypto_symbols ); }; @@ -317,7 +374,7 @@ async fn handle_clock_message( create_send_await!( news_backfill_sender, backfill::Message::new, - Some(backfill::Action::Backfill), + backfill::Action::Backfill, symbols ); }; diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index 9f7d6fd..1d4b256 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -268,9 +268,13 @@ impl Handler for BarsHandler { let bar = Bar::from(message); debug!("Received bar for {}: {}.", bar.symbol, bar.time); - database::bars::upsert(&self.config.clickhouse_client, &bar) - .await - .unwrap(); + database::bars::upsert( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &bar, + ) + .await + .unwrap(); } websocket::data::incoming::Message::Status(message) => { debug!( @@ -283,6 +287,7 @@ impl Handler for BarsHandler { | websocket::data::incoming::status::Status::VolatilityTradingPause(_) => { database::assets::update_status_where_symbol( &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, &message.symbol, false, ) @@ -293,6 +298,7 @@ impl Handler for BarsHandler { | websocket::data::incoming::status::Status::TradingResumption(_) => { database::assets::update_status_where_symbol( &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, &message.symbol, true, ) @@ -398,9 +404,13 @@ impl Handler for NewsHandler { ..news }; - database::news::upsert(&self.config.clickhouse_client, &news) - .await - .unwrap(); + database::news::upsert( + &self.config.clickhouse_client, + &self.config.clickhouse_concurrency_limiter, + &news, + ) + .await + .unwrap(); } websocket::data::incoming::Message::Error(message) => { error!("Received error message: {}.", message.message); diff --git a/src/threads/trading/websocket.rs b/src/threads/trading/websocket.rs index 9fb5690..0f2de3f 100644 --- a/src/threads/trading/websocket.rs +++ b/src/threads/trading/websocket.rs @@ -52,9 +52,13 @@ async fn handle_websocket_message( let order = Order::from(message.order); - database::orders::upsert(&config.clickhouse_client, &order) - .await - .unwrap(); + database::orders::upsert( + &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, + &order, + ) + .await + .unwrap(); match message.event { websocket::trading::incoming::order::Event::Fill { position_qty, .. } @@ -63,6 +67,7 @@ async fn handle_websocket_message( } => { database::assets::update_qty_where_symbol( &config.clickhouse_client, + &config.clickhouse_concurrency_limiter, &order.symbol, position_qty, ) diff --git a/src/types/alpaca/api/incoming/account.rs b/src/types/alpaca/api/incoming/account.rs index 1463073..7b9962d 100644 --- a/src/types/alpaca/api/incoming/account.rs +++ b/src/types/alpaca/api/incoming/account.rs @@ -81,15 +81,15 @@ pub struct Account { } pub async fn get( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/account", *ALPACA_API_URL)) .send() .await? diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index 8df93f6..5cd9f16 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -3,20 +3,25 @@ use crate::{ config::ALPACA_API_URL, types::{ self, - alpaca::shared::asset::{Class, Exchange, Status}, + alpaca::{ + api::outgoing, + shared::asset::{Class, Exchange, Status}, + }, }, }; use backoff::{future::retry_notify, ExponentialBackoff}; use governor::DefaultDirectRateLimiter; +use itertools::Itertools; use log::warn; use reqwest::{Client, Error}; use serde::Deserialize; use serde_aux::field_attributes::deserialize_option_number_from_string; -use std::time::Duration; +use std::{collections::HashSet, time::Duration}; +use tokio::try_join; use uuid::Uuid; #[allow(clippy::struct_excessive_bools)] -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct Asset { pub id: Uuid, pub class: Class, @@ -47,17 +52,56 @@ impl From<(Asset, Option)> for types::Asset { } } +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + query: &outgoing::asset::Asset, + backoff: Option, +) -> Result, Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("{}/assets", *ALPACA_API_URL)) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some( + reqwest::StatusCode::BAD_REQUEST + | reqwest::StatusCode::FORBIDDEN + | reqwest::StatusCode::NOT_FOUND, + ) => backoff::Error::Permanent(e), + _ => e.into(), + })? + .json::>() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get assets, will retry in {} seconds: {}", + duration.as_secs(), + e + ); + }, + ) + .await +} + pub async fn get_by_symbol( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, symbol: &str, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol)) .send() .await? @@ -84,3 +128,43 @@ pub async fn get_by_symbol( ) .await } + +pub async fn get_by_symbols( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbols: &[String], + backoff: Option, +) -> Result, Error> { + if symbols.len() < 2 { + let symbol = symbols.first().unwrap(); + let asset = get_by_symbol(client, rate_limiter, symbol, backoff).await?; + return Ok(vec![asset]); + } + + let symbols = symbols.iter().collect::>(); + + let backoff_clone = backoff.clone(); + + let us_equity_query = outgoing::asset::Asset { + class: Some(Class::UsEquity), + ..Default::default() + }; + + let us_equity_assets = get(client, rate_limiter, &us_equity_query, backoff_clone); + + let crypto_query = outgoing::asset::Asset { + class: Some(Class::Crypto), + ..Default::default() + }; + + let crypto_assets = get(client, rate_limiter, &crypto_query, backoff); + + let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?; + + Ok(crypto_assets + .into_iter() + .chain(us_equity_assets) + .dedup_by(|a, b| a.symbol == b.symbol) + .filter(|asset| symbols.contains(&asset.symbol)) + .collect()) +} diff --git a/src/types/alpaca/api/incoming/bar.rs b/src/types/alpaca/api/incoming/bar.rs index 8bf6f6a..ece3265 100644 --- a/src/types/alpaca/api/incoming/bar.rs +++ b/src/types/alpaca/api/incoming/bar.rs @@ -50,9 +50,9 @@ pub struct Message { pub next_page_token: Option, } -pub async fn get_historical( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, data_url: &str, query: &outgoing::bar::Bar, backoff: Option, @@ -60,8 +60,8 @@ pub async fn get_historical( retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(data_url) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/calendar.rs b/src/types/alpaca/api/incoming/calendar.rs index 8879886..80e6c80 100644 --- a/src/types/alpaca/api/incoming/calendar.rs +++ b/src/types/alpaca/api/incoming/calendar.rs @@ -32,16 +32,16 @@ impl From for types::Calendar { } pub async fn get( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, query: &outgoing::calendar::Calendar, backoff: Option, ) -> Result, Error> { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/calendar", *ALPACA_API_URL)) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/clock.rs b/src/types/alpaca/api/incoming/clock.rs index c23eabb..afc50ac 100644 --- a/src/types/alpaca/api/incoming/clock.rs +++ b/src/types/alpaca/api/incoming/clock.rs @@ -19,15 +19,15 @@ pub struct Clock { } pub async fn get( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/clock", *ALPACA_API_URL)) .send() .await? diff --git a/src/types/alpaca/api/incoming/news.rs b/src/types/alpaca/api/incoming/news.rs index 25e3e9b..61108f3 100644 --- a/src/types/alpaca/api/incoming/news.rs +++ b/src/types/alpaca/api/incoming/news.rs @@ -73,17 +73,17 @@ pub struct Message { pub next_page_token: Option, } -pub async fn get_historical( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, query: &outgoing::news::News, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(ALPACA_NEWS_DATA_API_URL) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/order.rs b/src/types/alpaca/api/incoming/order.rs index dc68eb8..4d17435 100644 --- a/src/types/alpaca/api/incoming/order.rs +++ b/src/types/alpaca/api/incoming/order.rs @@ -11,16 +11,16 @@ use std::time::Duration; pub use shared::order::Order; pub async fn get( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, query: &outgoing::order::Order, backoff: Option, ) -> Result, Error> { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/orders", *ALPACA_API_URL)) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/position.rs b/src/types/alpaca/api/incoming/position.rs index 7e29fb6..5b733fa 100644 --- a/src/types/alpaca/api/incoming/position.rs +++ b/src/types/alpaca/api/incoming/position.rs @@ -12,10 +12,10 @@ use log::warn; use reqwest::Client; use serde::Deserialize; use serde_aux::field_attributes::deserialize_number_from_string; -use std::time::Duration; +use std::{collections::HashSet, time::Duration}; use uuid::Uuid; -#[derive(Deserialize)] +#[derive(Deserialize, Clone, Copy)] #[serde(rename_all = "snake_case")] pub enum Side { Long, @@ -31,7 +31,7 @@ impl From for shared::order::Side { } } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct Position { pub asset_id: Uuid, #[serde(deserialize_with = "de::add_slash_to_symbol")] @@ -67,15 +67,15 @@ pub struct Position { } pub async fn get( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result, reqwest::Error> { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - alpaca_client + rate_limiter.until_ready().await; + client .get(&format!("{}/positions", *ALPACA_API_URL)) .send() .await? @@ -102,16 +102,16 @@ pub async fn get( } pub async fn get_by_symbol( - alpaca_client: &Client, - alpaca_rate_limiter: &DefaultDirectRateLimiter, + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, symbol: &str, backoff: Option, ) -> Result, reqwest::Error> { retry_notify( backoff.unwrap_or_default(), || async { - alpaca_rate_limiter.until_ready().await; - let response = alpaca_client + rate_limiter.until_ready().await; + let response = client .get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol)) .send() .await?; @@ -143,3 +143,25 @@ pub async fn get_by_symbol( ) .await } + +pub async fn get_by_symbols( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbols: &[String], + backoff: Option, +) -> Result, reqwest::Error> { + if symbols.len() < 2 { + let symbol = symbols.first().unwrap(); + let position = get_by_symbol(client, rate_limiter, symbol, backoff).await?; + return Ok(position.into_iter().collect()); + } + + let symbols = symbols.iter().collect::>(); + + let positions = get(client, rate_limiter, backoff).await?; + + Ok(positions + .into_iter() + .filter(|position| symbols.contains(&position.symbol)) + .collect()) +} diff --git a/src/types/alpaca/api/outgoing/asset.rs b/src/types/alpaca/api/outgoing/asset.rs new file mode 100644 index 0000000..1efa07f --- /dev/null +++ b/src/types/alpaca/api/outgoing/asset.rs @@ -0,0 +1,21 @@ +use crate::types::alpaca::shared::asset::{Class, Exchange, Status}; +use serde::Serialize; + +#[derive(Serialize)] +pub struct Asset { + pub status: Option, + pub class: Option, + pub exchange: Option, + pub attributes: Option>, +} + +impl Default for Asset { + fn default() -> Self { + Self { + status: None, + class: Some(Class::UsEquity), + exchange: None, + attributes: None, + } + } +} diff --git a/src/types/alpaca/api/outgoing/mod.rs b/src/types/alpaca/api/outgoing/mod.rs index f67cfaf..57308d9 100644 --- a/src/types/alpaca/api/outgoing/mod.rs +++ b/src/types/alpaca/api/outgoing/mod.rs @@ -1,3 +1,4 @@ +pub mod asset; pub mod bar; pub mod calendar; pub mod news; diff --git a/src/types/alpaca/shared/asset.rs b/src/types/alpaca/shared/asset.rs index 735576e..e0755d9 100644 --- a/src/types/alpaca/shared/asset.rs +++ b/src/types/alpaca/shared/asset.rs @@ -1,7 +1,7 @@ use crate::{impl_from_enum, types}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum Class { UsEquity, @@ -10,7 +10,7 @@ pub enum Class { impl_from_enum!(types::Class, Class, UsEquity, Crypto); -#[derive(Deserialize)] +#[derive(Serialize, Deserialize, Clone, Copy)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Exchange { Amex, @@ -36,7 +36,7 @@ impl_from_enum!( Crypto ); -#[derive(Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] #[serde(rename_all = "snake_case")] pub enum Status { Active, diff --git a/src/utils/de.rs b/src/utils/de.rs index 779d1af..cdc8d72 100644 --- a/src/utils/de.rs +++ b/src/utils/de.rs @@ -8,7 +8,8 @@ use std::fmt; use time::{format_description::OwnedFormatItem, macros::format_description, Time}; lazy_static! { - static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap(); + // This *will* break in the future if a crypto pair with one letter is added + static ref RE_SLASH: Regex = Regex::new(r"^(.{2,})(BTC|USD.?)$").unwrap(); static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into(); }