Add automatic websocket reconnection

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-11 23:41:06 +00:00
parent d02f958865
commit d2d20e2978
33 changed files with 838 additions and 664 deletions

View File

@@ -3,13 +3,13 @@ use crate::{
database, database,
}; };
use log::{info, warn}; use log::{info, warn};
use qrust::types::alpaca; use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::join; use tokio::join;
pub async fn check_account(config: &Arc<Config>) { pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::api::incoming::account::get( let account = alpaca::account::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
None, None,
@@ -19,7 +19,7 @@ pub async fn check_account(config: &Arc<Config>) {
.unwrap(); .unwrap();
assert!( assert!(
!(account.status != alpaca::api::incoming::account::Status::Active), !(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.", "Account status is not active: {:?}.",
account.status account.status
); );
@@ -46,11 +46,11 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
let mut orders = vec![]; let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH; let mut after = OffsetDateTime::UNIX_EPOCH;
while let Some(message) = alpaca::api::incoming::order::get( while let Some(message) = alpaca::orders::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&alpaca::api::outgoing::order::Order { &types::alpaca::api::outgoing::order::Order {
status: Some(alpaca::api::outgoing::order::Status::All), status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after), after: Some(after),
..Default::default() ..Default::default()
}, },
@@ -67,7 +67,7 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
let orders = orders let orders = orders
.into_iter() .into_iter()
.flat_map(&alpaca::api::incoming::order::Order::normalize) .flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
database::orders::upsert_batch( database::orders::upsert_batch(
@@ -85,7 +85,7 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data."); info!("Rehydrating position data.");
let positions_future = async { let positions_future = async {
alpaca::api::incoming::position::get( alpaca::positions::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
None, None,

42
src/lib/alpaca/account.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

139
src/lib/alpaca/assets.rs Normal file
View File

@@ -0,0 +1,139 @@
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

53
src/lib/alpaca/bars.rs Normal file
View File

@@ -0,0 +1,53 @@
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,44 @@
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

42
src/lib/alpaca/clock.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

8
src/lib/alpaca/mod.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;

52
src/lib/alpaca/news.rs Normal file
View File

@@ -0,0 +1,52 @@
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

46
src/lib/alpaca/orders.rs Normal file
View File

@@ -0,0 +1,46 @@
use crate::types::alpaca::{api::outgoing, shared::order};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use order::Order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

111
src/lib/alpaca/positions.rs Normal file
View File

@@ -0,0 +1,111 @@
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,3 +1,4 @@
pub mod alpaca;
pub mod database; pub mod database;
pub mod types; pub mod types;
pub mod utils; pub mod utils;

View File

@@ -1,12 +1,7 @@
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::{ use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string, deserialize_number_from_string, deserialize_option_number_from_string,
}; };
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
use uuid::Uuid; use uuid::Uuid;
@@ -78,39 +73,3 @@ pub struct Account {
#[serde(deserialize_with = "deserialize_number_from_string")] #[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64, pub regt_buying_power: f64,
} }
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,22 +1,11 @@
use super::position::Position; use super::position::Position;
use crate::types::{ use crate::types::{self, alpaca::shared::asset};
self,
alpaca::{
api::outgoing,
shared::asset::{Class, Exchange, Status},
},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string; use serde_aux::field_attributes::deserialize_option_number_from_string;
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
use uuid::Uuid; use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone)]
pub struct Asset { pub struct Asset {
@@ -48,131 +37,3 @@ impl From<(Asset, Option<Position>)> for types::Asset {
} }
} }
} }
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -1,10 +1,5 @@
use crate::types::{self, alpaca::api::outgoing}; use crate::types;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -43,47 +38,3 @@ impl From<(Bar, String)> for types::Bar {
} }
} }
} }
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,13 +1,8 @@
use crate::{ use crate::{
types::{self, alpaca::api::outgoing}, types,
utils::{de, time::EST_OFFSET}, utils::{de, time::EST_OFFSET},
}; };
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::time::Duration;
use time::{Date, OffsetDateTime, Time}; use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -29,41 +24,3 @@ impl From<Calendar> for types::Calendar {
} }
} }
} }
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,9 +1,4 @@
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -16,39 +11,3 @@ pub struct Clock {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime, pub next_close: OffsetDateTime,
} }
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,19 +1,8 @@
use crate::{ use crate::{
types::{ types::{self, alpaca::shared::news::normalize_html_content},
self,
alpaca::{
api::{outgoing, ALPACA_NEWS_DATA_API_URL},
shared::news::normalize_html_content,
},
},
utils::de, utils::de,
}; };
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -68,46 +57,3 @@ impl From<News> for types::News {
} }
} }
} }
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,45 +1,3 @@
use crate::types::alpaca::{api::outgoing, shared}; use crate::types::alpaca::shared::order;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
pub use shared::order::Order;
use std::time::Duration;
pub async fn get( pub use order::{Order, Side};
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,17 +1,12 @@
use crate::{ use crate::{
types::alpaca::shared::{ types::alpaca::api::incoming::{
self,
asset::{Class, Exchange}, asset::{Class, Exchange},
order,
}, },
utils::de, utils::de,
}; };
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string; use serde_aux::field_attributes::deserialize_number_from_string;
use std::{collections::HashSet, time::Duration};
use uuid::Uuid; use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)] #[derive(Deserialize, Clone, Copy)]
@@ -21,7 +16,7 @@ pub enum Side {
Short, Short,
} }
impl From<Side> for shared::order::Side { impl From<Side> for order::Side {
fn from(side: Side) -> Self { fn from(side: Side) -> Self {
match side { match side {
Side::Long => Self::Buy, Side::Long => Self::Buy,
@@ -64,110 +59,3 @@ pub struct Position {
pub change_today: f64, pub change_today: f64,
pub asset_marginable: bool, pub asset_marginable: bool,
} }
pub const ALPACA_API_URL_TEMPLATE: &str = "";
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,6 +1,8 @@
use crate::types::alpaca::shared::asset::{Class, Exchange, Status}; use crate::types::alpaca::shared::asset;
use serde::Serialize; use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)] #[derive(Serialize)]
pub struct Asset { pub struct Asset {
pub status: Option<Status>, pub status: Option<Status>,

View File

@@ -1,12 +1,13 @@
use crate::{ use crate::{
types::alpaca::shared::{Sort, Source}, alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
utils::{ser, ONE_MINUTE}, utils::{ser, ONE_MINUTE},
}; };
use serde::Serialize; use serde::Serialize;
use std::time::Duration; use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
pub const MAX_LIMIT: i64 = 10_000; pub use shared::{Sort, Source};
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]

View File

@@ -1,9 +1,7 @@
use crate::{types::alpaca::shared::Sort, utils::ser}; use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize; use serde::Serialize;
use time::OffsetDateTime; use time::OffsetDateTime;
pub const MAX_LIMIT: i64 = 50;
#[derive(Serialize)] #[derive(Serialize)]
pub struct News { pub struct News {
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]

View File

@@ -1,10 +1,12 @@
use crate::{ use crate::{
types::alpaca::shared::{order::Side, Sort}, types::alpaca::shared::{order, Sort},
utils::ser, utils::ser,
}; };
use serde::Serialize; use serde::Serialize;
use time::OffsetDateTime; use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[allow(dead_code)] #[allow(dead_code)]

View File

@@ -1,10 +1,10 @@
use crate::types::alpaca::shared; use crate::types::alpaca::shared::order;
use serde::Deserialize; use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string; use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime; use time::OffsetDateTime;
use uuid::Uuid; use uuid::Uuid;
pub use shared::order::Order; pub use order::Order;
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]

View File

@@ -4,7 +4,10 @@ use crate::{
}; };
use axum::{extract::Path, Extension, Json}; use axum::{extract::Path, Extension, Json};
use http::StatusCode; use http::StatusCode;
use qrust::types::{alpaca, Asset}; use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@@ -69,7 +72,7 @@ pub async fn add(
.map(|asset| asset.symbol) .map(|asset| asset.symbol)
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols( let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&request.symbols, &request.symbols,
@@ -94,7 +97,7 @@ pub async fn add(
if database_symbols.contains(&symbol) { if database_symbols.contains(&symbol) {
skipped.push(symbol); skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) { } else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == alpaca::shared::asset::Status::Active if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable && asset.tradable
&& asset.fractionable && asset.fractionable
{ {
@@ -144,7 +147,7 @@ pub async fn add_symbol(
return Err(StatusCode::CONFLICT); return Err(StatusCode::CONFLICT);
} }
let asset = alpaca::api::incoming::asset::get_by_symbol( let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&symbol, &symbol,
@@ -159,7 +162,7 @@ pub async fn add_symbol(
}) })
})?; })?;
if asset.status != alpaca::shared::asset::Status::Active if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable || !asset.tradable
|| !asset.fractionable || !asset.fractionable
{ {

View File

@@ -4,7 +4,8 @@ use crate::{
}; };
use log::info; use log::info;
use qrust::{ use qrust::{
types::{alpaca, Calendar}, alpaca,
types::{self, Calendar},
utils::{backoff, duration_until}, utils::{backoff, duration_until},
}; };
use std::sync::Arc; use std::sync::Arc;
@@ -21,8 +22,8 @@ pub struct Message {
pub next_switch: OffsetDateTime, pub next_switch: OffsetDateTime,
} }
impl From<alpaca::api::incoming::clock::Clock> for Message { impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self { fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
if clock.is_open { if clock.is_open {
Self { Self {
status: Status::Open, status: Status::Open,
@@ -40,7 +41,7 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) { pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop { loop {
let clock_future = async { let clock_future = async {
alpaca::api::incoming::clock::get( alpaca::clock::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
Some(backoff::infinite()), Some(backoff::infinite()),
@@ -51,10 +52,10 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
}; };
let calendar_future = async { let calendar_future = async {
alpaca::api::incoming::calendar::get( alpaca::calendar::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&alpaca::api::outgoing::calendar::Calendar::default(), &types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()), Some(backoff::infinite()),
&ALPACA_API_BASE, &ALPACA_API_BASE,
) )

View File

@@ -6,11 +6,10 @@ use crate::{
use async_trait::async_trait; use async_trait::async_trait;
use log::{error, info}; use log::{error, info};
use qrust::{ use qrust::{
alpaca,
types::{ types::{
alpaca::{ self,
self, alpaca::shared::{Sort, Source},
shared::{Sort, Source},
},
Backfill, Bar, Backfill, Bar,
}, },
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
@@ -27,7 +26,7 @@ pub struct Handler {
fetch_from: OffsetDateTime, fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime, fetch_to: OffsetDateTime,
next_page_token: Option<String>, next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar, ) -> types::alpaca::api::outgoing::bar::Bar,
} }
pub fn us_equity_query_constructor( pub fn us_equity_query_constructor(
@@ -35,8 +34,8 @@ pub fn us_equity_query_constructor(
fetch_from: OffsetDateTime, fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime, fetch_to: OffsetDateTime,
next_page_token: Option<String>, next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar { ) -> types::alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity { types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols, symbols,
start: Some(fetch_from), start: Some(fetch_from),
end: Some(fetch_to), end: Some(fetch_to),
@@ -52,8 +51,8 @@ pub fn crypto_query_constructor(
fetch_from: OffsetDateTime, fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime, fetch_to: OffsetDateTime,
next_page_token: Option<String>, next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar { ) -> types::alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto { types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols, symbols,
start: Some(fetch_from), start: Some(fetch_from),
end: Some(fetch_to), end: Some(fetch_to),
@@ -124,7 +123,7 @@ impl super::Handler for Handler {
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let message = alpaca::api::incoming::bar::get( let message = alpaca::bars::get(
&self.config.alpaca_client, &self.config.alpaca_client,
&self.config.alpaca_rate_limiter, &self.config.alpaca_rate_limiter,
self.data_url, self.data_url,
@@ -190,7 +189,7 @@ impl super::Handler for Handler {
} }
fn max_limit(&self) -> i64 { fn max_limit(&self) -> i64 {
alpaca::api::outgoing::bar::MAX_LIMIT alpaca::bars::MAX_LIMIT
} }
fn log_string(&self) -> &'static str { fn log_string(&self) -> &'static str {

View File

@@ -7,11 +7,10 @@ use async_trait::async_trait;
use futures_util::future::join_all; use futures_util::future::join_all;
use log::{error, info}; use log::{error, info};
use qrust::{ use qrust::{
alpaca,
types::{ types::{
alpaca::{ self,
self, alpaca::shared::{Sort, Source},
shared::{Sort, Source},
},
news::Prediction, news::Prediction,
Backfill, News, Backfill, News,
}, },
@@ -86,10 +85,10 @@ impl super::Handler for Handler {
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let message = alpaca::api::incoming::news::get( let message = alpaca::news::get(
&self.config.alpaca_client, &self.config.alpaca_client,
&self.config.alpaca_rate_limiter, &self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News { &types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(), symbols: symbols.clone(),
start: Some(fetch_from), start: Some(fetch_from),
end: Some(fetch_to), end: Some(fetch_to),
@@ -187,7 +186,7 @@ impl super::Handler for Handler {
} }
fn max_limit(&self) -> i64 { fn max_limit(&self) -> i64 {
alpaca::api::outgoing::news::MAX_LIMIT alpaca::news::MAX_LIMIT
} }
fn log_string(&self) -> &'static str { fn log_string(&self) -> &'static str {

View File

@@ -3,28 +3,26 @@ mod websocket;
use super::clock; use super::clock;
use crate::{ use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET, ALPACA_SOURCE}, config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database, create_send_await, database,
}; };
use futures_util::StreamExt;
use itertools::{Either, Itertools}; use itertools::{Either, Itertools};
use log::error; use log::error;
use qrust::types::{ use qrust::{
alpaca::{ alpaca,
self, types::{
websocket::{ alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
}, },
Asset, Class,
}, },
Asset, Class,
}; };
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tokio::{ use tokio::{
join, select, spawn, join, select, spawn,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
}; };
use tokio_tungstenite::connect_async;
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] #[allow(dead_code)]
@@ -67,11 +65,10 @@ pub async fn run(
mut clock_receiver: mpsc::Receiver<clock::Message>, mut clock_receiver: mpsc::Receiver<clock::Message>,
) { ) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await; init_thread(&config, ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await; init_thread(&config, ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) = let (news_websocket_sender, news_backfill_sender) = init_thread(&config, ThreadType::News);
init_thread(config.clone(), ThreadType::News).await;
loop { loop {
select! { select! {
@@ -100,8 +97,8 @@ pub async fn run(
} }
} }
async fn init_thread( fn init_thread(
config: Arc<Config>, config: &Arc<Config>,
thread_type: ThreadType, thread_type: ThreadType,
) -> ( ) -> (
mpsc::Sender<websocket::Message>, mpsc::Sender<websocket::Message>,
@@ -115,16 +112,6 @@ async fn init_thread(
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
}; };
let (websocket, _) = connect_async(websocket_url).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let (backfill_sender, backfill_receiver) = mpsc::channel(100); let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run( spawn(backfill::run(
Arc::new(backfill::create_handler(thread_type, config.clone())), Arc::new(backfill::create_handler(thread_type, config.clone())),
@@ -135,8 +122,7 @@ async fn init_thread(
spawn(websocket::run( spawn(websocket::run(
Arc::new(websocket::create_handler(thread_type, config.clone())), Arc::new(websocket::create_handler(thread_type, config.clone())),
websocket_receiver, websocket_receiver,
websocket_stream, websocket_url,
websocket_sink,
)); ));
(websocket_sender, backfill_sender) (websocket_sender, backfill_sender)
@@ -214,7 +200,7 @@ async fn handle_message(
match message.action { match message.action {
Action::Add => { Action::Add => {
let assets = async { let assets = async {
alpaca::api::incoming::asset::get_by_symbols( alpaca::assets::get_by_symbols(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&symbols, &symbols,
@@ -229,7 +215,7 @@ async fn handle_message(
}; };
let positions = async { let positions = async {
alpaca::api::incoming::position::get_by_symbols( alpaca::positions::get_by_symbols(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&symbols, &symbols,
@@ -252,7 +238,7 @@ async fn handle_message(
let position = positions.remove(symbol); let position = positions.remove(symbol);
batch.push(Asset::from((asset, position))); batch.push(Asset::from((asset, position)));
} else { } else {
error!("Failed to find asset for symbol: {}", symbol); error!("Failed to find asset for symbol: {}.", symbol);
} }
} }

View File

@@ -1,9 +1,12 @@
use super::Pending; use super::State;
use crate::{config::Config, database}; use crate::{config::Config, database};
use async_trait::async_trait; use async_trait::async_trait;
use log::{debug, error, info}; use log::{debug, error, info};
use qrust::types::{alpaca::websocket, Bar}; use qrust::types::{alpaca::websocket, Bar};
use std::{collections::HashMap, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::RwLock; use tokio::sync::RwLock;
pub struct Handler { pub struct Handler {
@@ -23,7 +26,7 @@ impl super::Handler for Handler {
async fn handle_websocket_message( async fn handle_websocket_message(
&self, &self,
pending: Arc<RwLock<Pending>>, state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message, message: websocket::data::incoming::Message,
) { ) {
match message { match message {
@@ -35,19 +38,24 @@ impl super::Handler for Handler {
unreachable!() unreachable!()
}; };
let mut pending = pending.write().await; let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = pending let newly_subscribed = state
.subscriptions .pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol)) .extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let newly_unsubscribed = pending let newly_unsubscribed = state
.unsubscriptions .pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol)) .extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
drop(pending); state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() { if !newly_subscribed.is_empty() {
info!( info!(
@@ -122,4 +130,8 @@ impl super::Handler for Handler {
_ => unreachable!(), _ => unreachable!(),
} }
} }
fn log_string(&self) -> &'static str {
"bars"
}
} }

View File

@@ -2,23 +2,27 @@ mod bars;
mod news; mod news;
use super::ThreadType; use super::ThreadType;
use crate::config::Config; use crate::config::{Config, ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::{ use backoff::{future::retry_notify, ExponentialBackoff};
future::join_all, use futures_util::{future::join_all, SinkExt, StreamExt};
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::error; use log::error;
use qrust::types::{alpaca::websocket, Class}; use qrust::types::{
alpaca::{self, websocket},
Class,
};
use serde_json::{from_str, to_string}; use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{ use tokio::{
net::TcpStream, net::TcpStream,
select, spawn, select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock}, sync::{mpsc, oneshot, RwLock},
}; };
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action { pub enum Action {
Subscribe, Subscribe,
@@ -54,9 +58,10 @@ impl Message {
} }
} }
pub struct Pending { pub struct State {
pub subscriptions: HashMap<String, oneshot::Sender<()>>, pub active_subscriptions: HashSet<String>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>, pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
} }
#[async_trait] #[async_trait]
@@ -67,53 +72,64 @@ pub trait Handler: Send + Sync {
) -> websocket::data::outgoing::subscribe::Message; ) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message( async fn handle_websocket_message(
&self, &self,
pending: Arc<RwLock<Pending>>, state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message, message: websocket::data::incoming::Message,
); );
fn log_string(&self) -> &'static str;
} }
pub async fn run( pub async fn run(
handler: Arc<Box<dyn Handler>>, handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>, mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, websocket_url: String,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) { ) {
let pending = Arc::new(RwLock::new(Pending { let state = Arc::new(RwLock::new(State {
subscriptions: HashMap::new(), active_subscriptions: HashSet::new(),
unsubscriptions: HashMap::new(), pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
})); }));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop { loop {
select! { select! {
Some(message) = receiver.recv() => { Some(message) = receiver.recv() => {
spawn(handle_message( spawn(handle_message(
handler.clone(), handler.clone(),
pending.clone(), state.clone(),
websocket_sink.clone(), sink_sender.clone(),
message, message,
)); ));
} }
Some(Ok(message)) = websocket_stream.next() => { Some(message) = stream_receiver.recv() => {
match message { match message {
tungstenite::Message::Text(message) => { tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message); let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() { if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message); error!("Failed to deserialize websocket message: {:?}.", message);
continue; continue;
} }
for message in parsed_message.unwrap() { for message in parsed_message.unwrap() {
let handler = handler.clone(); let handler = handler.clone();
let pending = pending.clone(); let state = state.clone();
spawn(async move { spawn(async move {
handler.handle_websocket_message(pending, message).await; handler.handle_websocket_message(state, message).await;
}); });
} }
} }
tungstenite::Message::Ping(_) => {} tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message), _ => error!("Unexpected websocket message: {:?}.", message),
} }
} }
else => panic!("Communication channel unexpectedly closed.") else => panic!("Communication channel unexpectedly closed.")
@@ -121,10 +137,142 @@ pub async fn run(
} }
} }
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if !pending_subscriptions.is_empty() {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message( async fn handle_message(
handler: Arc<Box<dyn Handler>>, handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>, pending: Arc<RwLock<State>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>, sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message, message: Message,
) { ) {
if message.symbols.is_empty() { if message.symbols.is_empty() {
@@ -134,23 +282,22 @@ async fn handle_message(
match message.action { match message.action {
Some(Action::Subscribe) => { Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message let (pending_subscriptions, receivers) = message
.symbols .symbols
.iter() .iter()
.map(|symbol| { .map(|symbol| {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver) ((symbol.clone(), sender), receiver)
}) })
.unzip(); .unzip::<_, _, Vec<_>, Vec<_>>();
pending pending
.write() .write()
.await .await
.subscriptions .pending_subscriptions
.extend(pending_subscriptions); .extend(pending_subscriptions);
sink.lock() sink_sender
.await
.send(tungstenite::Message::Text( .send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe( to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols), handler.create_subscription_message(message.symbols),
@@ -175,11 +322,10 @@ async fn handle_message(
pending pending
.write() .write()
.await .await
.unsubscriptions .pending_unsubscriptions
.extend(pending_unsubscriptions); .extend(pending_unsubscriptions);
sink.lock() sink_sender
.await
.send(tungstenite::Message::Text( .send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe( to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()), handler.create_subscription_message(message.symbols.clone()),

View File

@@ -1,4 +1,4 @@
use super::Pending; use super::State;
use crate::{config::Config, database}; use crate::{config::Config, database};
use async_trait::async_trait; use async_trait::async_trait;
use log::{debug, error, info}; use log::{debug, error, info};
@@ -21,7 +21,7 @@ impl super::Handler for Handler {
async fn handle_websocket_message( async fn handle_websocket_message(
&self, &self,
pending: Arc<RwLock<Pending>>, state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message, message: websocket::data::incoming::Message,
) { ) {
match message { match message {
@@ -32,19 +32,23 @@ impl super::Handler for Handler {
unreachable!() unreachable!()
}; };
let mut pending = pending.write().await; let mut state = state.write().await;
let newly_subscribed = pending let newly_subscribed = state
.subscriptions .pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol)) .extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let newly_unsubscribed = pending let newly_unsubscribed = state
.unsubscriptions .pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol)) .extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
drop(pending); state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() { if !newly_subscribed.is_empty() {
info!( info!(
@@ -108,4 +112,8 @@ impl super::Handler for Handler {
_ => unreachable!(), _ => unreachable!(),
} }
} }
fn log_string(&self) -> &'static str {
"news"
}
} }

View File

@@ -21,7 +21,7 @@ pub async fn run(
); );
if parsed_message.is_err() { if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message); error!("Failed to deserialize websocket message: {:?}.", message);
continue; continue;
} }
@@ -31,7 +31,7 @@ pub async fn run(
)); ));
} }
tungstenite::Message::Ping(_) => {} tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message), _ => error!("Unexpected websocket message: {:?}.", message),
} }
} }
} }
@@ -43,7 +43,7 @@ async fn handle_websocket_message(
match message { match message {
websocket::trading::incoming::Message::Order(message) => { websocket::trading::incoming::Message::Order(message) => {
debug!( debug!(
"Received order message for {}: {:?}", "Received order message for {}: {:?}.",
message.order.symbol, message.event message.order.symbol, message.event
); );