Add automatic websocket reconnection

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-11 23:41:06 +00:00
parent d02f958865
commit d2d20e2978
33 changed files with 838 additions and 664 deletions

View File

@@ -3,13 +3,13 @@ use crate::{
database,
};
use log::{info, warn};
use qrust::types::alpaca;
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::api::incoming::account::get(
let account = alpaca::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
@@ -19,7 +19,7 @@ pub async fn check_account(config: &Arc<Config>) {
.unwrap();
assert!(
!(account.status != alpaca::api::incoming::account::Status::Active),
!(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
@@ -46,11 +46,11 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
while let Some(message) = alpaca::api::incoming::order::get(
while let Some(message) = alpaca::orders::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::order::Order {
status: Some(alpaca::api::outgoing::order::Status::All),
&types::alpaca::api::outgoing::order::Order {
status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
@@ -67,7 +67,7 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
let orders = orders
.into_iter()
.flat_map(&alpaca::api::incoming::order::Order::normalize)
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(
@@ -85,7 +85,7 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data.");
let positions_future = async {
alpaca::api::incoming::position::get(
alpaca::positions::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,

42
src/lib/alpaca/account.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

139
src/lib/alpaca/assets.rs Normal file
View File

@@ -0,0 +1,139 @@
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

53
src/lib/alpaca/bars.rs Normal file
View File

@@ -0,0 +1,53 @@
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,44 @@
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

42
src/lib/alpaca/clock.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

8
src/lib/alpaca/mod.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;

52
src/lib/alpaca/news.rs Normal file
View File

@@ -0,0 +1,52 @@
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

46
src/lib/alpaca/orders.rs Normal file
View File

@@ -0,0 +1,46 @@
use crate::types::alpaca::{api::outgoing, shared::order};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use order::Order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

111
src/lib/alpaca/positions.rs Normal file
View File

@@ -0,0 +1,111 @@
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,3 +1,4 @@
pub mod alpaca;
pub mod database;
pub mod types;
pub mod utils;

View File

@@ -1,12 +1,7 @@
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use std::time::Duration;
use time::OffsetDateTime;
use uuid::Uuid;
@@ -78,39 +73,3 @@ pub struct Account {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,22 +1,11 @@
use super::position::Position;
use crate::types::{
self,
alpaca::{
api::outgoing,
shared::asset::{Class, Exchange, Status},
},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
@@ -48,131 +37,3 @@ impl From<(Asset, Option<Position>)> for types::Asset {
}
}
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -1,10 +1,5 @@
use crate::types::{self, alpaca::api::outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use crate::types;
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime;
#[derive(Deserialize)]
@@ -43,47 +38,3 @@ impl From<(Bar, String)> for types::Bar {
}
}
}
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,13 +1,8 @@
use crate::{
types::{self, alpaca::api::outgoing},
types,
utils::{de, time::EST_OFFSET},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
@@ -29,41 +24,3 @@ impl From<Calendar> for types::Calendar {
}
}
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,9 +1,4 @@
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime;
#[derive(Deserialize)]
@@ -16,39 +11,3 @@ pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,19 +1,8 @@
use crate::{
types::{
self,
alpaca::{
api::{outgoing, ALPACA_NEWS_DATA_API_URL},
shared::news::normalize_html_content,
},
},
types::{self, alpaca::shared::news::normalize_html_content},
utils::de,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime;
#[derive(Deserialize)]
@@ -68,46 +57,3 @@ impl From<News> for types::News {
}
}
}
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,45 +1,3 @@
use crate::types::alpaca::{api::outgoing, shared};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
pub use shared::order::Order;
use std::time::Duration;
use crate::types::alpaca::shared::order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub use order::{Order, Side};

View File

@@ -1,17 +1,12 @@
use crate::{
types::alpaca::shared::{
self,
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use std::{collections::HashSet, time::Duration};
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
@@ -21,7 +16,7 @@ pub enum Side {
Short,
}
impl From<Side> for shared::order::Side {
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
@@ -64,110 +59,3 @@ pub struct Position {
pub change_today: f64,
pub asset_marginable: bool,
}
pub const ALPACA_API_URL_TEMPLATE: &str = "";
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,6 +1,8 @@
use crate::types::alpaca::shared::asset::{Class, Exchange, Status};
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,

View File

@@ -1,12 +1,13 @@
use crate::{
types::alpaca::shared::{Sort, Source},
alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
utils::{ser, ONE_MINUTE},
};
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
pub const MAX_LIMIT: i64 = 10_000;
pub use shared::{Sort, Source};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]

View File

@@ -1,9 +1,7 @@
use crate::{types::alpaca::shared::Sort, utils::ser};
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize;
use time::OffsetDateTime;
pub const MAX_LIMIT: i64 = 50;
#[derive(Serialize)]
pub struct News {
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]

View File

@@ -1,10 +1,12 @@
use crate::{
types::alpaca::shared::{order::Side, Sort},
types::alpaca::shared::{order, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]

View File

@@ -1,10 +1,10 @@
use crate::types::alpaca::shared;
use crate::types::alpaca::shared::order;
use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime;
use uuid::Uuid;
pub use shared::order::Order;
pub use order::Order;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]

View File

@@ -4,7 +4,10 @@ use crate::{
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use qrust::types::{alpaca, Asset};
use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
@@ -69,7 +72,7 @@ pub async fn add(
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols(
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
@@ -94,7 +97,7 @@ pub async fn add(
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == alpaca::shared::asset::Status::Active
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
@@ -144,7 +147,7 @@ pub async fn add_symbol(
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::api::incoming::asset::get_by_symbol(
let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
@@ -159,7 +162,7 @@ pub async fn add_symbol(
})
})?;
if asset.status != alpaca::shared::asset::Status::Active
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{

View File

@@ -4,7 +4,8 @@ use crate::{
};
use log::info;
use qrust::{
types::{alpaca, Calendar},
alpaca,
types::{self, Calendar},
utils::{backoff, duration_until},
};
use std::sync::Arc;
@@ -21,8 +22,8 @@ pub struct Message {
pub next_switch: OffsetDateTime,
}
impl From<alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self {
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
if clock.is_open {
Self {
status: Status::Open,
@@ -40,7 +41,7 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock_future = async {
alpaca::api::incoming::clock::get(
alpaca::clock::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
Some(backoff::infinite()),
@@ -51,10 +52,10 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
};
let calendar_future = async {
alpaca::api::incoming::calendar::get(
alpaca::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::calendar::Calendar::default(),
&types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
&ALPACA_API_BASE,
)

View File

@@ -6,11 +6,10 @@ use crate::{
use async_trait::async_trait;
use log::{error, info};
use qrust::{
alpaca,
types::{
alpaca::{
self,
shared::{Sort, Source},
},
self,
alpaca::shared::{Sort, Source},
Backfill, Bar,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
@@ -27,7 +26,7 @@ pub struct Handler {
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
@@ -35,8 +34,8 @@ pub fn us_equity_query_constructor(
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
@@ -52,8 +51,8 @@ pub fn crypto_query_constructor(
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
@@ -124,7 +123,7 @@ impl super::Handler for Handler {
let mut next_page_token = None;
loop {
let message = alpaca::api::incoming::bar::get(
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
@@ -190,7 +189,7 @@ impl super::Handler for Handler {
}
fn max_limit(&self) -> i64 {
alpaca::api::outgoing::bar::MAX_LIMIT
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {

View File

@@ -7,11 +7,10 @@ use async_trait::async_trait;
use futures_util::future::join_all;
use log::{error, info};
use qrust::{
alpaca,
types::{
alpaca::{
self,
shared::{Sort, Source},
},
self,
alpaca::shared::{Sort, Source},
news::Prediction,
Backfill, News,
},
@@ -86,10 +85,10 @@ impl super::Handler for Handler {
let mut next_page_token = None;
loop {
let message = alpaca::api::incoming::news::get(
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News {
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
@@ -187,7 +186,7 @@ impl super::Handler for Handler {
}
fn max_limit(&self) -> i64 {
alpaca::api::outgoing::news::MAX_LIMIT
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {

View File

@@ -3,28 +3,26 @@ mod websocket;
use super::clock;
use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET, ALPACA_SOURCE},
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database,
};
use futures_util::StreamExt;
use itertools::{Either, Itertools};
use log::error;
use qrust::types::{
alpaca::{
self,
websocket::{
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
},
Asset, Class,
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
use tokio_tungstenite::connect_async;
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
@@ -67,11 +65,10 @@ pub async fn run(
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await;
init_thread(&config, ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await;
let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News).await;
init_thread(&config, ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) = init_thread(&config, ThreadType::News);
loop {
select! {
@@ -100,8 +97,8 @@ pub async fn run(
}
}
async fn init_thread(
config: Arc<Config>,
fn init_thread(
config: &Arc<Config>,
thread_type: ThreadType,
) -> (
mpsc::Sender<websocket::Message>,
@@ -115,16 +112,6 @@ async fn init_thread(
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
};
let (websocket, _) = connect_async(websocket_url).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(
Arc::new(backfill::create_handler(thread_type, config.clone())),
@@ -135,8 +122,7 @@ async fn init_thread(
spawn(websocket::run(
Arc::new(websocket::create_handler(thread_type, config.clone())),
websocket_receiver,
websocket_stream,
websocket_sink,
websocket_url,
));
(websocket_sender, backfill_sender)
@@ -214,7 +200,7 @@ async fn handle_message(
match message.action {
Action::Add => {
let assets = async {
alpaca::api::incoming::asset::get_by_symbols(
alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
@@ -229,7 +215,7 @@ async fn handle_message(
};
let positions = async {
alpaca::api::incoming::position::get_by_symbols(
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
@@ -252,7 +238,7 @@ async fn handle_message(
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}", symbol);
error!("Failed to find asset for symbol: {}.", symbol);
}
}

View File

@@ -1,9 +1,12 @@
use super::Pending;
use super::State;
use crate::{config::Config, database};
use async_trait::async_trait;
use log::{debug, error, info};
use qrust::types::{alpaca::websocket, Bar};
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::RwLock;
pub struct Handler {
@@ -23,7 +26,7 @@ impl super::Handler for Handler {
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
@@ -35,19 +38,24 @@ impl super::Handler for Handler {
unreachable!()
};
let mut pending = pending.write().await;
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = pending
.subscriptions
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
@@ -122,4 +130,8 @@ impl super::Handler for Handler {
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
}

View File

@@ -2,23 +2,27 @@ mod bars;
mod news;
use super::ThreadType;
use crate::config::Config;
use crate::config::{Config, ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait;
use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use futures_util::{future::join_all, SinkExt, StreamExt};
use log::error;
use qrust::types::{alpaca::websocket, Class};
use qrust::types::{
alpaca::{self, websocket},
Class,
};
use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
sync::{mpsc, oneshot, RwLock},
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
@@ -54,9 +58,10 @@ impl Message {
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
pub struct State {
pub active_subscriptions: HashSet<String>,
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
@@ -67,53 +72,64 @@ pub trait Handler: Send + Sync {
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
);
fn log_string(&self) -> &'static str;
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
websocket_url: String,
) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
let state = Arc::new(RwLock::new(State {
active_subscriptions: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
state.clone(),
sink_sender.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
Some(message) = stream_receiver.recv() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let pending = pending.clone();
let state = state.clone();
spawn(async move {
handler.handle_websocket_message(pending, message).await;
handler.handle_websocket_message(state, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
@@ -121,10 +137,142 @@ pub async fn run(
}
}
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if !pending_subscriptions.is_empty() {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
pending: Arc<RwLock<State>>,
sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message,
) {
if message.symbols.is_empty() {
@@ -134,23 +282,22 @@ async fn handle_message(
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
let (pending_subscriptions, receivers) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
.unzip::<_, _, Vec<_>, Vec<_>>();
pending
.write()
.await
.subscriptions
.pending_subscriptions
.extend(pending_subscriptions);
sink.lock()
.await
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
@@ -175,11 +322,10 @@ async fn handle_message(
pending
.write()
.await
.unsubscriptions
.pending_unsubscriptions
.extend(pending_unsubscriptions);
sink.lock()
.await
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),

View File

@@ -1,4 +1,4 @@
use super::Pending;
use super::State;
use crate::{config::Config, database};
use async_trait::async_trait;
use log::{debug, error, info};
@@ -21,7 +21,7 @@ impl super::Handler for Handler {
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
@@ -32,19 +32,23 @@ impl super::Handler for Handler {
unreachable!()
};
let mut pending = pending.write().await;
let mut state = state.write().await;
let newly_subscribed = pending
.subscriptions
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
@@ -108,4 +112,8 @@ impl super::Handler for Handler {
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"news"
}
}

View File

@@ -21,7 +21,7 @@ pub async fn run(
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
@@ -31,7 +31,7 @@ pub async fn run(
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
}
@@ -43,7 +43,7 @@ async fn handle_websocket_message(
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}",
"Received order message for {}: {:?}.",
message.order.symbol, message.event
);