Add multiple asset adding route

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-09 20:13:36 +00:00
parent 080f91b044
commit 681d7393d7
31 changed files with 754 additions and 282 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables
debug/
target/
log/
# These are backup files generated by rustfmt
**/*.rs.bk

View File

@@ -4,7 +4,14 @@ appenders:
encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root:
level: info
appenders:
- stdout
- file

View File

@@ -13,7 +13,7 @@ use rust_bert::{
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
use tokio::sync::{Mutex, Semaphore};
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
@@ -51,17 +51,21 @@ lazy_static! {
Mode::Paper => String::from("paper-api"),
}
);
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
.expect("BERT_MAX_INPUTS must be set.")
.parse()
.expect("MAX_BERT_INPUTS must be a positive integer.");
.expect("BERT_MAX_INPUTS must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
pub sequence_classifier: Mutex<SequenceClassificationModel>,
}
@@ -95,6 +99,7 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
sequence_classifier: Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,

View File

@@ -1,8 +1,11 @@
use std::sync::Arc;
use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets");
select_where_symbol!(Asset, "assets");
@@ -11,14 +14,16 @@ delete_where_symbols!("assets");
optimize!("assets");
pub async fn update_status_where_symbol<T>(
clickhouse_client: &Client,
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
status: bool,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
@@ -27,14 +32,16 @@ where
}
pub async fn update_qty_where_symbol<T>(
clickhouse_client: &Client,
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol)

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
use tokio::sync::Semaphore;
select_where_symbol!(Backfill, "backfills_bars");
select_where_symbols!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
pub async fn unfresh(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute()
.await

View File

@@ -1,16 +1,20 @@
use std::sync::Arc;
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
cleanup, delete_where_symbols, optimize, select_where_symbols, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
use tokio::sync::Semaphore;
select_where_symbol!(Backfill, "backfills_news");
select_where_symbols!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
pub async fn unfresh(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute()
.await

View File

@@ -1,7 +1,21 @@
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
cleanup!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,11 +1,14 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar};
use clickhouse::error::Error;
use tokio::try_join;
use clickhouse::{error::Error, Client};
use tokio::{sync::Semaphore, try_join};
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, T>(
client: &clickhouse::Client,
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: T,
) -> Result<(), Error>
where
@@ -34,5 +37,6 @@ where
.await
};
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ())
}

View File

@@ -15,7 +15,9 @@ macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
@@ -29,11 +31,13 @@ macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
@@ -46,13 +50,39 @@ macro_rules! select_where_symbol {
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
@@ -65,12 +95,14 @@ macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: T,
) -> Result<(), clickhouse::error::Error>
where
T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
@@ -85,11 +117,13 @@ macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
@@ -102,7 +136,11 @@ macro_rules! delete_where_symbols {
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
@@ -117,7 +155,11 @@ macro_rules! cleanup {
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
@@ -126,27 +168,33 @@ macro_rules! optimize {
};
}
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
info!("Cleaning up database.");
try_join!(
bars::cleanup(clickhouse_client),
news::cleanup(clickhouse_client),
backfills_bars::cleanup(clickhouse_client),
backfills_news::cleanup(clickhouse_client)
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
info!("Optimizing database.");
try_join!(
assets::optimize(clickhouse_client),
bars::optimize(clickhouse_client),
news::optimize(clickhouse_client),
backfills_bars::optimize(clickhouse_client),
backfills_news::optimize(clickhouse_client),
orders::optimize(clickhouse_client),
calendar::optimize(clickhouse_client)
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -1,24 +1,33 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols)
.execute()
.await
}
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)

View File

@@ -68,7 +68,11 @@ pub async fn rehydrate_orders(config: &Arc<Config>) {
.flat_map(&alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(&config.clickhouse_client, &orders)
database::orders::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
@@ -92,7 +96,10 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
};
let assets_future = async {
database::assets::select(&config.clickhouse_client)
database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
};
@@ -111,7 +118,11 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
})
.collect::<Vec<_>>();
database::assets::upsert_batch(&config.clickhouse_client, &assets)
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();

View File

@@ -22,15 +22,27 @@ async fn main() {
let config = Config::arc_from_env();
try_join!(
database::backfills_bars::unfresh(&config.clickhouse_client),
database::backfills_news::unfresh(&config.clickhouse_client)
database::backfills_bars::unfresh(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter
),
database::backfills_news::unfresh(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter
)
)
.unwrap();
database::cleanup_all(&config.clickhouse_client)
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
database::optimize_all(&config.clickhouse_client)
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
@@ -53,7 +65,10 @@ async fn main() {
spawn(threads::clock::run(config.clone(), clock_sender));
let assets = database::assets::select(&config.clickhouse_client)
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()

View File

@@ -5,14 +5,20 @@ use crate::{
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use serde::Deserialize;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(&config.clickhouse_client)
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -23,7 +29,11 @@ pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -33,16 +43,98 @@ pub async fn get_where_symbol(
}
#[derive(Deserialize)]
pub struct AddAssetRequest {
symbol: String,
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetRequest>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(vec![], vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == alpaca::shared::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets.clone()
);
Ok((
StatusCode::CREATED,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol)
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
@@ -53,7 +145,7 @@ pub async fn add(
let asset = alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbol,
&symbol,
None,
)
.await
@@ -64,7 +156,10 @@ pub async fn add(
})
})?;
if !asset.tradable || !asset.fractionable {
if asset.status != alpaca::shared::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN);
}
@@ -83,7 +178,11 @@ pub async fn delete(
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;

View File

@@ -16,6 +16,7 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
.route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete))
.layer(Extension(config))
.layer(Extension(data_sender));

View File

@@ -74,7 +74,11 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
let sleep_future = sleep(sleep_until);
let calendar_future = async {
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar)
database::calendar::upsert_batch_and_delete(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
};

View File

@@ -2,7 +2,7 @@ use super::ThreadType;
use crate::{
config::{
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
MAX_BERT_INPUTS,
BERT_MAX_INPUTS,
},
database,
types::{
@@ -30,24 +30,14 @@ pub enum Action {
Purge,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Backfill),
super::Action::Remove => Some(Action::Purge),
super::Action::Disable => None,
}
}
}
pub struct Message {
pub action: Option<Action>,
pub action: Action,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
@@ -62,10 +52,10 @@ impl Message {
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfill(
async fn select_latest_backfills(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error>;
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
@@ -94,9 +84,17 @@ async fn handle_backfill_message(
let mut backfill_jobs = backfill_jobs.lock().await;
match message.action {
Some(Action::Backfill) => {
Action::Backfill => {
let log_string = handler.log_string();
let backfills = handler
.select_latest_backfills(&message.symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
for symbol in message.symbols {
if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() {
@@ -108,18 +106,11 @@ async fn handle_backfill_message(
}
}
let handler = handler.clone();
backfill_jobs.insert(
symbol.clone(),
spawn(async move {
let fetch_from = match handler
.select_latest_backfill(symbol.clone())
.await
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
let fetch_from = backfills
.get(&symbol)
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
@@ -128,13 +119,17 @@ async fn handle_backfill_message(
return;
}
let handler = handler.clone();
backfill_jobs.insert(
symbol.clone(),
spawn(async move {
handler.queue_backfill(&symbol, fetch_to).await;
handler.backfill(symbol, fetch_from, fetch_to).await;
}),
);
}
}
Some(Action::Purge) => {
Action::Purge => {
for symbol in &message.symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
if !job.is_finished() {
@@ -150,7 +145,6 @@ async fn handle_backfill_message(
)
.unwrap();
}
None => {}
}
message.response.send(()).unwrap();
@@ -199,20 +193,34 @@ fn crypto_query_constructor(
#[async_trait]
impl Handler for BarHandler {
async fn select_latest_backfill(
async fn select_latest_backfills(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols)
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
@@ -230,7 +238,7 @@ impl Handler for BarHandler {
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::bar::get_historical(
let Ok(message) = alpaca::api::incoming::bar::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
@@ -267,10 +275,18 @@ impl Handler for BarHandler {
let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&self.config.clickhouse_client, &bars)
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill)
database::backfills_bars::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfill,
)
.await
.unwrap();
@@ -288,20 +304,34 @@ struct NewsHandler {
#[async_trait]
impl Handler for NewsHandler {
async fn select_latest_backfill(
async fn select_latest_backfills(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols)
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
@@ -317,7 +347,7 @@ impl Handler for NewsHandler {
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::news::get_historical(
let Ok(message) = alpaca::api::incoming::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News {
@@ -355,7 +385,7 @@ impl Handler for NewsHandler {
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move {
let predictions = join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
@@ -381,10 +411,18 @@ impl Handler for NewsHandler {
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.config.clickhouse_client, &news)
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
database::backfills_news::upsert(&self.config.clickhouse_client, &backfill)
database::backfills_news::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfill,
)
.await
.unwrap();

View File

@@ -9,18 +9,18 @@ use crate::{
},
create_send_await, database,
types::{alpaca, Asset, Class},
utils::backoff,
};
use futures_util::{future::join_all, StreamExt};
use futures_util::StreamExt;
use itertools::{Either, Itertools};
use std::sync::Arc;
use log::error;
use std::{collections::HashMap, sync::Arc};
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
use tokio_tungstenite::connect_async;
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Action {
Add,
@@ -173,13 +173,6 @@ async fn handle_message(
message.action.into(),
us_equity_symbols.clone()
);
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
message.action.into(),
us_equity_symbols
);
};
let bars_crypto_future = async {
@@ -193,13 +186,6 @@ async fn handle_message(
message.action.into(),
crypto_symbols.clone()
);
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
message.action.into(),
crypto_symbols
);
};
let news_future = async {
@@ -209,62 +195,127 @@ async fn handle_message(
message.action.into(),
symbols.clone()
);
create_send_await!(
news_backfill_sender,
backfill::Message::new,
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
match message.action {
Action::Add => {
let assets = join_all(symbols.into_iter().map(|symbol| {
let config = config.clone();
async move {
let asset_future = async {
alpaca::api::incoming::asset::get_by_symbol(
let assets = async {
alpaca::api::incoming::asset::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
&symbols,
None,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let position_future = async {
alpaca::api::incoming::position::get_by_symbol(
let positions = async {
alpaca::api::incoming::position::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
&symbols,
None,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (asset, position) = join!(asset_future, position_future);
Asset::from((asset, position))
let (mut assets, mut positions) = join!(assets, positions);
let mut batch = vec![];
for symbol in &symbols {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}", symbol);
}
}
}))
.await;
database::assets::upsert_batch(&config.clickhouse_client, &assets)
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols)
database::assets::delete_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbols,
)
.await
.unwrap();
}
_ => {}
}
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
let bars_us_equity_future = async {
if us_equity_symbols.is_empty() {
return;
}
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
};
let bars_crypto_future = async {
if crypto_symbols.is_empty() {
return;
}
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap();
}
@@ -274,11 +325,17 @@ async fn handle_clock_message(
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
) {
database::cleanup_all(&config.clickhouse_client)
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let assets = database::assets::select(&config.clickhouse_client)
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
@@ -299,8 +356,8 @@ async fn handle_clock_message(
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
us_equity_symbols.clone()
backfill::Action::Backfill,
us_equity_symbols
);
};
@@ -308,8 +365,8 @@ async fn handle_clock_message(
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
crypto_symbols.clone()
backfill::Action::Backfill,
crypto_symbols
);
};
@@ -317,7 +374,7 @@ async fn handle_clock_message(
create_send_await!(
news_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
backfill::Action::Backfill,
symbols
);
};

View File

@@ -268,7 +268,11 @@ impl Handler for BarsHandler {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.config.clickhouse_client, &bar)
database::bars::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bar,
)
.await
.unwrap();
}
@@ -283,6 +287,7 @@ impl Handler for BarsHandler {
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
@@ -293,6 +298,7 @@ impl Handler for BarsHandler {
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
@@ -398,7 +404,11 @@ impl Handler for NewsHandler {
..news
};
database::news::upsert(&self.config.clickhouse_client, &news)
database::news::upsert(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
}

View File

@@ -52,7 +52,11 @@ async fn handle_websocket_message(
let order = Order::from(message.order);
database::orders::upsert(&config.clickhouse_client, &order)
database::orders::upsert(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
@@ -63,6 +67,7 @@ async fn handle_websocket_message(
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol,
position_qty,
)

View File

@@ -81,15 +81,15 @@ pub struct Account {
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/account", *ALPACA_API_URL))
.send()
.await?

View File

@@ -3,20 +3,25 @@ use crate::{
config::ALPACA_API_URL,
types::{
self,
alpaca::shared::asset::{Class, Exchange, Status},
alpaca::{
api::outgoing,
shared::asset::{Class, Exchange, Status},
},
},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use std::time::Duration;
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
use uuid::Uuid;
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
@@ -47,17 +52,56 @@ impl From<(Asset, Option<Position>)> for types::Asset {
}
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("{}/assets", *ALPACA_API_URL))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Vec<Asset>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol))
.send()
.await?
@@ -84,3 +128,43 @@ pub async fn get_by_symbol(
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Asset>, Error> {
if symbols.len() < 2 {
let symbol = symbols.first().unwrap();
let asset = get_by_symbol(client, rate_limiter, symbol, backoff).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(client, rate_limiter, &us_equity_query, backoff_clone);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -50,9 +50,9 @@ pub struct Message {
pub next_page_token: Option<String>,
}
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
@@ -60,8 +60,8 @@ pub async fn get_historical(
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()

View File

@@ -32,16 +32,16 @@ impl From<Calendar> for types::Calendar {
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/calendar", *ALPACA_API_URL))
.query(query)
.send()

View File

@@ -19,15 +19,15 @@ pub struct Clock {
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/clock", *ALPACA_API_URL))
.send()
.await?

View File

@@ -73,17 +73,17 @@ pub struct Message {
pub next_page_token: Option<String>,
}
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()

View File

@@ -11,16 +11,16 @@ use std::time::Duration;
pub use shared::order::Order;
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/orders", *ALPACA_API_URL))
.query(query)
.send()

View File

@@ -12,10 +12,10 @@ use log::warn;
use reqwest::Client;
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use std::time::Duration;
use std::{collections::HashSet, time::Duration};
use uuid::Uuid;
#[derive(Deserialize)]
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
@@ -31,7 +31,7 @@ impl From<Side> for shared::order::Side {
}
}
#[derive(Deserialize)]
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
@@ -67,15 +67,15 @@ pub struct Position {
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
rate_limiter.until_ready().await;
client
.get(&format!("{}/positions", *ALPACA_API_URL))
.send()
.await?
@@ -102,16 +102,16 @@ pub async fn get(
}
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
let response = alpaca_client
rate_limiter.until_ready().await;
let response = client
.get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol))
.send()
.await?;
@@ -143,3 +143,25 @@ pub async fn get_by_symbol(
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.len() < 2 {
let symbol = symbols.first().unwrap();
let position = get_by_symbol(client, rate_limiter, symbol, backoff).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -0,0 +1,21 @@
use crate::types::alpaca::shared::asset::{Class, Exchange, Status};
use serde::Serialize;
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,3 +1,4 @@
pub mod asset;
pub mod bar;
pub mod calendar;
pub mod news;

View File

@@ -1,7 +1,7 @@
use crate::{impl_from_enum, types};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Deserialize)]
#[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto
);
#[derive(Deserialize)]
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,

View File

@@ -8,7 +8,8 @@ use std::fmt;
use time::{format_description::OwnedFormatItem, macros::format_description, Time};
lazy_static! {
static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap();
// This *will* break in the future if a crypto pair with one letter is added
static ref RE_SLASH: Regex = Regex::new(r"^(.{2,})(BTC|USD.?)$").unwrap();
static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into();
}