diff --git a/src/init.rs b/src/init.rs index c6f654c..e8d48bb 100644 --- a/src/init.rs +++ b/src/init.rs @@ -3,13 +3,13 @@ use crate::{ database, }; use log::{info, warn}; -use qrust::types::alpaca; +use qrust::{alpaca, types}; use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; use tokio::join; pub async fn check_account(config: &Arc) { - let account = alpaca::api::incoming::account::get( + let account = alpaca::account::get( &config.alpaca_client, &config.alpaca_rate_limiter, None, @@ -19,7 +19,7 @@ pub async fn check_account(config: &Arc) { .unwrap(); assert!( - !(account.status != alpaca::api::incoming::account::Status::Active), + !(account.status != types::alpaca::api::incoming::account::Status::Active), "Account status is not active: {:?}.", account.status ); @@ -46,11 +46,11 @@ pub async fn rehydrate_orders(config: &Arc) { let mut orders = vec![]; let mut after = OffsetDateTime::UNIX_EPOCH; - while let Some(message) = alpaca::api::incoming::order::get( + while let Some(message) = alpaca::orders::get( &config.alpaca_client, &config.alpaca_rate_limiter, - &alpaca::api::outgoing::order::Order { - status: Some(alpaca::api::outgoing::order::Status::All), + &types::alpaca::api::outgoing::order::Order { + status: Some(types::alpaca::api::outgoing::order::Status::All), after: Some(after), ..Default::default() }, @@ -67,7 +67,7 @@ pub async fn rehydrate_orders(config: &Arc) { let orders = orders .into_iter() - .flat_map(&alpaca::api::incoming::order::Order::normalize) + .flat_map(&types::alpaca::api::incoming::order::Order::normalize) .collect::>(); database::orders::upsert_batch( @@ -85,7 +85,7 @@ pub async fn rehydrate_positions(config: &Arc) { info!("Rehydrating position data."); let positions_future = async { - alpaca::api::incoming::position::get( + alpaca::positions::get( &config.alpaca_client, &config.alpaca_rate_limiter, None, diff --git a/src/lib/alpaca/account.rs b/src/lib/alpaca/account.rs new file mode 100644 index 0000000..c13a325 --- /dev/null +++ b/src/lib/alpaca/account.rs @@ -0,0 +1,42 @@ +use crate::types::alpaca::api::incoming::account::Account; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use std::time::Duration; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + backoff: Option, + api_base: &str, +) -> Result { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/account", api_base)) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get account, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/assets.rs b/src/lib/alpaca/assets.rs new file mode 100644 index 0000000..e8b3e5f --- /dev/null +++ b/src/lib/alpaca/assets.rs @@ -0,0 +1,139 @@ +use crate::types::alpaca::api::{ + incoming::asset::{Asset, Class}, + outgoing, +}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use itertools::Itertools; +use log::warn; +use reqwest::{Client, Error}; +use std::{collections::HashSet, time::Duration}; +use tokio::try_join; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + query: &outgoing::asset::Asset, + backoff: Option, + api_base: &str, +) -> Result, Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/assets", api_base)) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some( + reqwest::StatusCode::BAD_REQUEST + | reqwest::StatusCode::FORBIDDEN + | reqwest::StatusCode::NOT_FOUND, + ) => backoff::Error::Permanent(e), + _ => e.into(), + })? + .json::>() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get assets, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} + +pub async fn get_by_symbol( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbol: &str, + backoff: Option, + api_base: &str, +) -> Result { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!( + "https://{}.alpaca.markets/v2/assets/{}", + api_base, symbol + )) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some( + reqwest::StatusCode::BAD_REQUEST + | reqwest::StatusCode::FORBIDDEN + | reqwest::StatusCode::NOT_FOUND, + ) => backoff::Error::Permanent(e), + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get asset, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} + +pub async fn get_by_symbols( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbols: &[String], + backoff: Option, + api_base: &str, +) -> Result, Error> { + if symbols.len() == 1 { + let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; + return Ok(vec![asset]); + } + + let symbols = symbols.iter().collect::>(); + + let backoff_clone = backoff.clone(); + + let us_equity_query = outgoing::asset::Asset { + class: Some(Class::UsEquity), + ..Default::default() + }; + + let us_equity_assets = get( + client, + rate_limiter, + &us_equity_query, + backoff_clone, + api_base, + ); + + let crypto_query = outgoing::asset::Asset { + class: Some(Class::Crypto), + ..Default::default() + }; + + let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base); + + let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?; + + Ok(crypto_assets + .into_iter() + .chain(us_equity_assets) + .dedup_by(|a, b| a.symbol == b.symbol) + .filter(|asset| symbols.contains(&asset.symbol)) + .collect()) +} diff --git a/src/lib/alpaca/bars.rs b/src/lib/alpaca/bars.rs new file mode 100644 index 0000000..e3c896b --- /dev/null +++ b/src/lib/alpaca/bars.rs @@ -0,0 +1,53 @@ +use crate::types::alpaca::api::{incoming::bar::Bar, outgoing}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use serde::Deserialize; +use std::{collections::HashMap, time::Duration}; + +pub const MAX_LIMIT: i64 = 10_000; + +#[derive(Deserialize)] +pub struct Message { + pub bars: HashMap>, + pub next_page_token: Option, +} + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + data_url: &str, + query: &outgoing::bar::Bar, + backoff: Option, +) -> Result { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(data_url) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get historical bars, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/calendar.rs b/src/lib/alpaca/calendar.rs new file mode 100644 index 0000000..8f4cff7 --- /dev/null +++ b/src/lib/alpaca/calendar.rs @@ -0,0 +1,44 @@ +use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use std::time::Duration; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + query: &outgoing::calendar::Calendar, + backoff: Option, + api_base: &str, +) -> Result, Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/calendar", api_base)) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::>() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get calendar, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/clock.rs b/src/lib/alpaca/clock.rs new file mode 100644 index 0000000..4b79997 --- /dev/null +++ b/src/lib/alpaca/clock.rs @@ -0,0 +1,42 @@ +use crate::types::alpaca::api::incoming::clock::Clock; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use std::time::Duration; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + backoff: Option, + api_base: &str, +) -> Result { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/clock", api_base)) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get clock, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/mod.rs b/src/lib/alpaca/mod.rs new file mode 100644 index 0000000..90dd11b --- /dev/null +++ b/src/lib/alpaca/mod.rs @@ -0,0 +1,8 @@ +pub mod account; +pub mod assets; +pub mod bars; +pub mod calendar; +pub mod clock; +pub mod news; +pub mod orders; +pub mod positions; diff --git a/src/lib/alpaca/news.rs b/src/lib/alpaca/news.rs new file mode 100644 index 0000000..6464845 --- /dev/null +++ b/src/lib/alpaca/news.rs @@ -0,0 +1,52 @@ +use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use serde::Deserialize; +use std::time::Duration; + +pub const MAX_LIMIT: i64 = 50; + +#[derive(Deserialize)] +pub struct Message { + pub news: Vec, + pub next_page_token: Option, +} + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + query: &outgoing::news::News, + backoff: Option, +) -> Result { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(ALPACA_NEWS_DATA_API_URL) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get historical news, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/orders.rs b/src/lib/alpaca/orders.rs new file mode 100644 index 0000000..8c65738 --- /dev/null +++ b/src/lib/alpaca/orders.rs @@ -0,0 +1,46 @@ +use crate::types::alpaca::{api::outgoing, shared::order}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::{Client, Error}; +use std::time::Duration; + +pub use order::Order; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + query: &outgoing::order::Order, + backoff: Option, + api_base: &str, +) -> Result, Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/orders", api_base)) + .query(query) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::>() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get orders, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/lib/alpaca/positions.rs b/src/lib/alpaca/positions.rs new file mode 100644 index 0000000..79b619d --- /dev/null +++ b/src/lib/alpaca/positions.rs @@ -0,0 +1,111 @@ +use crate::types::alpaca::api::incoming::position::Position; +use backoff::{future::retry_notify, ExponentialBackoff}; +use governor::DefaultDirectRateLimiter; +use log::warn; +use reqwest::Client; +use std::{collections::HashSet, time::Duration}; + +pub async fn get( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + backoff: Option, + api_base: &str, +) -> Result, reqwest::Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + client + .get(&format!("https://{}.alpaca.markets/v2/positions", api_base)) + .send() + .await? + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::>() + .await + .map_err(backoff::Error::Permanent) + }, + |e, duration: Duration| { + warn!( + "Failed to get positions, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} + +pub async fn get_by_symbol( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbol: &str, + backoff: Option, + api_base: &str, +) -> Result, reqwest::Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + rate_limiter.until_ready().await; + let response = client + .get(&format!( + "https://{}.alpaca.markets/v2/positions/{}", + api_base, symbol + )) + .send() + .await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Ok(None); + } + + response + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { + backoff::Error::Permanent(e) + } + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + .map(Some) + }, + |e, duration: Duration| { + warn!( + "Failed to get position, will retry in {} seconds: {}.", + duration.as_secs(), + e + ); + }, + ) + .await +} + +pub async fn get_by_symbols( + client: &Client, + rate_limiter: &DefaultDirectRateLimiter, + symbols: &[String], + backoff: Option, + api_base: &str, +) -> Result, reqwest::Error> { + if symbols.len() == 1 { + let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; + return Ok(position.into_iter().collect()); + } + + let symbols = symbols.iter().collect::>(); + + let positions = get(client, rate_limiter, backoff, api_base).await?; + + Ok(positions + .into_iter() + .filter(|position| symbols.contains(&position.symbol)) + .collect()) +} diff --git a/src/lib/mod.rs b/src/lib/mod.rs index 284cd23..4fb7569 100644 --- a/src/lib/mod.rs +++ b/src/lib/mod.rs @@ -1,3 +1,4 @@ +pub mod alpaca; pub mod database; pub mod types; pub mod utils; diff --git a/src/lib/types/alpaca/api/incoming/account.rs b/src/lib/types/alpaca/api/incoming/account.rs index 831d55c..be5da49 100644 --- a/src/lib/types/alpaca/api/incoming/account.rs +++ b/src/lib/types/alpaca/api/incoming/account.rs @@ -1,12 +1,7 @@ -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; use serde::Deserialize; use serde_aux::field_attributes::{ deserialize_number_from_string, deserialize_option_number_from_string, }; -use std::time::Duration; use time::OffsetDateTime; use uuid::Uuid; @@ -78,39 +73,3 @@ pub struct Account { #[serde(deserialize_with = "deserialize_number_from_string")] pub regt_buying_power: f64, } - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - backoff: Option, - api_base: &str, -) -> Result { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/account", api_base)) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get account, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} diff --git a/src/lib/types/alpaca/api/incoming/asset.rs b/src/lib/types/alpaca/api/incoming/asset.rs index 093c43e..712dc50 100644 --- a/src/lib/types/alpaca/api/incoming/asset.rs +++ b/src/lib/types/alpaca/api/incoming/asset.rs @@ -1,22 +1,11 @@ use super::position::Position; -use crate::types::{ - self, - alpaca::{ - api::outgoing, - shared::asset::{Class, Exchange, Status}, - }, -}; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use itertools::Itertools; -use log::warn; -use reqwest::{Client, Error}; +use crate::types::{self, alpaca::shared::asset}; use serde::Deserialize; use serde_aux::field_attributes::deserialize_option_number_from_string; -use std::{collections::HashSet, time::Duration}; -use tokio::try_join; use uuid::Uuid; +pub use asset::{Class, Exchange, Status}; + #[allow(clippy::struct_excessive_bools)] #[derive(Deserialize, Clone)] pub struct Asset { @@ -48,131 +37,3 @@ impl From<(Asset, Option)> for types::Asset { } } } - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - query: &outgoing::asset::Asset, - backoff: Option, - api_base: &str, -) -> Result, Error> { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/assets", api_base)) - .query(query) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some( - reqwest::StatusCode::BAD_REQUEST - | reqwest::StatusCode::FORBIDDEN - | reqwest::StatusCode::NOT_FOUND, - ) => backoff::Error::Permanent(e), - _ => e.into(), - })? - .json::>() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get assets, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} - -pub async fn get_by_symbol( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - symbol: &str, - backoff: Option, - api_base: &str, -) -> Result { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!( - "https://{}.alpaca.markets/v2/assets/{}", - api_base, symbol - )) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some( - reqwest::StatusCode::BAD_REQUEST - | reqwest::StatusCode::FORBIDDEN - | reqwest::StatusCode::NOT_FOUND, - ) => backoff::Error::Permanent(e), - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get asset, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} - -pub async fn get_by_symbols( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - symbols: &[String], - backoff: Option, - api_base: &str, -) -> Result, Error> { - if symbols.len() == 1 { - let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; - return Ok(vec![asset]); - } - - let symbols = symbols.iter().collect::>(); - - let backoff_clone = backoff.clone(); - - let us_equity_query = outgoing::asset::Asset { - class: Some(Class::UsEquity), - ..Default::default() - }; - - let us_equity_assets = get( - client, - rate_limiter, - &us_equity_query, - backoff_clone, - api_base, - ); - - let crypto_query = outgoing::asset::Asset { - class: Some(Class::Crypto), - ..Default::default() - }; - - let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base); - - let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?; - - Ok(crypto_assets - .into_iter() - .chain(us_equity_assets) - .dedup_by(|a, b| a.symbol == b.symbol) - .filter(|asset| symbols.contains(&asset.symbol)) - .collect()) -} diff --git a/src/lib/types/alpaca/api/incoming/bar.rs b/src/lib/types/alpaca/api/incoming/bar.rs index ece3265..94dd5c0 100644 --- a/src/lib/types/alpaca/api/incoming/bar.rs +++ b/src/lib/types/alpaca/api/incoming/bar.rs @@ -1,10 +1,5 @@ -use crate::types::{self, alpaca::api::outgoing}; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; +use crate::types; use serde::Deserialize; -use std::{collections::HashMap, time::Duration}; use time::OffsetDateTime; #[derive(Deserialize)] @@ -43,47 +38,3 @@ impl From<(Bar, String)> for types::Bar { } } } - -#[derive(Deserialize)] -pub struct Message { - pub bars: HashMap>, - pub next_page_token: Option, -} - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - data_url: &str, - query: &outgoing::bar::Bar, - backoff: Option, -) -> Result { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(data_url) - .query(query) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get historical bars, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} diff --git a/src/lib/types/alpaca/api/incoming/calendar.rs b/src/lib/types/alpaca/api/incoming/calendar.rs index 65c540b..7e59c3f 100644 --- a/src/lib/types/alpaca/api/incoming/calendar.rs +++ b/src/lib/types/alpaca/api/incoming/calendar.rs @@ -1,13 +1,8 @@ use crate::{ - types::{self, alpaca::api::outgoing}, + types, utils::{de, time::EST_OFFSET}, }; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; use serde::Deserialize; -use std::time::Duration; use time::{Date, OffsetDateTime, Time}; #[derive(Deserialize)] @@ -29,41 +24,3 @@ impl From for types::Calendar { } } } - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - query: &outgoing::calendar::Calendar, - backoff: Option, - api_base: &str, -) -> Result, Error> { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/calendar", api_base)) - .query(query) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::>() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get calendar, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} diff --git a/src/lib/types/alpaca/api/incoming/clock.rs b/src/lib/types/alpaca/api/incoming/clock.rs index 8543283..fb45711 100644 --- a/src/lib/types/alpaca/api/incoming/clock.rs +++ b/src/lib/types/alpaca/api/incoming/clock.rs @@ -1,9 +1,4 @@ -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; use serde::Deserialize; -use std::time::Duration; use time::OffsetDateTime; #[derive(Deserialize)] @@ -16,39 +11,3 @@ pub struct Clock { #[serde(with = "time::serde::rfc3339")] pub next_close: OffsetDateTime, } - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - backoff: Option, - api_base: &str, -) -> Result { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/clock", api_base)) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get clock, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} diff --git a/src/lib/types/alpaca/api/incoming/news.rs b/src/lib/types/alpaca/api/incoming/news.rs index 180c0e5..e2f2528 100644 --- a/src/lib/types/alpaca/api/incoming/news.rs +++ b/src/lib/types/alpaca/api/incoming/news.rs @@ -1,19 +1,8 @@ use crate::{ - types::{ - self, - alpaca::{ - api::{outgoing, ALPACA_NEWS_DATA_API_URL}, - shared::news::normalize_html_content, - }, - }, + types::{self, alpaca::shared::news::normalize_html_content}, utils::de, }; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; use serde::Deserialize; -use std::time::Duration; use time::OffsetDateTime; #[derive(Deserialize)] @@ -68,46 +57,3 @@ impl From for types::News { } } } - -#[derive(Deserialize)] -pub struct Message { - pub news: Vec, - pub next_page_token: Option, -} - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - query: &outgoing::news::News, - backoff: Option, -) -> Result { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(ALPACA_NEWS_DATA_API_URL) - .query(query) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get historical news, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} diff --git a/src/lib/types/alpaca/api/incoming/order.rs b/src/lib/types/alpaca/api/incoming/order.rs index 3f6dbf0..bcc418e 100644 --- a/src/lib/types/alpaca/api/incoming/order.rs +++ b/src/lib/types/alpaca/api/incoming/order.rs @@ -1,45 +1,3 @@ -use crate::types::alpaca::{api::outgoing, shared}; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::{Client, Error}; -pub use shared::order::Order; -use std::time::Duration; +use crate::types::alpaca::shared::order; -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - query: &outgoing::order::Order, - backoff: Option, - api_base: &str, -) -> Result, Error> { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/orders", api_base)) - .query(query) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::>() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get orders, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} +pub use order::{Order, Side}; diff --git a/src/lib/types/alpaca/api/incoming/position.rs b/src/lib/types/alpaca/api/incoming/position.rs index f95113c..722087a 100644 --- a/src/lib/types/alpaca/api/incoming/position.rs +++ b/src/lib/types/alpaca/api/incoming/position.rs @@ -1,17 +1,12 @@ use crate::{ - types::alpaca::shared::{ - self, + types::alpaca::api::incoming::{ asset::{Class, Exchange}, + order, }, utils::de, }; -use backoff::{future::retry_notify, ExponentialBackoff}; -use governor::DefaultDirectRateLimiter; -use log::warn; -use reqwest::Client; use serde::Deserialize; use serde_aux::field_attributes::deserialize_number_from_string; -use std::{collections::HashSet, time::Duration}; use uuid::Uuid; #[derive(Deserialize, Clone, Copy)] @@ -21,7 +16,7 @@ pub enum Side { Short, } -impl From for shared::order::Side { +impl From for order::Side { fn from(side: Side) -> Self { match side { Side::Long => Self::Buy, @@ -64,110 +59,3 @@ pub struct Position { pub change_today: f64, pub asset_marginable: bool, } - -pub const ALPACA_API_URL_TEMPLATE: &str = ""; - -pub async fn get( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - backoff: Option, - api_base: &str, -) -> Result, reqwest::Error> { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - client - .get(&format!("https://{}.alpaca.markets/v2/positions", api_base)) - .send() - .await? - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::>() - .await - .map_err(backoff::Error::Permanent) - }, - |e, duration: Duration| { - warn!( - "Failed to get positions, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} - -pub async fn get_by_symbol( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - symbol: &str, - backoff: Option, - api_base: &str, -) -> Result, reqwest::Error> { - retry_notify( - backoff.unwrap_or_default(), - || async { - rate_limiter.until_ready().await; - let response = client - .get(&format!( - "https://{}.alpaca.markets/v2/positions/{}", - api_base, symbol - )) - .send() - .await?; - - if response.status() == reqwest::StatusCode::NOT_FOUND { - return Ok(None); - } - - response - .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { - backoff::Error::Permanent(e) - } - _ => e.into(), - })? - .json::() - .await - .map_err(backoff::Error::Permanent) - .map(Some) - }, - |e, duration: Duration| { - warn!( - "Failed to get position, will retry in {} seconds: {}", - duration.as_secs(), - e - ); - }, - ) - .await -} - -pub async fn get_by_symbols( - client: &Client, - rate_limiter: &DefaultDirectRateLimiter, - symbols: &[String], - backoff: Option, - api_base: &str, -) -> Result, reqwest::Error> { - if symbols.len() == 1 { - let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; - return Ok(position.into_iter().collect()); - } - - let symbols = symbols.iter().collect::>(); - - let positions = get(client, rate_limiter, backoff, api_base).await?; - - Ok(positions - .into_iter() - .filter(|position| symbols.contains(&position.symbol)) - .collect()) -} diff --git a/src/lib/types/alpaca/api/outgoing/asset.rs b/src/lib/types/alpaca/api/outgoing/asset.rs index 1efa07f..34395cf 100644 --- a/src/lib/types/alpaca/api/outgoing/asset.rs +++ b/src/lib/types/alpaca/api/outgoing/asset.rs @@ -1,6 +1,8 @@ -use crate::types::alpaca::shared::asset::{Class, Exchange, Status}; +use crate::types::alpaca::shared::asset; use serde::Serialize; +pub use asset::{Class, Exchange, Status}; + #[derive(Serialize)] pub struct Asset { pub status: Option, diff --git a/src/lib/types/alpaca/api/outgoing/bar.rs b/src/lib/types/alpaca/api/outgoing/bar.rs index 09e851f..aa05c5b 100644 --- a/src/lib/types/alpaca/api/outgoing/bar.rs +++ b/src/lib/types/alpaca/api/outgoing/bar.rs @@ -1,12 +1,13 @@ use crate::{ - types::alpaca::shared::{Sort, Source}, + alpaca::bars::MAX_LIMIT, + types::alpaca::shared, utils::{ser, ONE_MINUTE}, }; use serde::Serialize; use std::time::Duration; use time::OffsetDateTime; -pub const MAX_LIMIT: i64 = 10_000; +pub use shared::{Sort, Source}; #[derive(Serialize)] #[serde(rename_all = "snake_case")] diff --git a/src/lib/types/alpaca/api/outgoing/news.rs b/src/lib/types/alpaca/api/outgoing/news.rs index c271fe1..b4b9227 100644 --- a/src/lib/types/alpaca/api/outgoing/news.rs +++ b/src/lib/types/alpaca/api/outgoing/news.rs @@ -1,9 +1,7 @@ -use crate::{types::alpaca::shared::Sort, utils::ser}; +use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser}; use serde::Serialize; use time::OffsetDateTime; -pub const MAX_LIMIT: i64 = 50; - #[derive(Serialize)] pub struct News { #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] diff --git a/src/lib/types/alpaca/api/outgoing/order.rs b/src/lib/types/alpaca/api/outgoing/order.rs index 4b66d72..a6751ea 100644 --- a/src/lib/types/alpaca/api/outgoing/order.rs +++ b/src/lib/types/alpaca/api/outgoing/order.rs @@ -1,10 +1,12 @@ use crate::{ - types::alpaca::shared::{order::Side, Sort}, + types::alpaca::shared::{order, Sort}, utils::ser, }; use serde::Serialize; use time::OffsetDateTime; +pub use order::Side; + #[derive(Serialize)] #[serde(rename_all = "snake_case")] #[allow(dead_code)] diff --git a/src/lib/types/alpaca/websocket/trading/incoming/order.rs b/src/lib/types/alpaca/websocket/trading/incoming/order.rs index e5f012b..2477451 100644 --- a/src/lib/types/alpaca/websocket/trading/incoming/order.rs +++ b/src/lib/types/alpaca/websocket/trading/incoming/order.rs @@ -1,10 +1,10 @@ -use crate::types::alpaca::shared; +use crate::types::alpaca::shared::order; use serde::Deserialize; use serde_aux::prelude::deserialize_number_from_string; use time::OffsetDateTime; use uuid::Uuid; -pub use shared::order::Order; +pub use order::Order; #[derive(Deserialize, Debug, PartialEq)] #[serde(rename_all = "snake_case")] diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 4b2e057..64b1aa2 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -4,7 +4,10 @@ use crate::{ }; use axum::{extract::Path, Extension, Json}; use http::StatusCode; -use qrust::types::{alpaca, Asset}; +use qrust::{ + alpaca, + types::{self, Asset}, +}; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, HashSet}, @@ -69,7 +72,7 @@ pub async fn add( .map(|asset| asset.symbol) .collect::>(); - let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols( + let mut alpaca_assets = alpaca::assets::get_by_symbols( &config.alpaca_client, &config.alpaca_rate_limiter, &request.symbols, @@ -94,7 +97,7 @@ pub async fn add( if database_symbols.contains(&symbol) { skipped.push(symbol); } else if let Some(asset) = alpaca_assets.remove(&symbol) { - if asset.status == alpaca::shared::asset::Status::Active + if asset.status == types::alpaca::api::incoming::asset::Status::Active && asset.tradable && asset.fractionable { @@ -144,7 +147,7 @@ pub async fn add_symbol( return Err(StatusCode::CONFLICT); } - let asset = alpaca::api::incoming::asset::get_by_symbol( + let asset = alpaca::assets::get_by_symbol( &config.alpaca_client, &config.alpaca_rate_limiter, &symbol, @@ -159,7 +162,7 @@ pub async fn add_symbol( }) })?; - if asset.status != alpaca::shared::asset::Status::Active + if asset.status != types::alpaca::api::incoming::asset::Status::Active || !asset.tradable || !asset.fractionable { diff --git a/src/threads/clock.rs b/src/threads/clock.rs index 819d964..2829573 100644 --- a/src/threads/clock.rs +++ b/src/threads/clock.rs @@ -4,7 +4,8 @@ use crate::{ }; use log::info; use qrust::{ - types::{alpaca, Calendar}, + alpaca, + types::{self, Calendar}, utils::{backoff, duration_until}, }; use std::sync::Arc; @@ -21,8 +22,8 @@ pub struct Message { pub next_switch: OffsetDateTime, } -impl From for Message { - fn from(clock: alpaca::api::incoming::clock::Clock) -> Self { +impl From for Message { + fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self { if clock.is_open { Self { status: Status::Open, @@ -40,7 +41,7 @@ impl From for Message { pub async fn run(config: Arc, sender: mpsc::Sender) { loop { let clock_future = async { - alpaca::api::incoming::clock::get( + alpaca::clock::get( &config.alpaca_client, &config.alpaca_rate_limiter, Some(backoff::infinite()), @@ -51,10 +52,10 @@ pub async fn run(config: Arc, sender: mpsc::Sender) { }; let calendar_future = async { - alpaca::api::incoming::calendar::get( + alpaca::calendar::get( &config.alpaca_client, &config.alpaca_rate_limiter, - &alpaca::api::outgoing::calendar::Calendar::default(), + &types::alpaca::api::outgoing::calendar::Calendar::default(), Some(backoff::infinite()), &ALPACA_API_BASE, ) diff --git a/src/threads/data/backfill/bars.rs b/src/threads/data/backfill/bars.rs index 0d82359..702d188 100644 --- a/src/threads/data/backfill/bars.rs +++ b/src/threads/data/backfill/bars.rs @@ -6,11 +6,10 @@ use crate::{ use async_trait::async_trait; use log::{error, info}; use qrust::{ + alpaca, types::{ - alpaca::{ - self, - shared::{Sort, Source}, - }, + self, + alpaca::shared::{Sort, Source}, Backfill, Bar, }, utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, @@ -27,7 +26,7 @@ pub struct Handler { fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, - ) -> alpaca::api::outgoing::bar::Bar, + ) -> types::alpaca::api::outgoing::bar::Bar, } pub fn us_equity_query_constructor( @@ -35,8 +34,8 @@ pub fn us_equity_query_constructor( fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, -) -> alpaca::api::outgoing::bar::Bar { - alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity { +) -> types::alpaca::api::outgoing::bar::Bar { + types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity { symbols, start: Some(fetch_from), end: Some(fetch_to), @@ -52,8 +51,8 @@ pub fn crypto_query_constructor( fetch_from: OffsetDateTime, fetch_to: OffsetDateTime, next_page_token: Option, -) -> alpaca::api::outgoing::bar::Bar { - alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto { +) -> types::alpaca::api::outgoing::bar::Bar { + types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto { symbols, start: Some(fetch_from), end: Some(fetch_to), @@ -124,7 +123,7 @@ impl super::Handler for Handler { let mut next_page_token = None; loop { - let message = alpaca::api::incoming::bar::get( + let message = alpaca::bars::get( &self.config.alpaca_client, &self.config.alpaca_rate_limiter, self.data_url, @@ -190,7 +189,7 @@ impl super::Handler for Handler { } fn max_limit(&self) -> i64 { - alpaca::api::outgoing::bar::MAX_LIMIT + alpaca::bars::MAX_LIMIT } fn log_string(&self) -> &'static str { diff --git a/src/threads/data/backfill/news.rs b/src/threads/data/backfill/news.rs index 4f7e23d..e23b3ff 100644 --- a/src/threads/data/backfill/news.rs +++ b/src/threads/data/backfill/news.rs @@ -7,11 +7,10 @@ use async_trait::async_trait; use futures_util::future::join_all; use log::{error, info}; use qrust::{ + alpaca, types::{ - alpaca::{ - self, - shared::{Sort, Source}, - }, + self, + alpaca::shared::{Sort, Source}, news::Prediction, Backfill, News, }, @@ -86,10 +85,10 @@ impl super::Handler for Handler { let mut next_page_token = None; loop { - let message = alpaca::api::incoming::news::get( + let message = alpaca::news::get( &self.config.alpaca_client, &self.config.alpaca_rate_limiter, - &alpaca::api::outgoing::news::News { + &types::alpaca::api::outgoing::news::News { symbols: symbols.clone(), start: Some(fetch_from), end: Some(fetch_to), @@ -187,7 +186,7 @@ impl super::Handler for Handler { } fn max_limit(&self) -> i64 { - alpaca::api::outgoing::news::MAX_LIMIT + alpaca::news::MAX_LIMIT } fn log_string(&self) -> &'static str { diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index 1b99bda..940ef0b 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -3,28 +3,26 @@ mod websocket; use super::clock; use crate::{ - config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET, ALPACA_SOURCE}, + config::{Config, ALPACA_API_BASE, ALPACA_SOURCE}, create_send_await, database, }; -use futures_util::StreamExt; use itertools::{Either, Itertools}; use log::error; -use qrust::types::{ - alpaca::{ - self, - websocket::{ +use qrust::{ + alpaca, + types::{ + alpaca::websocket::{ ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, }, + Asset, Class, }, - Asset, Class, }; use std::{collections::HashMap, sync::Arc}; use tokio::{ join, select, spawn, sync::{mpsc, oneshot}, }; -use tokio_tungstenite::connect_async; #[derive(Clone, Copy, PartialEq, Eq)] #[allow(dead_code)] @@ -67,11 +65,10 @@ pub async fn run( mut clock_receiver: mpsc::Receiver, ) { let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = - init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await; + init_thread(&config, ThreadType::Bars(Class::UsEquity)); let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = - init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await; - let (news_websocket_sender, news_backfill_sender) = - init_thread(config.clone(), ThreadType::News).await; + init_thread(&config, ThreadType::Bars(Class::Crypto)); + let (news_websocket_sender, news_backfill_sender) = init_thread(&config, ThreadType::News); loop { select! { @@ -100,8 +97,8 @@ pub async fn run( } } -async fn init_thread( - config: Arc, +fn init_thread( + config: &Arc, thread_type: ThreadType, ) -> ( mpsc::Sender, @@ -115,16 +112,6 @@ async fn init_thread( ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), }; - let (websocket, _) = connect_async(websocket_url).await.unwrap(); - let (mut websocket_sink, mut websocket_stream) = websocket.split(); - alpaca::websocket::data::authenticate( - &mut websocket_sink, - &mut websocket_stream, - (*ALPACA_API_KEY).to_string(), - (*ALPACA_API_SECRET).to_string(), - ) - .await; - let (backfill_sender, backfill_receiver) = mpsc::channel(100); spawn(backfill::run( Arc::new(backfill::create_handler(thread_type, config.clone())), @@ -135,8 +122,7 @@ async fn init_thread( spawn(websocket::run( Arc::new(websocket::create_handler(thread_type, config.clone())), websocket_receiver, - websocket_stream, - websocket_sink, + websocket_url, )); (websocket_sender, backfill_sender) @@ -214,7 +200,7 @@ async fn handle_message( match message.action { Action::Add => { let assets = async { - alpaca::api::incoming::asset::get_by_symbols( + alpaca::assets::get_by_symbols( &config.alpaca_client, &config.alpaca_rate_limiter, &symbols, @@ -229,7 +215,7 @@ async fn handle_message( }; let positions = async { - alpaca::api::incoming::position::get_by_symbols( + alpaca::positions::get_by_symbols( &config.alpaca_client, &config.alpaca_rate_limiter, &symbols, @@ -252,7 +238,7 @@ async fn handle_message( let position = positions.remove(symbol); batch.push(Asset::from((asset, position))); } else { - error!("Failed to find asset for symbol: {}", symbol); + error!("Failed to find asset for symbol: {}.", symbol); } } diff --git a/src/threads/data/websocket/bars.rs b/src/threads/data/websocket/bars.rs index 0a616be..ef5fc4a 100644 --- a/src/threads/data/websocket/bars.rs +++ b/src/threads/data/websocket/bars.rs @@ -1,9 +1,12 @@ -use super::Pending; +use super::State; use crate::{config::Config, database}; use async_trait::async_trait; use log::{debug, error, info}; use qrust::types::{alpaca::websocket, Bar}; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use tokio::sync::RwLock; pub struct Handler { @@ -23,7 +26,7 @@ impl super::Handler for Handler { async fn handle_websocket_message( &self, - pending: Arc>, + state: Arc>, message: websocket::data::incoming::Message, ) { match message { @@ -35,19 +38,24 @@ impl super::Handler for Handler { unreachable!() }; - let mut pending = pending.write().await; + let symbols = symbols.into_iter().collect::>(); + let mut state = state.write().await; - let newly_subscribed = pending - .subscriptions + let newly_subscribed = state + .pending_subscriptions .extract_if(|symbol, _| symbols.contains(symbol)) .collect::>(); - let newly_unsubscribed = pending - .unsubscriptions + let newly_unsubscribed = state + .pending_unsubscriptions .extract_if(|symbol, _| !symbols.contains(symbol)) .collect::>(); - drop(pending); + state + .active_subscriptions + .extend(newly_subscribed.keys().cloned()); + + drop(state); if !newly_subscribed.is_empty() { info!( @@ -122,4 +130,8 @@ impl super::Handler for Handler { _ => unreachable!(), } } + + fn log_string(&self) -> &'static str { + "bars" + } } diff --git a/src/threads/data/websocket/mod.rs b/src/threads/data/websocket/mod.rs index 47e1cff..995145e 100644 --- a/src/threads/data/websocket/mod.rs +++ b/src/threads/data/websocket/mod.rs @@ -2,23 +2,27 @@ mod bars; mod news; use super::ThreadType; -use crate::config::Config; +use crate::config::{Config, ALPACA_API_KEY, ALPACA_API_SECRET}; use async_trait::async_trait; -use futures_util::{ - future::join_all, - stream::{SplitSink, SplitStream}, - SinkExt, StreamExt, -}; +use backoff::{future::retry_notify, ExponentialBackoff}; +use futures_util::{future::join_all, SinkExt, StreamExt}; use log::error; -use qrust::types::{alpaca::websocket, Class}; +use qrust::types::{ + alpaca::{self, websocket}, + Class, +}; use serde_json::{from_str, to_string}; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, +}; use tokio::{ net::TcpStream, select, spawn, - sync::{mpsc, oneshot, Mutex, RwLock}, + sync::{mpsc, oneshot, RwLock}, }; -use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream}; pub enum Action { Subscribe, @@ -54,9 +58,10 @@ impl Message { } } -pub struct Pending { - pub subscriptions: HashMap>, - pub unsubscriptions: HashMap>, +pub struct State { + pub active_subscriptions: HashSet, + pub pending_subscriptions: HashMap>, + pub pending_unsubscriptions: HashMap>, } #[async_trait] @@ -67,53 +72,64 @@ pub trait Handler: Send + Sync { ) -> websocket::data::outgoing::subscribe::Message; async fn handle_websocket_message( &self, - pending: Arc>, + state: Arc>, message: websocket::data::incoming::Message, ); + fn log_string(&self) -> &'static str; } pub async fn run( handler: Arc>, mut receiver: mpsc::Receiver, - mut websocket_stream: SplitStream>>, - websocket_sink: SplitSink>, tungstenite::Message>, + websocket_url: String, ) { - let pending = Arc::new(RwLock::new(Pending { - subscriptions: HashMap::new(), - unsubscriptions: HashMap::new(), + let state = Arc::new(RwLock::new(State { + active_subscriptions: HashSet::new(), + pending_subscriptions: HashMap::new(), + pending_unsubscriptions: HashMap::new(), })); - let websocket_sink = Arc::new(Mutex::new(websocket_sink)); + + let (sink_sender, sink_receiver) = mpsc::channel(100); + let (stream_sender, mut stream_receiver) = mpsc::channel(10_000); + + spawn(run_connection( + handler.clone(), + sink_receiver, + stream_sender, + websocket_url.clone(), + state.clone(), + )); loop { select! { Some(message) = receiver.recv() => { spawn(handle_message( handler.clone(), - pending.clone(), - websocket_sink.clone(), + state.clone(), + sink_sender.clone(), message, )); } - Some(Ok(message)) = websocket_stream.next() => { + Some(message) = stream_receiver.recv() => { match message { tungstenite::Message::Text(message) => { let parsed_message = from_str::>(&message); if parsed_message.is_err() { - error!("Failed to deserialize websocket message: {:?}", message); + error!("Failed to deserialize websocket message: {:?}.", message); continue; } for message in parsed_message.unwrap() { let handler = handler.clone(); - let pending = pending.clone(); + let state = state.clone(); spawn(async move { - handler.handle_websocket_message(pending, message).await; + handler.handle_websocket_message(state, message).await; }); } } tungstenite::Message::Ping(_) => {} - _ => error!("Unexpected websocket message: {:?}", message), + _ => error!("Unexpected websocket message: {:?}.", message), } } else => panic!("Communication channel unexpectedly closed.") @@ -121,10 +137,142 @@ pub async fn run( } } +#[allow(clippy::too_many_lines)] +async fn run_connection( + handler: Arc>, + mut sink_receiver: mpsc::Receiver, + stream_sender: mpsc::Sender, + websocket_url: String, + state: Arc>, +) { + let mut peek = None; + + 'connection: loop { + let (websocket, _): (WebSocketStream>, _) = retry_notify( + ExponentialBackoff::default(), + || async { + connect_async(websocket_url.clone()) + .await + .map_err(Into::into) + }, + |e, duration: Duration| { + error!( + "Failed to connect to {} websocket, will retry in {} seconds: {}.", + handler.log_string(), + duration.as_secs(), + e + ); + }, + ) + .await + .unwrap(); + + let (mut sink, mut stream) = websocket.split(); + alpaca::websocket::data::authenticate( + &mut sink, + &mut stream, + (*ALPACA_API_KEY).to_string(), + (*ALPACA_API_SECRET).to_string(), + ) + .await; + + let mut state = state.write().await; + + state + .pending_unsubscriptions + .drain() + .for_each(|(_, sender)| { + sender.send(()).unwrap(); + }); + + let (recovered_subscriptions, receivers) = state + .active_subscriptions + .iter() + .map(|symbol| { + let (sender, receiver) = oneshot::channel(); + ((symbol.clone(), sender), receiver) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + state.pending_subscriptions.extend(recovered_subscriptions); + + let pending_subscriptions = state + .pending_subscriptions + .keys() + .cloned() + .collect::>(); + + drop(state); + + if !pending_subscriptions.is_empty() { + if let Err(err) = sink + .send(tungstenite::Message::Text( + to_string(&websocket::data::outgoing::Message::Subscribe( + handler.create_subscription_message(pending_subscriptions), + )) + .unwrap(), + )) + .await + { + error!("Failed to send websocket message: {:?}.", err); + continue; + } + } + + join_all(receivers).await; + + if peek.is_some() { + if let Err(err) = sink.send(peek.clone().unwrap()).await { + error!("Failed to send websocket message: {:?}.", err); + continue; + } + peek = None; + } + + loop { + select! { + Some(message) = sink_receiver.recv() => { + peek = Some(message.clone()); + + if let Err(err) = sink.send(message).await { + error!("Failed to send websocket message: {:?}.", err); + continue 'connection; + }; + + peek = None; + } + message = stream.next() => { + if message.is_none() { + error!("Websocket stream unexpectedly closed."); + continue 'connection; + } + + let message = message.unwrap(); + + if let Err(err) = message { + error!("Failed to receive websocket message: {:?}.", err); + continue 'connection; + } + + let message = message.unwrap(); + + if message.is_close() { + error!("Websocket connection closed."); + continue 'connection; + } + + stream_sender.send(message).await.unwrap(); + } + else => error!("Communication channel unexpectedly closed.") + } + } + } +} + async fn handle_message( handler: Arc>, - pending: Arc>, - sink: Arc>, tungstenite::Message>>>, + pending: Arc>, + sink_sender: mpsc::Sender, message: Message, ) { if message.symbols.is_empty() { @@ -134,23 +282,22 @@ async fn handle_message( match message.action { Some(Action::Subscribe) => { - let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message + let (pending_subscriptions, receivers) = message .symbols .iter() .map(|symbol| { let (sender, receiver) = oneshot::channel(); ((symbol.clone(), sender), receiver) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); pending .write() .await - .subscriptions + .pending_subscriptions .extend(pending_subscriptions); - sink.lock() - .await + sink_sender .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Subscribe( handler.create_subscription_message(message.symbols), @@ -175,11 +322,10 @@ async fn handle_message( pending .write() .await - .unsubscriptions + .pending_unsubscriptions .extend(pending_unsubscriptions); - sink.lock() - .await + sink_sender .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Unsubscribe( handler.create_subscription_message(message.symbols.clone()), diff --git a/src/threads/data/websocket/news.rs b/src/threads/data/websocket/news.rs index 0975378..08a32cf 100644 --- a/src/threads/data/websocket/news.rs +++ b/src/threads/data/websocket/news.rs @@ -1,4 +1,4 @@ -use super::Pending; +use super::State; use crate::{config::Config, database}; use async_trait::async_trait; use log::{debug, error, info}; @@ -21,7 +21,7 @@ impl super::Handler for Handler { async fn handle_websocket_message( &self, - pending: Arc>, + state: Arc>, message: websocket::data::incoming::Message, ) { match message { @@ -32,19 +32,23 @@ impl super::Handler for Handler { unreachable!() }; - let mut pending = pending.write().await; + let mut state = state.write().await; - let newly_subscribed = pending - .subscriptions + let newly_subscribed = state + .pending_subscriptions .extract_if(|symbol, _| symbols.contains(symbol)) .collect::>(); - let newly_unsubscribed = pending - .unsubscriptions + let newly_unsubscribed = state + .pending_unsubscriptions .extract_if(|symbol, _| !symbols.contains(symbol)) .collect::>(); - drop(pending); + state + .active_subscriptions + .extend(newly_subscribed.keys().cloned()); + + drop(state); if !newly_subscribed.is_empty() { info!( @@ -108,4 +112,8 @@ impl super::Handler for Handler { _ => unreachable!(), } } + + fn log_string(&self) -> &'static str { + "news" + } } diff --git a/src/threads/trading/websocket.rs b/src/threads/trading/websocket.rs index aea886c..1507d68 100644 --- a/src/threads/trading/websocket.rs +++ b/src/threads/trading/websocket.rs @@ -21,7 +21,7 @@ pub async fn run( ); if parsed_message.is_err() { - error!("Failed to deserialize websocket message: {:?}", message); + error!("Failed to deserialize websocket message: {:?}.", message); continue; } @@ -31,7 +31,7 @@ pub async fn run( )); } tungstenite::Message::Ping(_) => {} - _ => error!("Unexpected websocket message: {:?}", message), + _ => error!("Unexpected websocket message: {:?}.", message), } } } @@ -43,7 +43,7 @@ async fn handle_websocket_message( match message { websocket::trading::incoming::Message::Order(message) => { debug!( - "Received order message for {}: {:?}", + "Received order message for {}: {:?}.", message.order.symbol, message.event );