From cdaa2d20a9c581c5270434e205cbe88895840732 Mon Sep 17 00:00:00 2001 From: Nikolaos Karaolidis Date: Thu, 15 Feb 2024 01:09:16 +0000 Subject: [PATCH] Update random bits and bobs Signed-off-by: Nikolaos Karaolidis --- Cargo.lock | 104 ++++++++------------- src/config.rs | 106 ++++++++++------------ src/init.rs | 68 ++++++++++---- src/main.rs | 22 +++-- src/routes/assets.rs | 21 +++-- src/threads/clock.rs | 10 +- src/threads/data/backfill.rs | 57 ++++++------ src/threads/data/mod.rs | 26 +++--- src/threads/data/websocket.rs | 15 +-- src/threads/trading/mod.rs | 3 +- src/threads/trading/websocket.rs | 6 +- src/types/alpaca/api/incoming/account.rs | 15 +-- src/types/alpaca/api/incoming/asset.rs | 15 +-- src/types/alpaca/api/incoming/bar.rs | 18 ++-- src/types/alpaca/api/incoming/clock.rs | 15 +-- src/types/alpaca/api/incoming/news.rs | 15 +-- src/types/alpaca/api/incoming/order.rs | 15 +-- src/types/alpaca/api/incoming/position.rs | 22 +++-- src/types/alpaca/websocket/data/mod.rs | 11 ++- src/types/alpaca/websocket/trading/mod.rs | 11 ++- 20 files changed, 292 insertions(+), 283 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e965978..fcc5acb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aes" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher", @@ -30,9 +30,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", "once_cell", @@ -304,15 +304,13 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.33" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" +checksum = "5bc015644b92d5890fab7489e49d21f879d5c990186827d42ec511919404f38b" dependencies = [ "android-tzdata", "iana-time-zone", - "js-sys", "num-traits", - "wasm-bindgen", "windows-targets 0.52.0", ] @@ -418,9 +416,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.3.2" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" dependencies = [ "cfg-if", ] @@ -494,7 +492,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.3", + "hashbrown", "lock_api", "once_cell", "parking_lot_core", @@ -573,9 +571,9 @@ checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "encode_unicode" @@ -834,7 +832,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.11", - "indexmap 2.2.2", + "indexmap", "slab", "tokio", "tokio-util", @@ -853,7 +851,7 @@ dependencies = [ "futures-sink", "futures-util", "http 1.0.0", - "indexmap 2.2.2", + "indexmap", "slab", "tokio", "tokio-util", @@ -870,12 +868,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.14.3" @@ -897,9 +889,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" +checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd" [[package]] name = "hmac" @@ -1100,22 +1092,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown", ] [[package]] @@ -1219,12 +1201,6 @@ dependencies = [ "redox_syscall", ] -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -1258,9 +1234,9 @@ checksum = "a94d21414c1f4a51209ad204c1776a3d0765002c76c6abcb602a6f09f1e881c7" [[package]] name = "log4rs" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d36ca1786d9e79b8193a68d480a0907b612f109537115c6ff655a3a1967533fd" +checksum = "0816135ae15bd0391cf284eab37e6e3ee0a6ee63d2ceeb659862bd8d0a984ca6" dependencies = [ "anyhow", "arc-swap", @@ -1271,7 +1247,9 @@ dependencies = [ "libc", "log", "log-mdc", + "once_cell", "parking_lot", + "rand", "serde", "serde-value", "serde_json", @@ -1620,9 +1598,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "powerfmt" @@ -1887,7 +1865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19599f60a688b5160247ee9c37a6af8b0c742ee8b160c5b44acc0f0eb265a59f" dependencies = [ "csv", - "hashbrown 0.14.3", + "hashbrown", "itertools 0.11.0", "lazy_static", "protobuf", @@ -2098,14 +2076,15 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.8.26" +version = "0.9.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578a7433b776b56a35785ed5ce9a7e777ac0598aac5a6dd1b4b18a307c7fc71b" +checksum = "adf8a49373e98a4c5f0ceb5d05aa7c648d75f63774981ed95b7c7443bbd50c6e" dependencies = [ - "indexmap 1.9.3", + "indexmap", + "itoa", "ryu", "serde", - "yaml-rust", + "unsafe-libyaml", ] [[package]] @@ -2269,18 +2248,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", @@ -2555,6 +2534,12 @@ dependencies = [ "destructure_traitobject", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" + [[package]] name = "url" version = "2.5.0" @@ -2875,15 +2860,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "yaml-rust" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "zerocopy" version = "0.7.32" diff --git a/src/config.rs b/src/config.rs index c8d9fe9..5b6a180 100644 --- a/src/config.rs +++ b/src/config.rs @@ -15,27 +15,6 @@ use rust_bert::{ use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; use tokio::sync::Mutex; -lazy_static! { - pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") - .expect("ALPACA_MODE must be set.") - .parse() - .expect("ALPACA_MODE must be 'live' or 'paper'"); - static ref ALPACA_URL_SUBDOMAIN: String = match *ALPACA_MODE { - Mode::Live => String::from("api"), - Mode::Paper => String::from("paper-api"), - }; - #[derive(Debug)] - pub static ref ALPACA_API_URL: String = format!( - "https://{subdomain}.alpaca.markets/v2", - subdomain = *ALPACA_URL_SUBDOMAIN - ); - #[derive(Debug)] - pub static ref ALPACA_WEBSOCKET_URL: String = format!( - "wss://{subdomain}.alpaca.markets/stream", - subdomain = *ALPACA_URL_SUBDOMAIN - ); -} - pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news"; @@ -45,69 +24,78 @@ pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; +lazy_static! { + pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") + .expect("ALPACA_MODE must be set.") + .parse() + .expect("ALPACA_MODE must be 'live' or 'paper'"); + pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") + .expect("ALPACA_SOURCE must be set.") + .parse() + .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); + pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); + pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); + #[derive(Debug)] + pub static ref ALPACA_API_URL: String = format!( + "https://{}.alpaca.markets/v2", + match *ALPACA_MODE { + Mode::Live => String::from("api"), + Mode::Paper => String::from("paper-api"), + } + ); + #[derive(Debug)] + pub static ref ALPACA_WEBSOCKET_URL: String = format!( + "wss://{}.alpaca.markets/stream", + match *ALPACA_MODE { + Mode::Live => String::from("api"), + Mode::Paper => String::from("paper-api"), + } + ); + pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS") + .expect("MAX_BERT_INPUTS must be set.") + .parse() + .expect("MAX_BERT_INPUTS must be a positive integer."); + +} + pub struct Config { - pub alpaca_api_key: String, - pub alpaca_api_secret: String, pub alpaca_client: Client, - pub alpaca_rate_limit: DefaultDirectRateLimiter, - pub alpaca_source: Source, + pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub clickhouse_client: clickhouse::Client, - pub max_bert_inputs: usize, - pub sequence_classifier: Arc>, + pub sequence_classifier: Mutex, } impl Config { pub fn from_env() -> Self { - let alpaca_api_key = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); - let alpaca_api_secret = - env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); - let alpaca_source: Source = env::var("ALPACA_SOURCE") - .expect("ALPACA_SOURCE must be set.") - .parse() - .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'."); - - let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."); - let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."); - let clickhouse_password = - env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."); - let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."); - - let max_bert_inputs: usize = env::var("MAX_BERT_INPUTS") - .expect("MAX_BERT_INPUTS must be set.") - .parse() - .expect("MAX_BERT_INPUTS must be a positive integer."); - Self { alpaca_client: Client::builder() .default_headers(HeaderMap::from_iter([ ( HeaderName::from_static("apca-api-key-id"), - HeaderValue::from_str(&alpaca_api_key) + HeaderValue::from_str(&ALPACA_API_KEY) .expect("Alpaca API key must not contain invalid characters."), ), ( HeaderName::from_static("apca-api-secret-key"), - HeaderValue::from_str(&alpaca_api_secret) + HeaderValue::from_str(&ALPACA_API_SECRET) .expect("Alpaca API secret must not contain invalid characters."), ), ])) .build() .unwrap(), - alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source { + alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, Source::Otc => unimplemented!("OTC rate limit not implemented."), })), - alpaca_source, clickhouse_client: clickhouse::Client::default() - .with_url(clickhouse_url) - .with_user(clickhouse_user) - .with_password(clickhouse_password) - .with_database(clickhouse_db), - alpaca_api_key, - alpaca_api_secret, - max_bert_inputs, - sequence_classifier: Arc::new(Mutex::new( + .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) + .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) + .with_password( + env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), + ) + .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), + sequence_classifier: Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, ModelResource::Torch(Box::new(LocalResource { @@ -125,7 +113,7 @@ impl Config { None, )) .unwrap(), - )), + ), } } diff --git a/src/init.rs b/src/init.rs index e8d8b55..9219fd0 100644 --- a/src/init.rs +++ b/src/init.rs @@ -4,13 +4,18 @@ use crate::{ types::alpaca::{self, api, shared::Sort}, }; use log::{info, warn}; -use std::{collections::HashSet, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; +use tokio::join; pub async fn check_account(config: &Arc) { - let account = alpaca::api::incoming::account::get(config, None) - .await - .unwrap(); + let account = alpaca::api::incoming::account::get( + &config.alpaca_client, + &config.alpaca_rate_limiter, + None, + ) + .await + .unwrap(); assert!( !(account.status != alpaca::api::incoming::account::Status::Active), @@ -41,7 +46,8 @@ pub async fn rehydrate_orders(config: &Arc) { let mut after = OffsetDateTime::UNIX_EPOCH; while let Some(message) = api::incoming::order::get( - config, + &config.alpaca_client, + &config.alpaca_rate_limiter, &api::outgoing::order::Order { status: Some(api::outgoing::order::Status::All), limit: Some(500), @@ -74,30 +80,52 @@ pub async fn rehydrate_orders(config: &Arc) { info!("Rehydrated order data."); } -pub async fn check_positions(config: &Arc) { +pub async fn rehydrate_positions(config: &Arc) { + info!("Rehydrating position data."); + let positions_future = async { - alpaca::api::incoming::position::get(config, None) - .await - .unwrap() + alpaca::api::incoming::position::get( + &config.alpaca_client, + &config.alpaca_rate_limiter, + None, + ) + .await + .unwrap() + .into_iter() + .map(|position| (position.symbol.clone(), position)) + .collect::>() }; let assets_future = async { database::assets::select(&config.clickhouse_client) .await .unwrap() - .into_iter() - .map(|asset| asset.symbol) - .collect::>() }; - let (positions, assets) = tokio::join!(positions_future, assets_future); + let (mut positions, assets) = join!(positions_future, assets_future); - for position in positions { - if !assets.contains(&position.symbol) { - warn!( - "Position for unmonitored asset: {}, {} shares.", - position.symbol, position.qty - ); - } + let assets = assets + .into_iter() + .map(|mut asset| { + if let Some(position) = positions.remove(&asset.symbol) { + asset.qty = position.qty_available; + } else { + asset.qty = 0.0; + } + asset + }) + .collect::>(); + + database::assets::upsert_batch(&config.clickhouse_client, &assets) + .await + .unwrap(); + + for position in positions.values() { + warn!( + "Position for unmonitored asset: {}, {} shares.", + position.symbol, position.qty + ); } + + info!("Rehydrated position data."); } diff --git a/src/main.rs b/src/main.rs index 6d1a730..445843b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,7 @@ mod utils; use config::Config; use dotenv::dotenv; use log4rs::config::Deserializers; -use tokio::{spawn, sync::mpsc, try_join}; +use tokio::{join, spawn, sync::mpsc, try_join}; #[tokio::main] async fn main() { @@ -21,6 +21,12 @@ async fn main() { log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); let config = Config::arc_from_env(); + try_join!( + database::backfills_bars::unfresh(&config.clickhouse_client), + database::backfills_news::unfresh(&config.clickhouse_client) + ) + .unwrap(); + database::cleanup_all(&config.clickhouse_client) .await .unwrap(); @@ -28,15 +34,11 @@ async fn main() { .await .unwrap(); - try_join!( - database::backfills_bars::unfresh(&config.clickhouse_client), - database::backfills_news::unfresh(&config.clickhouse_client) - ) - .unwrap(); - init::check_account(&config).await; - init::rehydrate_orders(&config).await; - init::check_positions(&config).await; + join!( + init::rehydrate_orders(&config), + init::rehydrate_positions(&config) + ); spawn(threads::trading::run(config.clone())); @@ -61,7 +63,7 @@ async fn main() { create_send_await!( data_sender, threads::data::Message::new, - threads::data::Action::Add, + threads::data::Action::Enable, assets ); diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 533c656..8f9b1e9 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -50,14 +50,19 @@ pub async fn add( return Err(StatusCode::CONFLICT); } - let asset = alpaca::api::incoming::asset::get_by_symbol(&config, &request.symbol, None) - .await - .map_err(|e| { - e.status() - .map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| { - StatusCode::from_u16(status.as_u16()).unwrap() - }) - })?; + let asset = alpaca::api::incoming::asset::get_by_symbol( + &config.alpaca_client, + &config.alpaca_rate_limiter, + &request.symbol, + None, + ) + .await + .map_err(|e| { + e.status() + .map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| { + StatusCode::from_u16(status.as_u16()).unwrap() + }) + })?; if !asset.tradable || !asset.fractionable { return Err(StatusCode::FORBIDDEN); diff --git a/src/threads/clock.rs b/src/threads/clock.rs index f0adeb9..c849028 100644 --- a/src/threads/clock.rs +++ b/src/threads/clock.rs @@ -36,9 +36,13 @@ impl From for Message { pub async fn run(config: Arc, sender: mpsc::Sender) { loop { - let clock = alpaca::api::incoming::clock::get(&config, Some(backoff::infinite())) - .await - .unwrap(); + let clock = alpaca::api::incoming::clock::get( + &config.alpaca_client, + &config.alpaca_rate_limiter, + Some(backoff::infinite()), + ) + .await + .unwrap(); let sleep_until = duration_until(if clock.is_open { info!("Market is open, will close at {}.", clock.next_close); diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index ee48354..c3caf27 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -1,6 +1,9 @@ use super::ThreadType; use crate::{ - config::{Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_STOCK_DATA_API_URL}, + config::{ + Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL, + MAX_BERT_INPUTS, + }, database, types::{ alpaca::{ @@ -30,23 +33,24 @@ pub enum Action { Purge, } -impl From for Action { +impl From for Option { fn from(action: super::Action) -> Self { match action { - super::Action::Add => Self::Backfill, - super::Action::Remove => Self::Purge, + super::Action::Add | super::Action::Enable => Some(Action::Backfill), + super::Action::Remove => Some(Action::Purge), + super::Action::Disable => None, } } } pub struct Message { - pub action: Action, + pub action: Option, pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { - pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { + pub fn new(action: Option, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel::<()>(); ( Self { @@ -77,7 +81,6 @@ pub async fn run(handler: Arc>, mut receiver: mpsc::Receiver { + Some(Action::Backfill) => { let log_string = handler.log_string(); for symbol in message.symbols { @@ -134,7 +137,7 @@ async fn handle_backfill_message( ); } } - Action::Purge => { + Some(Action::Purge) => { for symbol in &message.symbols { if let Some(job) = backfill_jobs.remove(symbol) { if !job.is_finished() { @@ -150,6 +153,7 @@ async fn handle_backfill_message( ) .unwrap(); } + None => {} } message.response.send(()).unwrap(); @@ -159,7 +163,6 @@ struct BarHandler { config: Arc, data_url: &'static str, api_query_constructor: fn( - config: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, @@ -168,7 +171,6 @@ struct BarHandler { } fn us_equity_query_constructor( - config: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, @@ -182,7 +184,7 @@ fn us_equity_query_constructor( limit: Some(10000), adjustment: None, asof: None, - feed: Some(config.alpaca_source), + feed: Some(*ALPACA_SOURCE), currency: None, page_token: next_page_token, sort: Some(Sort::Asc), @@ -190,7 +192,6 @@ fn us_equity_query_constructor( } fn crypto_query_constructor( - _: &Arc, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, @@ -226,7 +227,7 @@ impl Handler for BarHandler { } async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { - if self.config.alpaca_source == Source::Iex { + if *ALPACA_SOURCE == Source::Iex { let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); info!("Queing bar backfill for {} in {:?}.", symbol, run_delay); sleep(run_delay).await; @@ -241,10 +242,10 @@ impl Handler for BarHandler { loop { let Ok(message) = api::incoming::bar::get_historical( - &self.config, + &self.config.alpaca_client, + &self.config.alpaca_rate_limiter, self.data_url, &(self.api_query_constructor)( - &self.config, symbol.clone(), fetch_from, fetch_to, @@ -328,7 +329,8 @@ impl Handler for NewsHandler { loop { let Ok(message) = api::incoming::news::get_historical( - &self.config, + &self.config.alpaca_client, + &self.config.alpaca_rate_limiter, &api::outgoing::news::News { symbols: vec![symbol.clone()], start: Some(fetch_from), @@ -367,18 +369,15 @@ impl Handler for NewsHandler { .map(|news| format!("{}\n\n{}", news.headline, news.content)) .collect::>(); - let predictions = join_all(inputs.chunks(self.config.max_bert_inputs).map(|inputs| { - let sequence_classifier = self.config.sequence_classifier.clone(); - async move { - let sequence_classifier = sequence_classifier.lock().await; - block_in_place(|| { - sequence_classifier - .predict(inputs.iter().map(String::as_str).collect::>()) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>() - }) - } + let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move { + let sequence_classifier = self.config.sequence_classifier.lock().await; + block_in_place(|| { + sequence_classifier + .predict(inputs.iter().map(String::as_str).collect::>()) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>() + }) })) .await .into_iter() diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index 8d78fc2..322d16f 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -4,7 +4,7 @@ mod websocket; use super::clock; use crate::{ config::{ - Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, + Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_WEBSOCKET_URL, }, create_send_await, database, @@ -21,9 +21,12 @@ use tokio::{ use tokio_tungstenite::connect_async; #[derive(Clone, Copy)] +#[allow(dead_code)] pub enum Action { Add, + Enable, Remove, + Disable, } pub struct Message { @@ -100,10 +103,7 @@ async fn init_thread( ) { let websocket_url = match thread_type { ThreadType::Bars(Class::UsEquity) => { - format!( - "{}/{}", - ALPACA_STOCK_DATA_WEBSOCKET_URL, &config.alpaca_source - ) + format!("{}/{}", ALPACA_STOCK_DATA_WEBSOCKET_URL, *ALPACA_SOURCE) } ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(), ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), @@ -111,8 +111,7 @@ async fn init_thread( let (websocket, _) = connect_async(websocket_url).await.unwrap(); let (mut websocket_sink, mut websocket_stream) = websocket.split(); - alpaca::websocket::data::authenticate(&config, &mut websocket_sink, &mut websocket_stream) - .await; + alpaca::websocket::data::authenticate(&mut websocket_sink, &mut websocket_stream).await; let (backfill_sender, backfill_receiver) = mpsc::channel(100); spawn(backfill::run( @@ -223,7 +222,8 @@ async fn handle_message( async move { let asset_future = async { alpaca::api::incoming::asset::get_by_symbol( - &config, + &config.alpaca_client, + &config.alpaca_rate_limiter, &symbol, Some(backoff::infinite()), ) @@ -233,7 +233,8 @@ async fn handle_message( let position_future = async { alpaca::api::incoming::position::get_by_symbol( - &config, + &config.alpaca_rate_limiter, + &config.alpaca_client, &symbol, Some(backoff::infinite()), ) @@ -256,6 +257,7 @@ async fn handle_message( .await .unwrap(); } + _ => {} } message.response.send(()).unwrap(); @@ -292,7 +294,7 @@ async fn handle_clock_message( create_send_await!( bars_us_equity_backfill_sender, backfill::Message::new, - backfill::Action::Backfill, + Some(backfill::Action::Backfill), us_equity_symbols.clone() ); }; @@ -301,7 +303,7 @@ async fn handle_clock_message( create_send_await!( bars_crypto_backfill_sender, backfill::Message::new, - backfill::Action::Backfill, + Some(backfill::Action::Backfill), crypto_symbols.clone() ); }; @@ -310,7 +312,7 @@ async fn handle_clock_message( create_send_await!( news_backfill_sender, backfill::Message::new, - backfill::Action::Backfill, + Some(backfill::Action::Backfill), symbols ); }; diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index 26bc702..6375793 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -26,23 +26,23 @@ pub enum Action { Unsubscribe, } -impl From for Action { +impl From for Option { fn from(action: super::Action) -> Self { match action { - super::Action::Add => Self::Subscribe, - super::Action::Remove => Self::Unsubscribe, + super::Action::Add | super::Action::Enable => Some(Action::Subscribe), + super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe), } } } pub struct Message { - pub action: Action, + pub action: Option, pub symbols: Vec, pub response: oneshot::Sender<()>, } impl Message { - pub fn new(action: Action, symbols: Vec) -> (Self, oneshot::Receiver<()>) { + pub fn new(action: Option, symbols: Vec) -> (Self, oneshot::Receiver<()>) { let (sender, receiver) = oneshot::channel(); ( Self { @@ -115,7 +115,7 @@ async fn handle_message( message: Message, ) { match message.action { - Action::Subscribe => { + Some(Action::Subscribe) => { let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message .symbols .iter() @@ -144,7 +144,7 @@ async fn handle_message( join_all(receivers).await; } - Action::Unsubscribe => { + Some(Action::Unsubscribe) => { let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message .symbols .iter() @@ -173,6 +173,7 @@ async fn handle_message( join_all(receivers).await; } + None => {} } message.response.send(()).unwrap(); diff --git a/src/threads/trading/mod.rs b/src/threads/trading/mod.rs index 69db55e..5b4840c 100644 --- a/src/threads/trading/mod.rs +++ b/src/threads/trading/mod.rs @@ -13,8 +13,7 @@ pub async fn run(config: Arc) { let (websocket, _) = connect_async(&*ALPACA_WEBSOCKET_URL).await.unwrap(); let (mut websocket_sink, mut websocket_stream) = websocket.split(); - alpaca::websocket::trading::authenticate(&config, &mut websocket_sink, &mut websocket_stream) - .await; + alpaca::websocket::trading::authenticate(&mut websocket_sink, &mut websocket_stream).await; alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await; spawn(websocket::run(config, websocket_stream, websocket_sink)); diff --git a/src/threads/trading/websocket.rs b/src/threads/trading/websocket.rs index bace439..ff66bc8 100644 --- a/src/threads/trading/websocket.rs +++ b/src/threads/trading/websocket.rs @@ -22,10 +22,8 @@ pub async fn run( loop { let message = websocket_stream.next().await.unwrap().unwrap(); - let config = config.clone(); - spawn(handle_websocket_message( - config, + config.clone(), websocket_sink.clone(), message, )); @@ -42,7 +40,7 @@ async fn handle_websocket_message( if let Ok(message) = from_str::( &String::from_utf8_lossy(&message), ) { - spawn(handle_parsed_websocket_message(config.clone(), message)); + handle_parsed_websocket_message(config, message).await; } else { error!("Failed to deserialize websocket message: {:?}", message); } diff --git a/src/types/alpaca/api/incoming/account.rs b/src/types/alpaca/api/incoming/account.rs index fccb963..580ba47 100644 --- a/src/types/alpaca/api/incoming/account.rs +++ b/src/types/alpaca/api/incoming/account.rs @@ -1,12 +1,13 @@ -use crate::config::{Config, ALPACA_API_URL}; +use crate::config::ALPACA_API_URL; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; +use reqwest::{Client, Error}; use serde::Deserialize; use serde_aux::field_attributes::{ deserialize_number_from_string, deserialize_option_number_from_string, }; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use time::OffsetDateTime; use uuid::Uuid; @@ -80,15 +81,15 @@ pub struct Account { } pub async fn get( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(&format!("{}/account", *ALPACA_API_URL)) .send() .await? diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index 93177f9..fb7367c 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -1,17 +1,18 @@ use super::position::Position; use crate::{ - config::{Config, ALPACA_API_URL}, + config::ALPACA_API_URL, types::{ self, alpaca::shared::asset::{Class, Exchange, Status}, }, }; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; +use reqwest::{Client, Error}; use serde::Deserialize; use serde_aux::field_attributes::deserialize_option_number_from_string; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use uuid::Uuid; #[allow(clippy::struct_excessive_bools)] @@ -47,16 +48,16 @@ impl From<(Asset, Option)> for types::Asset { } pub async fn get_by_symbol( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, symbol: &str, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol)) .send() .await? diff --git a/src/types/alpaca/api/incoming/bar.rs b/src/types/alpaca/api/incoming/bar.rs index 4511cef..4ad1182 100644 --- a/src/types/alpaca/api/incoming/bar.rs +++ b/src/types/alpaca/api/incoming/bar.rs @@ -1,12 +1,10 @@ -use crate::{ - config::Config, - types::{self, alpaca::api::outgoing}, -}; +use crate::types::{self, alpaca::api::outgoing}; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; +use reqwest::{Client, Error}; use serde::Deserialize; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, time::Duration}; use time::OffsetDateTime; #[derive(Deserialize)] @@ -53,7 +51,8 @@ pub struct Message { } pub async fn get_historical( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, data_url: &str, query: &outgoing::bar::Bar, backoff: Option, @@ -61,9 +60,8 @@ pub async fn get_historical( retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(data_url) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/clock.rs b/src/types/alpaca/api/incoming/clock.rs index 33ec1ec..23d1841 100644 --- a/src/types/alpaca/api/incoming/clock.rs +++ b/src/types/alpaca/api/incoming/clock.rs @@ -1,9 +1,10 @@ -use crate::config::{Config, ALPACA_API_URL}; +use crate::config::ALPACA_API_URL; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; +use reqwest::{Client, Error}; use serde::Deserialize; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use time::OffsetDateTime; #[derive(Deserialize)] @@ -18,15 +19,15 @@ pub struct Clock { } pub async fn get( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(&format!("{}/clock", *ALPACA_API_URL)) .send() .await? diff --git a/src/types/alpaca/api/incoming/news.rs b/src/types/alpaca/api/incoming/news.rs index 548237d..b682718 100644 --- a/src/types/alpaca/api/incoming/news.rs +++ b/src/types/alpaca/api/incoming/news.rs @@ -1,5 +1,5 @@ use crate::{ - config::{Config, ALPACA_NEWS_DATA_API_URL}, + config::ALPACA_NEWS_DATA_API_URL, types::{ self, alpaca::{api::outgoing, shared::news::normalize_html_content}, @@ -7,10 +7,11 @@ use crate::{ utils::de, }; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; +use reqwest::{Client, Error}; use serde::Deserialize; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use time::OffsetDateTime; #[derive(Deserialize)] @@ -73,16 +74,16 @@ pub struct Message { } pub async fn get_historical( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, query: &outgoing::news::News, backoff: Option, ) -> Result { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(ALPACA_NEWS_DATA_API_URL) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/order.rs b/src/types/alpaca/api/incoming/order.rs index 65e2fd8..1967104 100644 --- a/src/types/alpaca/api/incoming/order.rs +++ b/src/types/alpaca/api/incoming/order.rs @@ -1,25 +1,26 @@ use crate::{ - config::{Config, ALPACA_API_URL}, + config::ALPACA_API_URL, types::alpaca::{api::outgoing, shared}, }; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; -use reqwest::Error; -use std::{sync::Arc, time::Duration}; +use reqwest::{Client, Error}; +use std::time::Duration; pub use shared::order::Order; pub async fn get( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, query: &outgoing::order::Order, backoff: Option, ) -> Result, Error> { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(&format!("{}/orders", *ALPACA_API_URL)) .query(query) .send() diff --git a/src/types/alpaca/api/incoming/position.rs b/src/types/alpaca/api/incoming/position.rs index b18f6a6..5ba8c73 100644 --- a/src/types/alpaca/api/incoming/position.rs +++ b/src/types/alpaca/api/incoming/position.rs @@ -1,5 +1,5 @@ use crate::{ - config::{Config, ALPACA_API_URL}, + config::ALPACA_API_URL, types::alpaca::shared::{ self, asset::{Class, Exchange}, @@ -7,10 +7,12 @@ use crate::{ utils::de, }; use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; use log::warn; +use reqwest::Client; use serde::Deserialize; use serde_aux::field_attributes::deserialize_number_from_string; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use uuid::Uuid; #[derive(Deserialize)] @@ -65,15 +67,15 @@ pub struct Position { } pub async fn get( - config: &Arc, + alpaca_client: &Client, + alpaca_rate_limiter: &DefaultDirectRateLimiter, backoff: Option, ) -> Result, reqwest::Error> { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + alpaca_client .get(&format!("{}/positions", *ALPACA_API_URL)) .send() .await? @@ -98,16 +100,16 @@ pub async fn get( } pub async fn get_by_symbol( - config: &Arc, + alpaca_rate_limiter: &DefaultDirectRateLimiter, + alpaca_client: &Client, symbol: &str, backoff: Option, ) -> Result, reqwest::Error> { retry_notify( backoff.unwrap_or_default(), || async { - config.alpaca_rate_limit.until_ready().await; - let response = config - .alpaca_client + alpaca_rate_limiter.until_ready().await; + let response = alpaca_client .get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol)) .send() .await?; diff --git a/src/types/alpaca/websocket/data/mod.rs b/src/types/alpaca/websocket/data/mod.rs index 62fd831..7d9a289 100644 --- a/src/types/alpaca/websocket/data/mod.rs +++ b/src/types/alpaca/websocket/data/mod.rs @@ -1,19 +1,20 @@ pub mod incoming; pub mod outgoing; -use crate::{config::Config, types::alpaca::websocket}; +use crate::{ + config::{ALPACA_API_KEY, ALPACA_API_SECRET}, + types::alpaca::websocket, +}; use core::panic; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use serde_json::{from_str, to_string}; -use std::sync::Arc; use tokio::net::TcpStream; use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; pub async fn authenticate( - config: &Arc, sink: &mut SplitSink>, Message>, stream: &mut SplitStream>>, ) { @@ -31,8 +32,8 @@ pub async fn authenticate( sink.send(Message::Text( to_string(&websocket::data::outgoing::Message::Auth( websocket::auth::Message { - key: config.alpaca_api_key.clone(), - secret: config.alpaca_api_secret.clone(), + key: (*ALPACA_API_KEY).clone(), + secret: (*ALPACA_API_SECRET).clone(), }, )) .unwrap(), diff --git a/src/types/alpaca/websocket/trading/mod.rs b/src/types/alpaca/websocket/trading/mod.rs index 638f6e8..6de6dd6 100644 --- a/src/types/alpaca/websocket/trading/mod.rs +++ b/src/types/alpaca/websocket/trading/mod.rs @@ -1,27 +1,28 @@ pub mod incoming; pub mod outgoing; -use crate::{config::Config, types::alpaca::websocket}; +use crate::{ + config::{ALPACA_API_KEY, ALPACA_API_SECRET}, + types::alpaca::websocket, +}; use core::panic; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use serde_json::{from_str, to_string}; -use std::sync::Arc; use tokio::net::TcpStream; use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; pub async fn authenticate( - config: &Arc, sink: &mut SplitSink>, Message>, stream: &mut SplitStream>>, ) { sink.send(Message::Text( to_string(&websocket::trading::outgoing::Message::Auth( websocket::auth::Message { - key: config.alpaca_api_key.clone(), - secret: config.alpaca_api_secret.clone(), + key: (*ALPACA_API_KEY).clone(), + secret: (*ALPACA_API_SECRET).clone(), }, )) .unwrap(),