Compare commits

..

1 Commits

Author SHA1 Message Date
2036e5fa32 Add LSTM experiment
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-27 09:53:52 +00:00
116 changed files with 37389 additions and 3203 deletions

7
.gitignore vendored
View File

@@ -2,7 +2,6 @@
# will have compiled files and executables
debug/
target/
log/
# These are backup files generated by rustfmt
**/*.rs.bk
@@ -11,3 +10,9 @@ log/
*.pdb
.env*
# ML models
models/*/rust_model.ot
notebooks/models/
libdevice.10.bc

View File

@@ -24,13 +24,13 @@ build:
script:
- cargo +nightly build
# test:
# image: registry.karaolidis.com/karaolidis/qrust/rust
# stage: test
# cache:
# <<: *global_cache
# script:
# - cargo +nightly test
test:
image: registry.karaolidis.com/karaolidis/qrust/rust
stage: test
cache:
<<: *global_cache
script:
- cargo +nightly test
lint:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -48,7 +48,7 @@ depcheck:
<<: *global_cache
script:
- cargo +nightly outdated
- cargo +nightly udeps --all-targets
- cargo +nightly udeps
build-release:
image: registry.karaolidis.com/karaolidis/qrust/rust

986
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,14 +3,6 @@ name = "qrust"
version = "0.1.0"
edition = "2021"
[lib]
name = "qrust"
path = "src/lib/mod.rs"
[[bin]]
name = "qrust"
path = "src/main.rs"
[profile.release]
panic = 'abort'
strip = true
@@ -51,7 +43,6 @@ clickhouse = { version = "0.11.6", features = [
] }
uuid = { version = "1.6.1", features = [
"serde",
"v4",
] }
time = { version = "0.3.31", features = [
"serde",
@@ -65,9 +56,8 @@ backoff = { version = "0.4.0", features = [
"tokio",
] }
regex = "1.10.3"
html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77"
itertools = "0.12.1"
lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [
"serialize",
] }

View File

@@ -4,14 +4,7 @@ appenders:
encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root:
level: info
appenders:
- stdout
- file

View File

@@ -0,0 +1,32 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -0,0 +1 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -0,0 +1 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

30522
models/finbert/vocab.txt Normal file

File diff suppressed because it is too large Load Diff

3868
notebooks/lstm.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,49 +1,68 @@
use crate::types::alpaca::shared::{Mode, Source};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use qrust::types::alpaca::shared::{Mode, Source};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
use tokio::sync::Semaphore;
use rust_bert::{
pipelines::{
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
};
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String =
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
#[derive(Debug)]
pub static ref ALPACA_API_URL: String = format!(
"https://{}.alpaca.markets/v2",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
#[derive(Debug)]
pub static ref ALPACA_WEBSOCKET_URL: String = format!(
"wss://{}.alpaca.markets/stream",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse()
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
.expect("MAX_BERT_INPUTS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
pub sequence_classifier: Mutex<SequenceClassificationModel>,
}
impl Config {
@@ -66,7 +85,7 @@ impl Config {
.unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10_000) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})),
clickhouse_client: clickhouse::Client::default()
@@ -76,7 +95,25 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
sequence_classifier: Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
),
}
}

View File

@@ -1,11 +1,8 @@
use std::sync::Arc;
use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets");
select_where_symbol!(Asset, "assets");
@@ -14,16 +11,14 @@ delete_where_symbols!("assets");
optimize!("assets");
pub async fn update_status_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
clickhouse_client: &Client,
symbol: &T,
status: bool,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
clickhouse_client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
@@ -32,16 +27,14 @@ where
}
pub async fn update_qty_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
clickhouse_client: &Client,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
clickhouse_client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol)

View File

@@ -0,0 +1,17 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -0,0 +1,17 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute()
.await
}

7
src/database/bars.rs Normal file
View File

@@ -0,0 +1,7 @@
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
cleanup!("bars");
optimize!("bars");

View File

@@ -1,19 +1,16 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar};
use clickhouse::{error::Error, Client};
use tokio::{sync::Semaphore, try_join};
use clickhouse::error::Error;
use tokio::try_join;
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, I>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: I,
pub async fn upsert_batch_and_delete<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
I::IntoIter: Send,
T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
T::IntoIter: Send,
{
let upsert_future = async {
let mut insert = client.insert("calendar")?;
@@ -37,6 +34,5 @@ where
.await
};
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ())
}

152
src/database/mod.rs Normal file
View File

@@ -0,0 +1,152 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use log::info;
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
) -> Result<Vec<$record>, clickhouse::error::Error> {
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), clickhouse::error::Error>
where
T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Cleaning up database.");
try_join!(
bars::cleanup(clickhouse_client),
news::cleanup(clickhouse_client),
backfills_bars::cleanup(clickhouse_client),
backfills_news::cleanup(clickhouse_client)
)
.map(|_| ())
}
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Optimizing database.");
try_join!(
assets::optimize(clickhouse_client),
bars::optimize(clickhouse_client),
news::optimize(clickhouse_client),
backfills_bars::optimize(clickhouse_client),
backfills_news::optimize(clickhouse_client),
orders::optimize(clickhouse_client),
calendar::optimize(clickhouse_client)
)
.map(|_| ())
}

View File

@@ -1,33 +1,24 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");
pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
clickhouse_client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols)
.execute()
.await
}
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)

View File

@@ -1,25 +1,24 @@
use crate::{
config::{Config, ALPACA_API_BASE},
config::{Config, ALPACA_MODE},
database,
types::alpaca,
};
use log::{info, warn};
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::account::get(
let account = alpaca::api::incoming::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
assert!(
!(account.status != types::alpaca::api::incoming::account::Status::Active),
!(account.status != alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
@@ -34,60 +33,56 @@ pub async fn check_account(config: &Arc<Config>) {
warn!("Account cash is zero, qrust will not be able to trade.");
}
info!(
"qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_API_BASE, account.currency, account.cash
warn!(
"qrust active on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_MODE, account.currency, account.cash
);
}
pub async fn rehydrate_orders(config: &Arc<Config>) {
info!("Rehydrating order data.");
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
loop {
let message = alpaca::orders::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::order::Order {
status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
if message.is_empty() {
break;
}
while let Some(message) = alpaca::api::incoming::order::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::order::Order {
status: Some(alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
)
.await
.ok()
.filter(|message| !message.is_empty())
{
orders.extend(message);
after = orders.last().unwrap().submitted_at;
}
let orders = orders
.into_iter()
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.flat_map(&alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
database::orders::upsert_batch(&config.clickhouse_client, &orders)
.await
.unwrap();
info!("Rehydrated order data.");
}
pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data.");
let positions_future = async {
alpaca::positions::get(
alpaca::api::incoming::position::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
@@ -97,12 +92,9 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
};
let assets_future = async {
database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
database::assets::select(&config.clickhouse_client)
.await
.unwrap()
};
let (mut positions, assets) = join!(positions_future, assets_future);
@@ -119,13 +111,9 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
})
.collect::<Vec<_>>();
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await
.unwrap();
for position in positions.values() {
warn!(
@@ -133,4 +121,6 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
position.symbol, position.qty
);
}
info!("Rehydrated position data.");
}

View File

@@ -1,39 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Account>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,132 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Asset>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Asset>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -1,50 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,41 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Calendar>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,39 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Clock>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,27 +0,0 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;
use reqwest::StatusCode;
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
if err.is_status() {
return match err.status() {
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
| None => backoff::Error::Permanent(err),
_ => err.into(),
};
}
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
{
return backoff::Error::Permanent(err);
}
err.into()
}

View File

@@ -1,49 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,108 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Position>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(error_to_backoff)?
.json::<Position>()
.await
.map_err(error_to_backoff)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,11 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_bars");
upsert_batch!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
set_fresh_where_symbols!("backfills_bars");

View File

@@ -1,11 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_news");
upsert_batch!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
set_fresh_where_symbols!("backfills_news");

View File

@@ -1,21 +0,0 @@
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,223 +0,0 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, I>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: I,
) -> Result<(), clickhouse::error::Error>
where
I: IntoIterator<Item = &'a $record> + Send + Sync,
I::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! set_fresh_where_symbols {
($table_name:expr) => {
pub async fn set_fresh_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
fresh: bool,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
$table_name
))
.bind(fresh)
.bind(symbols)
.execute()
.await
}
};
}
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -1,4 +0,0 @@
pub mod alpaca;
pub mod database;
pub mod types;
pub mod utils;

View File

@@ -1,39 +0,0 @@
use super::position::Position;
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}

View File

@@ -1,40 +0,0 @@
use crate::types;
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Bar {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "o")]
pub open: f64,
#[serde(rename = "h")]
pub high: f64,
#[serde(rename = "l")]
pub low: f64,
#[serde(rename = "c")]
pub close: f64,
#[serde(rename = "v")]
pub volume: f64,
#[serde(rename = "n")]
pub trades: i64,
#[serde(rename = "vw")]
pub vwap: f64,
}
impl From<(Bar, String)> for types::Bar {
fn from((bar, symbol): (Bar, String)) -> Self {
Self {
time: bar.time,
symbol,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}

View File

@@ -1,26 +0,0 @@
use crate::{
types,
utils::{de, time::EST_OFFSET},
};
use serde::Deserialize;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}

View File

@@ -1,13 +0,0 @@
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime,
pub is_open: bool,
#[serde(with = "time::serde::rfc3339")]
pub next_open: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime,
}

View File

@@ -1,57 +0,0 @@
use crate::{
types::{self, alpaca::shared::news::strip},
utils::de,
};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageSize {
Thumb,
Small,
Large,
}
#[derive(Deserialize)]
pub struct Image {
pub size: ImageSize,
pub url: String,
}
#[derive(Deserialize)]
pub struct News {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
pub images: Vec<Image>,
}
impl From<News> for types::News {
fn from(news: News) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}
}

View File

@@ -1,3 +0,0 @@
use crate::types::alpaca::shared::order;
pub use order::{Order, Side};

View File

@@ -1,61 +0,0 @@
use crate::{
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}

View File

@@ -1,6 +0,0 @@
pub mod incoming;
pub mod outgoing;
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";

View File

@@ -1,23 +0,0 @@
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,8 +0,0 @@
pub mod auth;
pub mod data;
pub mod trading;
pub const ALPACA_US_EQUITY_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";

View File

@@ -1,11 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)]
pub struct Backfill {
pub symbol: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time: OffsetDateTime,
pub fresh: bool,
}

View File

@@ -1,19 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News {
pub id: i64,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_created: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime,
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: String,
}

View File

@@ -3,19 +3,16 @@
#![feature(hash_extract_if)]
mod config;
mod database;
mod init;
mod routes;
mod threads;
mod types;
mod utils;
use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
};
use config::Config;
use dotenv::dotenv;
use log::info;
use log4rs::config::Deserializers;
use nonempty::NonEmpty;
use qrust::{create_send_await, database};
use tokio::{join, spawn, sync::mpsc, try_join};
#[tokio::main]
@@ -24,62 +21,18 @@ async fn main() {
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
info!("Marking all assets as stale.");
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
try_join!(
database::backfills_bars::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
),
database::backfills_news::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
)
database::backfills_bars::unfresh(&config.clickhouse_client),
database::backfills_news::unfresh(&config.clickhouse_client)
)
.unwrap();
info!("Cleaning up database.");
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Optimizing database.");
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Rehydrating account data.");
database::cleanup_all(&config.clickhouse_client)
.await
.unwrap();
database::optimize_all(&config.clickhouse_client)
.await
.unwrap();
init::check_account(&config).await;
join!(
@@ -87,8 +40,6 @@ async fn main() {
init::rehydrate_positions(&config)
);
info!("Starting threads.");
spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
@@ -102,14 +53,19 @@ async fn main() {
spawn(threads::clock::run(config.clone(), clock_sender));
if let Some(assets) = NonEmpty::from_vec(assets) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
}
let assets = database::assets::select(&config.clickhouse_client)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
routes::run(config, data_sender).await;
}

View File

@@ -1,30 +1,20 @@
use crate::{
config::{Config, ALPACA_API_BASE},
config::Config,
create_send_await, database, threads,
types::{alpaca, Asset},
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use nonempty::{nonempty, NonEmpty};
use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let assets = database::assets::select(&config.clickhouse_client)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
@@ -33,13 +23,9 @@ pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
@@ -47,115 +33,28 @@ pub async fn get_where_symbol(
}
#[derive(Deserialize)]
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
pub struct AddAssetRequest {
symbol: String,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
}
Ok((
StatusCode::OK,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
Json(request): Json<AddAssetRequest>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::assets::get_by_symbol(
let asset = alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
&request.symbol,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| {
@@ -165,10 +64,7 @@ pub async fn add_symbol(
})
})?;
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
if !asset.tradable || !asset.fractionable {
return Err(StatusCode::FORBIDDEN);
}
@@ -176,7 +72,7 @@ pub async fn add_symbol(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
nonempty![(asset.symbol, asset.class.into())]
vec![(asset.symbol, asset.class.into())]
);
Ok(StatusCode::CREATED)
@@ -187,20 +83,16 @@ pub async fn delete(
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
nonempty![(asset.symbol, asset.class)]
vec![(asset.symbol, asset.class)]
);
Ok(StatusCode::NO_CONTENT)

View File

@@ -16,7 +16,6 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
.route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete))
.layer(Extension(config))
.layer(Extension(data_sender));

View File

@@ -1,13 +1,10 @@
use crate::{
config::{Config, ALPACA_API_BASE},
config::Config,
database,
};
use log::info;
use qrust::{
alpaca,
types::{self, Calendar},
types::{alpaca, Calendar},
utils::{backoff, duration_until},
};
use log::info;
use std::sync::Arc;
use time::OffsetDateTime;
use tokio::{join, sync::mpsc, time::sleep};
@@ -22,8 +19,8 @@ pub struct Message {
pub next_switch: OffsetDateTime,
}
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
impl From<alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self {
if clock.is_open {
Self {
status: Status::Open,
@@ -41,23 +38,21 @@ impl From<types::alpaca::api::incoming::clock::Clock> for Message {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock_future = async {
alpaca::clock::get(
alpaca::api::incoming::clock::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
};
let calendar_future = async {
alpaca::calendar::get(
alpaca::api::incoming::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::calendar::Calendar::default(),
&alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
@@ -79,13 +74,9 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
let sleep_future = sleep(sleep_until);
let calendar_future = async {
database::calendar::upsert_batch_and_delete(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar)
.await
.unwrap();
};
join!(sleep_future, calendar_future);

View File

@@ -0,0 +1,413 @@
use super::ThreadType;
use crate::{
config::{
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
MAX_BERT_INPUTS,
},
database,
types::{
alpaca::{self, shared::Source},
news::Prediction,
Backfill, Bar, Class, News,
},
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND},
};
use async_trait::async_trait;
use futures_util::future::join_all;
use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::{block_in_place, JoinHandle},
time::sleep,
try_join,
};
pub enum Action {
Backfill,
Purge,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Backfill),
super::Action::Remove => Some(Action::Purge),
super::Action::Disable => None,
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime);
fn log_string(&self) -> &'static str;
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(HashMap::new()));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_backfill_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_backfill_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
message: Message,
) {
let mut backfill_jobs = backfill_jobs.lock().await;
match message.action {
Some(Action::Backfill) => {
let log_string = handler.log_string();
for symbol in message.symbols {
if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
}
let handler = handler.clone();
backfill_jobs.insert(
symbol.clone(),
spawn(async move {
let fetch_from = match handler
.select_latest_backfill(symbol.clone())
.await
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
handler.queue_backfill(&symbol, fetch_to).await;
handler.backfill(symbol, fetch_from, fetch_to).await;
}),
);
}
}
Some(Action::Purge) => {
for symbol in &message.symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
if !job.is_finished() {
job.abort();
}
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&message.symbols),
handler.delete_data(&message.symbols)
)
.unwrap();
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarHandler {
config: Arc<Config>,
data_url: &'static str,
api_query_constructor: fn(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar,
}
fn us_equity_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
fn crypto_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
#[async_trait]
impl Handler for BarHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
if *ALPACA_SOURCE == Source::Iex {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing bar backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling bars for {}.", symbol);
let mut bars = vec![];
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::bar::get_historical(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbol.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await
else {
error!("Failed to backfill bars for {}.", symbol);
return;
};
message.bars.into_iter().for_each(|(symbol, bar_vec)| {
for bar in bar_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
});
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
}
if bars.is_empty() {
info!("No bars to backfill for {}.", symbol);
return;
}
let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&self.config.clickhouse_client, &bars)
.await
.unwrap();
database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill)
.await
.unwrap();
info!("Backfilled bars for {}.", symbol);
}
fn log_string(&self) -> &'static str {
"bars"
}
}
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing news backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling news for {}.", symbol);
let mut news = vec![];
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::news::get_historical(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News {
symbols: vec![symbol.clone()],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
..Default::default()
},
None,
)
.await
else {
error!("Failed to backfill news for {}.", symbol);
return;
};
message.news.into_iter().for_each(|news_item| {
news.push(News::from(news_item));
});
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
}
if news.is_empty() {
info!("No news to backfill for {}.", symbol);
return;
}
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}))
.await
.into_iter()
.flatten();
let news = news
.into_iter()
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.config.clickhouse_client, &news)
.await
.unwrap();
database::backfills_news::upsert(&self.config.clickhouse_client, &backfill)
.await
.unwrap();
info!("Backfilled news for {}.", symbol);
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler {
config,
data_url: ALPACA_STOCK_DATA_API_URL,
api_query_constructor: us_equity_query_constructor,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarHandler {
config,
data_url: ALPACA_CRYPTO_DATA_API_URL,
api_query_constructor: crypto_query_constructor,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -1,238 +0,0 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::{
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
shared::{Sort, Source},
},
Backfill, Bar, Class,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
pub data_url: &'static str,
pub api_query_constructor: fn(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
feed: Some(*ALPACA_SOURCE),
..Default::default()
})
}
pub fn crypto_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
..Default::default()
})
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbols.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
database::backfills_bars::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled bars for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"bars"
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
_ => unreachable!(),
};
let api_query_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
_ => unreachable!(),
};
Box::new(Handler {
config,
data_url,
api_query_constructor,
})
}

View File

@@ -1,244 +0,0 @@
pub mod bars;
pub mod news;
use async_trait::async_trait;
use itertools::Itertools;
use log::{info, warn};
use nonempty::{nonempty, NonEmpty};
use qrust::{
types::Backfill,
utils::{last_minute, ONE_SECOND},
};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::JoinHandle,
try_join,
};
use uuid::Uuid;
pub enum Action {
Backfill,
Purge,
}
pub struct Message {
pub action: Action,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[derive(Clone)]
pub struct Job {
pub symbol: String,
pub fetch_from: OffsetDateTime,
pub fetch_to: OffsetDateTime,
pub fresh: bool,
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
async fn backfill(&self, jobs: NonEmpty<Job>);
fn max_limit(&self) -> i64;
fn log_string(&self) -> &'static str;
}
pub struct Jobs {
pub symbol_to_uuid: HashMap<String, Uuid>,
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
}
impl Jobs {
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
let uuid = Uuid::new_v4();
for symbol in jobs {
self.symbol_to_uuid.insert(symbol.clone(), uuid);
}
self.uuid_to_job.insert(uuid, fut);
}
pub fn contains_key(&self, symbol: &str) -> bool {
self.symbol_to_uuid.contains_key(symbol)
}
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
self.symbol_to_uuid
.remove(symbol)
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
}
pub fn remove_many<T>(&mut self, symbols: &[T])
where
T: AsRef<str> + Hash + Eq,
{
for symbol in symbols {
self.symbol_to_uuid
.remove(symbol.as_ref())
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
}
}
pub fn len(&self) -> usize {
self.symbol_to_uuid.len()
}
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(Jobs {
symbol_to_uuid: HashMap::new(),
uuid_to_job: HashMap::new(),
}));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<Jobs>>,
message: Message,
) {
let backfill_jobs_clone = backfill_jobs.clone();
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = Vec::from(message.symbols);
match message.action {
Action::Backfill => {
let log_string = handler.log_string();
let max_limit = handler.max_limit();
let backfills = handler
.select_latest_backfills(&symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
let mut jobs = Vec::with_capacity(symbols.len());
for symbol in symbols {
if backfill_jobs.contains_key(&symbol) {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
let backfill = backfills.get(&symbol);
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
jobs.push(Job {
symbol,
fetch_from,
fetch_to,
fresh,
});
}
let jobs = jobs
.into_iter()
.sorted_unstable_by_key(|job| job.fetch_from)
.collect::<Vec<_>>();
let mut job_groups: Vec<NonEmpty<Job>> = vec![];
let mut current_minutes = 0;
for job in jobs {
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
if job_groups.last().is_some() && current_minutes + minutes <= max_limit {
let job_group = job_groups.last_mut().unwrap();
job_group.push(job);
current_minutes += minutes;
} else {
job_groups.push(nonempty![job]);
current_minutes = minutes;
}
}
for job_group in job_groups {
let symbols = job_group
.iter()
.map(|job| job.symbol.clone())
.collect::<Vec<_>>();
let handler = handler.clone();
let symbols_clone = symbols.clone();
let backfill_jobs_clone = backfill_jobs_clone.clone();
let fut = spawn(async move {
handler.queue_backfill(&job_group).await;
handler.backfill(job_group).await;
let mut backfill_jobs = backfill_jobs_clone.lock().await;
backfill_jobs.remove_many(&symbols_clone);
let remaining = backfill_jobs.len();
drop(backfill_jobs);
info!("{} {} backfills remaining.", remaining, log_string);
});
backfill_jobs.insert(symbols, fut);
}
}
Action::Purge => {
for symbol in &symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
job.abort();
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&symbols),
handler.delete_data(&symbols)
)
.unwrap();
}
}
message.response.send(()).unwrap();
}

View File

@@ -1,186 +0,0 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
database,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::shared::{Sort, Source},
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
database::backfills_news::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
Box::new(Handler { config })
}

View File

@@ -3,29 +3,24 @@ mod websocket;
use super::clock;
use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database,
};
use itertools::{Either, Itertools};
use log::error;
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
config::{
Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_SOURCE,
ALPACA_STOCK_DATA_WEBSOCKET_URL,
},
create_send_await, database,
types::{alpaca, Asset, Class},
utils::backoff,
};
use std::{collections::HashMap, sync::Arc};
use futures_util::{future::join_all, StreamExt};
use itertools::{Either, Itertools};
use std::sync::Arc;
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
use tokio_tungstenite::connect_async;
#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy)]
#[allow(dead_code)]
pub enum Action {
Add,
@@ -36,12 +31,12 @@ pub enum Action {
pub struct Message {
pub action: Action,
pub assets: NonEmpty<(String, Class)>,
pub assets: Vec<(String, Class)>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
pub fn new(action: Action, assets: Vec<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
@@ -66,11 +61,11 @@ pub async fn run(
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await;
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await;
let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News);
init_thread(config.clone(), ThreadType::News).await;
loop {
select! {
@@ -99,7 +94,7 @@ pub async fn run(
}
}
fn init_thread(
async fn init_thread(
config: Arc<Config>,
thread_type: ThreadType,
) -> (
@@ -108,32 +103,28 @@ fn init_thread(
) {
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
format!("{}/{}", ALPACA_STOCK_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
}
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
};
let backfill_handler = match thread_type {
ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
ThreadType::News => backfill::news::create_handler(config.clone()),
};
let (websocket, _) = connect_async(websocket_url).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::data::authenticate(&mut websocket_sink, &mut websocket_stream).await;
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(backfill_handler.into(), backfill_receiver));
let websocket_handler = match thread_type {
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
ThreadType::News => websocket::news::create_handler(config),
};
spawn(backfill::run(
Arc::new(backfill::create_handler(thread_type, config.clone())),
backfill_receiver,
));
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run(
websocket_handler.into(),
Arc::new(websocket::create_handler(thread_type, config.clone())),
websocket_receiver,
websocket_url,
websocket_stream,
websocket_sink,
));
(websocket_sender, backfill_sender)
@@ -151,6 +142,11 @@ async fn handle_message(
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message,
) {
if message.assets.is_empty() {
message.response.send(()).unwrap();
return;
}
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets
.clone()
@@ -160,28 +156,50 @@ async fn handle_message(
Class::Crypto => Either::Right(asset.0),
});
let symbols = message.assets.map(|(symbol, _)| symbol);
let symbols = message
.assets
.into_iter()
.map(|(symbol, _)| symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols
);
if us_equity_symbols.is_empty() {
return;
}
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols.clone()
);
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
message.action.into(),
us_equity_symbols
);
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols
);
if crypto_symbols.is_empty() {
return;
}
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols.clone()
);
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
message.action.into(),
crypto_symbols
);
};
let news_future = async {
@@ -191,127 +209,62 @@ async fn handle_message(
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
match message.action {
Action::Add | Action::Enable => {
let symbols = Vec::from(symbols.clone());
let assets = async {
alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let mut batch = Vec::with_capacity(symbols.len());
for symbol in &symbols {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}.", symbol);
}
}
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&Vec::from(symbols.clone()),
)
.await
.unwrap();
}
Action::Disable => unreachable!(),
}
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
match message.action {
Action::Add => {
let assets = join_all(symbols.into_iter().map(|symbol| {
let config = config.clone();
async move {
let asset_future = async {
alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let position_future = async {
alpaca::api::incoming::position::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let (asset, position) = join!(asset_future, position_future);
Asset::from((asset, position))
}
}))
.await;
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols)
.await
.unwrap();
}
_ => {}
}
message.response.send(()).unwrap();
}
@@ -321,19 +274,13 @@ async fn handle_clock_message(
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
) {
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
database::cleanup_all(&config.clickhouse_client)
.await
.unwrap();
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let assets = database::assets::select(&config.clickhouse_client)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()
@@ -349,36 +296,30 @@ async fn handle_clock_message(
.collect::<Vec<_>>();
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
us_equity_symbols
);
}
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
us_equity_symbols.clone()
);
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
crypto_symbols
);
}
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
crypto_symbols.clone()
);
};
let news_future = async {
if let Some(symbols) = NonEmpty::from_vec(symbols) {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
symbols
);
}
create_send_await!(
news_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);

View File

@@ -0,0 +1,427 @@
use super::ThreadType;
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, Class, News},
};
use async_trait::async_trait;
use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::{debug, error, info};
use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::block_in_place,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let pending = pending.clone();
spawn(async move {
handler.handle_websocket_message(pending, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
message: Message,
) {
if message.symbols.is_empty() {
message.response.send(()).unwrap();
return;
}
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.subscriptions
.extend(pending_subscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.unsubscriptions
.extend(pending_unsubscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarsHandler {
config: Arc<Config>,
subscription_message_constructor:
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl Handler for BarsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.config.clickhouse_client, &bar)
.await
.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&self.config.clickhouse_client, &news)
.await
.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_crypto,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -1,171 +0,0 @@
use super::State;
use crate::{
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, Bar, Class},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<Bar>>>,
pub subscription_message_constructor:
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
self.inserter.lock().await.write(&bar).await.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("bars")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
));
let subscription_message_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
websocket::data::outgoing::subscribe::Message::new_market_us_equity
}
ThreadType::Bars(Class::Crypto) => {
websocket::data::outgoing::subscribe::Message::new_market_crypto
}
_ => unreachable!(),
};
Box::new(Handler {
config,
inserter,
subscription_message_constructor,
})
}

View File

@@ -1,353 +0,0 @@
pub mod bars;
pub mod news;
use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait;
use backoff::{future::retry_notify, ExponentialBackoff};
use clickhouse::{inserter::Inserter, Row};
use futures_util::{future::join_all, SinkExt, StreamExt};
use log::error;
use nonempty::NonEmpty;
use qrust::types::alpaca::{self, websocket};
use serde::Serialize;
use serde_json::{from_str, to_string};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct State {
pub active_subscriptions: HashSet<String>,
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync + 'static {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
);
fn log_string(&self) -> &'static str;
async fn run_inserter(&self);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
websocket_url: String,
) {
let state = Arc::new(RwLock::new(State {
active_subscriptions: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let handler_clone = handler.clone();
spawn(async move { handler_clone.run_inserter().await });
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
state.clone(),
sink_sender.clone(),
message,
));
}
Some(message) = stream_receiver.recv() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let state = state.clone();
spawn(async move {
handler.handle_websocket_message(state, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<State>>,
sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message,
) {
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
pending
.write()
.await
.pending_subscriptions
.extend(pending_subscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.pending_unsubscriptions
.extend(pending_unsubscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>)
where
T: Row + Serialize,
{
loop {
let time_left = inserter.lock().await.time_left().unwrap();
tokio::time::sleep(time_left).await;
inserter.lock().await.commit().await.unwrap();
}
}

View File

@@ -1,116 +0,0 @@
use super::State;
use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, News},
utils::ONE_SECOND,
};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<News>>>,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
self.inserter.lock().await.write(&news).await.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"news"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("news")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()),
));
Box::new(Handler { config, inserter })
}

View File

@@ -1,26 +1,19 @@
mod websocket;
use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET};
use crate::{
config::{Config, ALPACA_WEBSOCKET_URL},
types::alpaca,
};
use futures_util::StreamExt;
use qrust::types::alpaca;
use std::sync::Arc;
use tokio::spawn;
use tokio_tungstenite::connect_async;
pub async fn run(config: Arc<Config>) {
let (websocket, _) =
connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE))
.await
.unwrap();
let (websocket, _) = connect_async(&*ALPACA_WEBSOCKET_URL).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::trading::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
alpaca::websocket::trading::authenticate(&mut websocket_sink, &mut websocket_stream).await;
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
spawn(websocket::run(config, websocket_stream));

View File

@@ -1,7 +1,10 @@
use crate::{config::Config, database};
use crate::{
config::Config,
database,
types::{alpaca::websocket, Order},
};
use futures_util::{stream::SplitStream, StreamExt};
use log::{debug, error};
use qrust::types::{alpaca::websocket, Order};
use serde_json::from_str;
use std::sync::Arc;
use tokio::{net::TcpStream, spawn};
@@ -21,7 +24,7 @@ pub async fn run(
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
error!("Failed to deserialize websocket message: {:?}", message);
continue;
}
@@ -31,7 +34,7 @@ pub async fn run(
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
_ => error!("Unexpected websocket message: {:?}", message),
}
}
}
@@ -43,19 +46,15 @@ async fn handle_websocket_message(
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}.",
"Received order message for {}: {:?}",
message.order.symbol, message.event
);
let order = Order::from(message.order);
database::orders::upsert(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
database::orders::upsert(&config.clickhouse_client, &order)
.await
.unwrap();
match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
@@ -64,7 +63,6 @@ async fn handle_websocket_message(
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol,
position_qty,
)

View File

@@ -1,7 +1,13 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use std::time::Duration;
use time::OffsetDateTime;
use uuid::Uuid;
@@ -73,3 +79,38 @@ pub struct Account {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/account", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,86 @@
use super::position::Position;
use crate::{
config::ALPACA_API_URL,
types::{
self,
alpaca::shared::asset::{Class, Exchange, Status},
},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use std::time::Duration;
use uuid::Uuid;
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,89 @@
use crate::types::{self, alpaca::api::outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Bar {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "o")]
pub open: f64,
#[serde(rename = "h")]
pub high: f64,
#[serde(rename = "l")]
pub low: f64,
#[serde(rename = "c")]
pub close: f64,
#[serde(rename = "v")]
pub volume: f64,
#[serde(rename = "n")]
pub trades: i64,
#[serde(rename = "vw")]
pub vwap: f64,
}
impl From<(Bar, String)> for types::Bar {
fn from((bar, symbol): (Bar, String)) -> Self {
Self {
time: bar.time,
symbol,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,69 @@
use crate::{
config::ALPACA_API_URL,
types::{self, alpaca::api::outgoing},
utils::{de, time::EST_OFFSET},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/calendar", *ALPACA_API_URL))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,54 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime,
pub is_open: bool,
#[serde(with = "time::serde::rfc3339")]
pub next_open: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/clock", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,111 @@
use crate::{
config::ALPACA_NEWS_DATA_API_URL,
types::{
self,
alpaca::{api::outgoing, shared::news::normalize_html_content},
},
utils::de,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageSize {
Thumb,
Small,
Large,
}
#[derive(Deserialize)]
pub struct Image {
pub size: ImageSize,
pub url: String,
}
#[derive(Deserialize)]
pub struct News {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
pub images: Vec<Image>,
}
impl From<News> for types::News {
fn from(news: News) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: normalize_html_content(&news.headline),
author: normalize_html_content(&news.author),
source: normalize_html_content(&news.source),
summary: normalize_html_content(&news.summary),
content: normalize_html_content(&news.content),
sentiment: types::news::Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(),
}
}
}
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,39 +1,44 @@
use super::error_to_backoff;
use crate::types::alpaca::{api::outgoing, shared::order};
use crate::{
config::ALPACA_API_URL,
types::alpaca::{api::outgoing, shared},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use order::Order;
pub use shared::order::Order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/orders", *ALPACA_API_URL))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.await?
.error_for_status()
.map_err(error_to_backoff)?
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(error_to_backoff)
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}.",
"Failed to get orders, will retry in {} seconds: {}",
duration.as_secs(),
e
);

View File

@@ -0,0 +1,145 @@
use crate::{
config::ALPACA_API_URL,
types::alpaca::shared::{
self,
asset::{Class, Exchange},
},
utils::de,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use std::time::Duration;
use uuid::Uuid;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for shared::order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/positions", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
let response = alpaca_client
.get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,2 @@
pub mod incoming;
pub mod outgoing;

View File

@@ -1,14 +1,12 @@
use crate::{
alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
config::ALPACA_SOURCE,
types::alpaca::shared::{Sort, Source},
utils::{ser, ONE_MINUTE},
};
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
pub use shared::{Sort, Source};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
@@ -55,10 +53,10 @@ impl Default for UsEquity {
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
limit: Some(10000),
adjustment: Some(Adjustment::All),
asof: None,
feed: Some(Source::Iex),
feed: Some(*ALPACA_SOURCE),
currency: None,
page_token: None,
sort: Some(Sort::Asc),
@@ -93,7 +91,7 @@ impl Default for Crypto {
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
limit: Some(10000),
page_token: None,
sort: Some(Sort::Asc),
}

View File

@@ -1,4 +1,3 @@
pub mod asset;
pub mod bar;
pub mod calendar;
pub mod news;

View File

@@ -1,10 +1,10 @@
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use crate::{types::alpaca::shared::Sort, utils::ser};
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
pub struct News {
#[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]
pub symbols: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
@@ -30,7 +30,7 @@ impl Default for News {
symbols: vec![],
start: None,
end: None,
limit: Some(MAX_LIMIT),
limit: Some(50),
include_content: Some(true),
exclude_contentless: Some(false),
page_token: None,

View File

@@ -1,12 +1,10 @@
use crate::{
types::alpaca::shared::{order, Sort},
types::alpaca::shared::{order::Side, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]

View File

@@ -1,7 +1,7 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Serialize, Deserialize, Clone, Copy)]
#[derive(Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto
);
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,

View File

@@ -1,3 +1,4 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static;
use regex::Regex;
@@ -6,10 +7,12 @@ lazy_static! {
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
}
pub fn strip(content: &str) -> String {
pub fn normalize_html_content(content: &str) -> String {
let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim();
content.to_string()
}

View File

@@ -1,5 +1,5 @@
use crate::{
types::{alpaca::shared::news::strip, News},
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News},
utils::de,
};
use serde::Deserialize;
@@ -31,11 +31,13 @@ impl From<Message> for News {
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
headline: normalize_html_content(&news.headline),
author: normalize_html_content(&news.author),
source: normalize_html_content(&news.source),
summary: normalize_html_content(&news.summary),
content: normalize_html_content(&news.content),
sentiment: Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(),
}
}

View File

@@ -1,7 +1,10 @@
pub mod incoming;
pub mod outgoing;
use crate::types::alpaca::websocket;
use crate::{
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
@@ -14,8 +17,6 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
@@ -31,8 +32,8 @@ pub async fn authenticate(
sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message {
key: api_key,
secret: api_secret,
key: (*ALPACA_API_KEY).clone(),
secret: (*ALPACA_API_SECRET).clone(),
},
))
.unwrap(),

View File

@@ -1,5 +1,4 @@
use crate::utils::ser;
use nonempty::NonEmpty;
use serde::Serialize;
#[derive(Serialize)]
@@ -7,14 +6,14 @@ use serde::Serialize;
pub enum Market {
#[serde(rename_all = "camelCase")]
UsEquity {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
statuses: NonEmpty<String>,
bars: Vec<String>,
updated_bars: Vec<String>,
statuses: Vec<String>,
},
#[serde(rename_all = "camelCase")]
Crypto {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
bars: Vec<String>,
updated_bars: Vec<String>,
},
}
@@ -24,12 +23,12 @@ pub enum Message {
Market(Market),
News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: NonEmpty<String>,
news: Vec<String>,
},
}
impl Message {
pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
pub fn new_market_us_equity(symbols: Vec<String>) -> Self {
Self::Market(Market::UsEquity {
bars: symbols.clone(),
updated_bars: symbols.clone(),
@@ -37,14 +36,14 @@ impl Message {
})
}
pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
pub fn new_market_crypto(symbols: Vec<String>) -> Self {
Self::Market(Market::Crypto {
bars: symbols.clone(),
updated_bars: symbols,
})
}
pub fn new_news(symbols: NonEmpty<String>) -> Self {
pub fn new_news(symbols: Vec<String>) -> Self {
Self::News { news: symbols }
}
}

View File

@@ -0,0 +1,3 @@
pub mod auth;
pub mod data;
pub mod trading;

View File

@@ -1,10 +1,10 @@
use crate::types::alpaca::shared::order;
use crate::types::alpaca::shared;
use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime;
use uuid::Uuid;
pub use order::Order;
pub use shared::order::Order;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]

View File

@@ -1,7 +1,10 @@
pub mod incoming;
pub mod outgoing;
use crate::types::alpaca::websocket;
use crate::{
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
@@ -14,14 +17,12 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Auth(
websocket::auth::Message {
key: api_key,
secret: api_secret,
key: (*ALPACA_API_KEY).clone(),
secret: (*ALPACA_API_SECRET).clone(),
},
))
.unwrap(),

Some files were not shown because too many files have changed in this diff Show More