Compare commits
28 Commits
Author | SHA1 | Date | |
---|---|---|---|
90b7f10a77 | |||
d7e9350257 | |||
f715881b07 | |||
d0ad9f65b1 | |||
ce8c4db422 | |||
46508d1b4f | |||
2ad42c5462 | |||
733e6373e9 | |||
d072b849c0 | |||
718e794f51 | |||
b7a175d5b4 | |||
e9012d6ec3 | |||
10365745aa | |||
8202255132 | |||
0d276d537c | |||
1707d74cf7 | |||
f3f9c6336b | |||
5ed0c7670a | |||
d2d20e2978 | |||
d02f958865 | |||
2d8972dce2 | |||
7bacc2565a | |||
b60cbc891d | |||
2de86b46f7 | |||
8c7ee3d12d | |||
a15fd2c3c9 | |||
acfc0ca4c9 | |||
681d7393d7 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
log/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
@@ -10,6 +11,3 @@ target/
|
||||
*.pdb
|
||||
|
||||
.env*
|
||||
|
||||
# ML models
|
||||
models/*/rust_model.ot
|
||||
|
@@ -22,7 +22,7 @@ build:
|
||||
cache:
|
||||
<<: *global_cache
|
||||
script:
|
||||
- cargo +nightly build
|
||||
- cargo +nightly build --workspace
|
||||
|
||||
test:
|
||||
image: registry.karaolidis.com/karaolidis/qrust/rust
|
||||
@@ -30,7 +30,7 @@ test:
|
||||
cache:
|
||||
<<: *global_cache
|
||||
script:
|
||||
- cargo +nightly test
|
||||
- cargo +nightly test --workspace
|
||||
|
||||
lint:
|
||||
image: registry.karaolidis.com/karaolidis/qrust/rust
|
||||
@@ -39,7 +39,7 @@ lint:
|
||||
<<: *global_cache
|
||||
script:
|
||||
- cargo +nightly fmt --all -- --check
|
||||
- cargo +nightly clippy --all-targets --all-features
|
||||
- cargo +nightly clippy --workspace --all-targets --all-features
|
||||
|
||||
depcheck:
|
||||
image: registry.karaolidis.com/karaolidis/qrust/rust
|
||||
@@ -48,7 +48,7 @@ depcheck:
|
||||
<<: *global_cache
|
||||
script:
|
||||
- cargo +nightly outdated
|
||||
- cargo +nightly udeps
|
||||
- cargo +nightly udeps --workspace --all-targets
|
||||
|
||||
build-release:
|
||||
image: registry.karaolidis.com/karaolidis/qrust/rust
|
||||
|
2765
Cargo.lock
generated
2765
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
65
Cargo.toml
65
Cargo.toml
@@ -3,6 +3,18 @@ name = "qrust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "qrust"
|
||||
path = "src/lib/qrust/mod.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "qrust"
|
||||
path = "src/bin/qrust/mod.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "trainer"
|
||||
path = "src/bin/trainer/mod.rs"
|
||||
|
||||
[profile.release]
|
||||
panic = 'abort'
|
||||
strip = true
|
||||
@@ -12,9 +24,9 @@ codegen-units = 1
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = "0.7.4"
|
||||
axum = "0.7.5"
|
||||
dotenv = "0.15.0"
|
||||
tokio = { version = "1.32.0", features = [
|
||||
tokio = { version = "1.37.0", features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
@@ -22,29 +34,29 @@ tokio-tungstenite = { version = "0.21.0", features = [
|
||||
"tokio-native-tls",
|
||||
"native-tls",
|
||||
] }
|
||||
log = "0.4.20"
|
||||
log4rs = "1.2.0"
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.105"
|
||||
serde_repr = "0.1.18"
|
||||
serde_with = "3.6.1"
|
||||
serde-aux = "4.4.0"
|
||||
futures-util = "0.3.28"
|
||||
reqwest = { version = "0.11.20", features = [
|
||||
log = "0.4.21"
|
||||
log4rs = "1.3.0"
|
||||
serde = "1.0.201"
|
||||
serde_json = "1.0.117"
|
||||
serde_repr = "0.1.19"
|
||||
serde_with = "3.8.1"
|
||||
serde-aux = "4.5.0"
|
||||
futures-util = "0.3.30"
|
||||
reqwest = { version = "0.12.4", features = [
|
||||
"json",
|
||||
"serde_json",
|
||||
] }
|
||||
http = "1.0.0"
|
||||
governor = "0.6.0"
|
||||
http = "1.1.0"
|
||||
governor = "0.6.3"
|
||||
clickhouse = { version = "0.11.6", features = [
|
||||
"watch",
|
||||
"time",
|
||||
"uuid",
|
||||
] }
|
||||
uuid = { version = "1.6.1", features = [
|
||||
uuid = { version = "1.8.0", features = [
|
||||
"serde",
|
||||
"v4",
|
||||
] }
|
||||
time = { version = "0.3.31", features = [
|
||||
time = { version = "0.3.36", features = [
|
||||
"serde",
|
||||
"serde-well-known",
|
||||
"serde-human-readable",
|
||||
@@ -55,9 +67,22 @@ time = { version = "0.3.31", features = [
|
||||
backoff = { version = "0.4.0", features = [
|
||||
"tokio",
|
||||
] }
|
||||
regex = "1.10.3"
|
||||
html-escape = "0.2.13"
|
||||
rust-bert = "0.22.0"
|
||||
async-trait = "0.1.77"
|
||||
regex = "1.10.4"
|
||||
async-trait = "0.1.80"
|
||||
itertools = "0.12.1"
|
||||
lazy_static = "1.4.0"
|
||||
nonempty = { version = "0.10.0", features = [
|
||||
"serialize",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
rayon = "1.10.0"
|
||||
burn = { version = "0.13.2", features = [
|
||||
"wgpu",
|
||||
"cuda",
|
||||
"tui",
|
||||
"metrics",
|
||||
"train",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_test = "1.0.176"
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# qrust
|
||||
|
||||

|
||||

|
||||
|
||||
`qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.
|
||||
|
@@ -4,7 +4,14 @@ appenders:
|
||||
encoder:
|
||||
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
|
||||
|
||||
file:
|
||||
kind: file
|
||||
path: "./log/output.log"
|
||||
encoder:
|
||||
pattern: "{d} {l} {M}::{L} - {m}{n}"
|
||||
|
||||
root:
|
||||
level: info
|
||||
appenders:
|
||||
- stdout
|
||||
- file
|
||||
|
@@ -1,32 +0,0 @@
|
||||
{
|
||||
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
|
||||
"architectures": [
|
||||
"BertForSequenceClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "positive",
|
||||
"1": "negative",
|
||||
"2": "neutral"
|
||||
},
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"label2id": {
|
||||
"positive": 0,
|
||||
"negative": 1,
|
||||
"neutral": 2
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
@@ -1 +0,0 @@
|
||||
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
@@ -1 +0,0 @@
|
||||
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}
|
30522
models/finbert/vocab.txt
30522
models/finbert/vocab.txt
File diff suppressed because it is too large
Load Diff
86
src/bin/qrust/config.rs
Normal file
86
src/bin/qrust/config.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
|
||||
use lazy_static::lazy_static;
|
||||
use qrust::types::alpaca::shared::{Mode, Source};
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderName, HeaderValue},
|
||||
Client,
|
||||
};
|
||||
use std::{env, num::NonZeroU32, sync::Arc};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
||||
.expect("ALPACA_MODE must be set.")
|
||||
.parse()
|
||||
.expect("ALPACA_MODE must be 'live' or 'paper'");
|
||||
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
|
||||
Mode::Live => String::from("api"),
|
||||
Mode::Paper => String::from("paper-api"),
|
||||
};
|
||||
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
|
||||
.expect("ALPACA_SOURCE must be set.")
|
||||
.parse()
|
||||
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
|
||||
pub static ref ALPACA_API_KEY: String =
|
||||
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
||||
pub static ref ALPACA_API_SECRET: String =
|
||||
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
||||
pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
|
||||
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
|
||||
.parse()
|
||||
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
|
||||
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
|
||||
.parse()
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
|
||||
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
|
||||
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
|
||||
.parse()
|
||||
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub alpaca_client: Client,
|
||||
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
||||
pub clickhouse_client: clickhouse::Client,
|
||||
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> Self {
|
||||
Self {
|
||||
alpaca_client: Client::builder()
|
||||
.default_headers(HeaderMap::from_iter([
|
||||
(
|
||||
HeaderName::from_static("apca-api-key-id"),
|
||||
HeaderValue::from_str(&ALPACA_API_KEY)
|
||||
.expect("Alpaca API key must not contain invalid characters."),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("apca-api-secret-key"),
|
||||
HeaderValue::from_str(&ALPACA_API_SECRET)
|
||||
.expect("Alpaca API secret must not contain invalid characters."),
|
||||
),
|
||||
]))
|
||||
.build()
|
||||
.unwrap(),
|
||||
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
|
||||
Source::Iex => NonZeroU32::new(200).unwrap(),
|
||||
Source::Sip => NonZeroU32::new(10_000).unwrap(),
|
||||
Source::Otc => unimplemented!("OTC rate limit not implemented."),
|
||||
})),
|
||||
clickhouse_client: clickhouse::Client::default()
|
||||
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
|
||||
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
|
||||
.with_password(
|
||||
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
|
||||
)
|
||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
||||
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arc_from_env() -> Arc<Self> {
|
||||
Arc::new(Self::from_env())
|
||||
}
|
||||
}
|
@@ -1,24 +1,25 @@
|
||||
use crate::{
|
||||
config::{Config, ALPACA_MODE},
|
||||
config::{Config, ALPACA_API_BASE},
|
||||
database,
|
||||
types::alpaca,
|
||||
};
|
||||
use log::{info, warn};
|
||||
use qrust::{alpaca, types};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::join;
|
||||
|
||||
pub async fn check_account(config: &Arc<Config>) {
|
||||
let account = alpaca::api::incoming::account::get(
|
||||
let account = alpaca::account::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!(account.status != alpaca::api::incoming::account::Status::Active),
|
||||
!(account.status != types::alpaca::api::incoming::account::Status::Active),
|
||||
"Account status is not active: {:?}.",
|
||||
account.status
|
||||
);
|
||||
@@ -33,56 +34,60 @@ pub async fn check_account(config: &Arc<Config>) {
|
||||
warn!("Account cash is zero, qrust will not be able to trade.");
|
||||
}
|
||||
|
||||
warn!(
|
||||
"qrust active on {} account with {} {}, avoid transferring funds without shutting down.",
|
||||
*ALPACA_MODE, account.currency, account.cash
|
||||
info!(
|
||||
"qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
|
||||
*ALPACA_API_BASE, account.currency, account.cash
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn rehydrate_orders(config: &Arc<Config>) {
|
||||
info!("Rehydrating order data.");
|
||||
|
||||
let mut orders = vec![];
|
||||
let mut after = OffsetDateTime::UNIX_EPOCH;
|
||||
|
||||
while let Some(message) = alpaca::api::incoming::order::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&alpaca::api::outgoing::order::Order {
|
||||
status: Some(alpaca::api::outgoing::order::Status::All),
|
||||
after: Some(after),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.ok()
|
||||
.filter(|message| !message.is_empty())
|
||||
{
|
||||
loop {
|
||||
let message = alpaca::orders::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&types::alpaca::api::outgoing::order::Order {
|
||||
status: Some(types::alpaca::api::outgoing::order::Status::All),
|
||||
after: Some(after),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
orders.extend(message);
|
||||
after = orders.last().unwrap().submitted_at;
|
||||
}
|
||||
|
||||
let orders = orders
|
||||
.into_iter()
|
||||
.flat_map(&alpaca::api::incoming::order::Order::normalize)
|
||||
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::orders::upsert_batch(&config.clickhouse_client, &orders)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Rehydrated order data.");
|
||||
database::orders::upsert_batch(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&orders,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub async fn rehydrate_positions(config: &Arc<Config>) {
|
||||
info!("Rehydrating position data.");
|
||||
|
||||
let positions_future = async {
|
||||
alpaca::api::incoming::position::get(
|
||||
alpaca::positions::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
@@ -92,9 +97,12 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
|
||||
};
|
||||
|
||||
let assets_future = async {
|
||||
database::assets::select(&config.clickhouse_client)
|
||||
.await
|
||||
.unwrap()
|
||||
database::assets::select(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let (mut positions, assets) = join!(positions_future, assets_future);
|
||||
@@ -111,9 +119,13 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::assets::upsert_batch(&config.clickhouse_client, &assets)
|
||||
.await
|
||||
.unwrap();
|
||||
database::assets::upsert_batch(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&assets,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for position in positions.values() {
|
||||
warn!(
|
||||
@@ -121,6 +133,4 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
|
||||
position.symbol, position.qty
|
||||
);
|
||||
}
|
||||
|
||||
info!("Rehydrated position data.");
|
||||
}
|
115
src/bin/qrust/mod.rs
Normal file
115
src/bin/qrust/mod.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
#![feature(hash_extract_if)]
|
||||
|
||||
mod config;
|
||||
mod init;
|
||||
mod routes;
|
||||
mod threads;
|
||||
|
||||
use config::{
|
||||
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
|
||||
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
|
||||
};
|
||||
use dotenv::dotenv;
|
||||
use log::info;
|
||||
use log4rs::config::Deserializers;
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{create_send_await, database};
|
||||
use tokio::{join, spawn, sync::mpsc, try_join};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenv().ok();
|
||||
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
|
||||
let config = Config::arc_from_env();
|
||||
|
||||
let _ = *ALPACA_MODE;
|
||||
let _ = *ALPACA_API_BASE;
|
||||
let _ = *ALPACA_SOURCE;
|
||||
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
|
||||
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
|
||||
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
|
||||
|
||||
info!("Marking all assets as stale.");
|
||||
|
||||
let assets = database::assets::select(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|asset| (asset.symbol, asset.class))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
|
||||
|
||||
try_join!(
|
||||
database::backfills_bars::set_fresh_where_symbols(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
false,
|
||||
&symbols
|
||||
),
|
||||
database::backfills_news::set_fresh_where_symbols(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
false,
|
||||
&symbols
|
||||
)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
info!("Cleaning up database.");
|
||||
|
||||
database::cleanup_all(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Optimizing database.");
|
||||
|
||||
database::optimize_all(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Rehydrating account data.");
|
||||
|
||||
init::check_account(&config).await;
|
||||
join!(
|
||||
init::rehydrate_orders(&config),
|
||||
init::rehydrate_positions(&config)
|
||||
);
|
||||
|
||||
info!("Starting threads.");
|
||||
|
||||
spawn(threads::trading::run(config.clone()));
|
||||
|
||||
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
|
||||
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
|
||||
|
||||
spawn(threads::data::run(
|
||||
config.clone(),
|
||||
data_receiver,
|
||||
clock_receiver,
|
||||
));
|
||||
|
||||
spawn(threads::clock::run(config.clone(), clock_sender));
|
||||
|
||||
if let Some(assets) = NonEmpty::from_vec(assets) {
|
||||
create_send_await!(
|
||||
data_sender,
|
||||
threads::data::Message::new,
|
||||
threads::data::Action::Enable,
|
||||
assets
|
||||
);
|
||||
}
|
||||
|
||||
routes::run(config, data_sender).await;
|
||||
}
|
197
src/bin/qrust/routes/assets.rs
Normal file
197
src/bin/qrust/routes/assets.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use crate::{
|
||||
config::{Config, ALPACA_API_BASE},
|
||||
create_send_await, database, threads,
|
||||
};
|
||||
use axum::{extract::Path, Extension, Json};
|
||||
use http::StatusCode;
|
||||
use nonempty::{nonempty, NonEmpty};
|
||||
use qrust::{
|
||||
alpaca,
|
||||
types::{self, Asset},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn get(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
|
||||
let assets = database::assets::select(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok((StatusCode::OK, Json(assets)))
|
||||
}
|
||||
|
||||
pub async fn get_where_symbol(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Path(symbol): Path<String>,
|
||||
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
|
||||
let asset = database::assets::select_where_symbol(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&symbol,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
|
||||
Ok((StatusCode::OK, Json(asset)))
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AddAssetsRequest {
|
||||
symbols: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AddAssetsResponse {
|
||||
added: Vec<String>,
|
||||
skipped: Vec<String>,
|
||||
failed: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn add(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
|
||||
Json(request): Json<AddAssetsRequest>,
|
||||
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
|
||||
let database_symbols = database::assets::select(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.into_iter()
|
||||
.map(|asset| asset.symbol)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let mut alpaca_assets = alpaca::assets::get_by_symbols(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&request.symbols,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.into_iter()
|
||||
.map(|asset| (asset.symbol.clone(), asset))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let num_symbols = request.symbols.len();
|
||||
let (assets, skipped, failed) = request.symbols.into_iter().fold(
|
||||
(Vec::with_capacity(num_symbols), vec![], vec![]),
|
||||
|(mut assets, mut skipped, mut failed), symbol| {
|
||||
if database_symbols.contains(&symbol) {
|
||||
skipped.push(symbol);
|
||||
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
|
||||
if asset.status == types::alpaca::api::incoming::asset::Status::Active
|
||||
&& asset.tradable
|
||||
&& asset.fractionable
|
||||
{
|
||||
assets.push((asset.symbol, asset.class.into()));
|
||||
} else {
|
||||
failed.push(asset.symbol);
|
||||
}
|
||||
} else {
|
||||
failed.push(symbol);
|
||||
}
|
||||
|
||||
(assets, skipped, failed)
|
||||
},
|
||||
);
|
||||
|
||||
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
|
||||
create_send_await!(
|
||||
data_sender,
|
||||
threads::data::Message::new,
|
||||
threads::data::Action::Add,
|
||||
assets
|
||||
);
|
||||
}
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(AddAssetsResponse {
|
||||
added: assets.into_iter().map(|asset| asset.0).collect(),
|
||||
skipped,
|
||||
failed,
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn add_symbol(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
|
||||
Path(symbol): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
if database::assets::select_where_symbol(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&symbol,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.is_some()
|
||||
{
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let asset = alpaca::assets::get_by_symbol(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&symbol,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|
||||
|| !asset.tradable
|
||||
|| !asset.fractionable
|
||||
{
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
create_send_await!(
|
||||
data_sender,
|
||||
threads::data::Message::new,
|
||||
threads::data::Action::Add,
|
||||
nonempty![(asset.symbol, asset.class.into())]
|
||||
);
|
||||
|
||||
Ok(StatusCode::CREATED)
|
||||
}
|
||||
|
||||
pub async fn delete(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
|
||||
Path(symbol): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let asset = database::assets::select_where_symbol(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&symbol,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
create_send_await!(
|
||||
data_sender,
|
||||
threads::data::Message::new,
|
||||
threads::data::Action::Remove,
|
||||
nonempty![(asset.symbol, asset.class)]
|
||||
);
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
@@ -16,6 +16,7 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
|
||||
.route("/assets", get(assets::get))
|
||||
.route("/assets/:symbol", get(assets::get_where_symbol))
|
||||
.route("/assets", post(assets::add))
|
||||
.route("/assets/:symbol", post(assets::add_symbol))
|
||||
.route("/assets/:symbol", delete(assets::delete))
|
||||
.layer(Extension(config))
|
||||
.layer(Extension(data_sender));
|
@@ -1,14 +1,17 @@
|
||||
use crate::{
|
||||
config::Config,
|
||||
config::{Config, ALPACA_API_BASE},
|
||||
database,
|
||||
types::{alpaca, Calendar},
|
||||
utils::{backoff, duration_until},
|
||||
};
|
||||
use log::info;
|
||||
use qrust::{
|
||||
alpaca,
|
||||
types::{self, Calendar},
|
||||
utils::{backoff, duration_until},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{join, sync::mpsc, time::sleep};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum Status {
|
||||
Open,
|
||||
Closed,
|
||||
@@ -16,21 +19,16 @@ pub enum Status {
|
||||
|
||||
pub struct Message {
|
||||
pub status: Status,
|
||||
pub next_switch: OffsetDateTime,
|
||||
}
|
||||
|
||||
impl From<alpaca::api::incoming::clock::Clock> for Message {
|
||||
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self {
|
||||
if clock.is_open {
|
||||
Self {
|
||||
status: Status::Open,
|
||||
next_switch: clock.next_close,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
status: Status::Closed,
|
||||
next_switch: clock.next_open,
|
||||
}
|
||||
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
|
||||
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
|
||||
Self {
|
||||
status: if clock.is_open {
|
||||
Status::Open
|
||||
} else {
|
||||
Status::Closed
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -38,21 +36,23 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
|
||||
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
|
||||
loop {
|
||||
let clock_future = async {
|
||||
alpaca::api::incoming::clock::get(
|
||||
alpaca::clock::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
Some(backoff::infinite()),
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let calendar_future = async {
|
||||
alpaca::api::incoming::calendar::get(
|
||||
alpaca::calendar::get(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&alpaca::api::outgoing::calendar::Calendar::default(),
|
||||
&types::alpaca::api::outgoing::calendar::Calendar::default(),
|
||||
Some(backoff::infinite()),
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
@@ -74,9 +74,13 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
|
||||
let sleep_future = sleep(sleep_until);
|
||||
|
||||
let calendar_future = async {
|
||||
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar)
|
||||
.await
|
||||
.unwrap();
|
||||
database::calendar::upsert_batch_and_delete(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&calendar,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
join!(sleep_future, calendar_future);
|
238
src/bin/qrust/threads/data/backfill/bars.rs
Normal file
238
src/bin/qrust/threads/data/backfill/bars.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
|
||||
database,
|
||||
threads::data::ThreadType,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::{error, info};
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{
|
||||
alpaca,
|
||||
types::{
|
||||
self,
|
||||
alpaca::{
|
||||
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
|
||||
shared::{Sort, Source},
|
||||
},
|
||||
Backfill, Bar, Class,
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
pub data_url: &'static str,
|
||||
pub api_query_constructor: fn(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> types::alpaca::api::outgoing::bar::Bar,
|
||||
}
|
||||
|
||||
pub fn us_equity_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> types::alpaca::api::outgoing::bar::Bar {
|
||||
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
feed: Some(*ALPACA_SOURCE),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn crypto_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> types::alpaca::api::outgoing::bar::Bar {
|
||||
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_bars::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
|
||||
|
||||
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
async fn backfill(&self, jobs: NonEmpty<Job>) {
|
||||
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
|
||||
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
|
||||
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
|
||||
let freshness = jobs
|
||||
.into_iter()
|
||||
.map(|job| (job.symbol, job.fresh))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
|
||||
let mut last_times = HashMap::new();
|
||||
let mut next_page_token = None;
|
||||
|
||||
info!("Backfilling bars for {:?}.", symbols);
|
||||
|
||||
loop {
|
||||
let message = alpaca::bars::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
self.data_url,
|
||||
&(self.api_query_constructor)(
|
||||
symbols.clone(),
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
next_page_token.clone(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = message {
|
||||
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
|
||||
return;
|
||||
}
|
||||
|
||||
let message = message.unwrap();
|
||||
|
||||
for (symbol, bars_vec) in message.bars {
|
||||
if let Some(last) = bars_vec.last() {
|
||||
last_times.insert(symbol.clone(), last.time);
|
||||
}
|
||||
|
||||
for bar in bars_vec {
|
||||
bars.push(Bar::from((bar, symbol.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
database::bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bars,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill {
|
||||
fresh: freshness[&symbol],
|
||||
symbol,
|
||||
time,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::backfills_bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
bars.clear();
|
||||
}
|
||||
|
||||
database::backfills_bars::set_fresh_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
true,
|
||||
&symbols,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Backfilled bars for {:?}.", symbols);
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::bars::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"bars"
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
|
||||
let data_url = match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
|
||||
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let api_query_constructor = match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
|
||||
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Box::new(Handler {
|
||||
config,
|
||||
data_url,
|
||||
api_query_constructor,
|
||||
})
|
||||
}
|
243
src/bin/qrust/threads/data/backfill/mod.rs
Normal file
243
src/bin/qrust/threads/data/backfill/mod.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
pub mod bars;
|
||||
pub mod news;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use itertools::Itertools;
|
||||
use log::{info, warn};
|
||||
use nonempty::{nonempty, NonEmpty};
|
||||
use qrust::{
|
||||
types::Backfill,
|
||||
utils::{last_minute, ONE_SECOND},
|
||||
};
|
||||
use std::{collections::HashMap, hash::Hash, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
task::JoinHandle,
|
||||
try_join,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub enum Action {
|
||||
Backfill,
|
||||
Purge,
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Action,
|
||||
pub symbols: NonEmpty<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel::<()>();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Job {
|
||||
pub symbol: String,
|
||||
pub fetch_from: OffsetDateTime,
|
||||
pub fetch_to: OffsetDateTime,
|
||||
pub fresh: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
|
||||
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
|
||||
async fn backfill(&self, jobs: NonEmpty<Job>);
|
||||
fn max_limit(&self) -> i64;
|
||||
fn log_string(&self) -> &'static str;
|
||||
}
|
||||
|
||||
pub struct Jobs {
|
||||
pub symbol_to_uuid: HashMap<String, Uuid>,
|
||||
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl Jobs {
|
||||
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
|
||||
let uuid = Uuid::new_v4();
|
||||
for symbol in jobs {
|
||||
self.symbol_to_uuid.insert(symbol.clone(), uuid);
|
||||
}
|
||||
self.uuid_to_job.insert(uuid, fut);
|
||||
}
|
||||
|
||||
pub fn contains_key(&self, symbol: &str) -> bool {
|
||||
self.symbol_to_uuid.contains_key(symbol)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
|
||||
self.symbol_to_uuid
|
||||
.remove(symbol)
|
||||
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
|
||||
}
|
||||
|
||||
pub fn remove_many<T>(&mut self, symbols: &[T])
|
||||
where
|
||||
T: AsRef<str> + Hash + Eq,
|
||||
{
|
||||
for symbol in symbols {
|
||||
self.symbol_to_uuid
|
||||
.remove(symbol.as_ref())
|
||||
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.symbol_to_uuid.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
|
||||
let backfill_jobs = Arc::new(Mutex::new(Jobs {
|
||||
symbol_to_uuid: HashMap::new(),
|
||||
uuid_to_job: HashMap::new(),
|
||||
}));
|
||||
|
||||
loop {
|
||||
let message = receiver.recv().await.unwrap();
|
||||
spawn(handle_message(
|
||||
handler.clone(),
|
||||
backfill_jobs.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
backfill_jobs: Arc<Mutex<Jobs>>,
|
||||
message: Message,
|
||||
) {
|
||||
let backfill_jobs_clone = backfill_jobs.clone();
|
||||
let mut backfill_jobs = backfill_jobs.lock().await;
|
||||
let symbols = Vec::from(message.symbols);
|
||||
|
||||
match message.action {
|
||||
Action::Backfill => {
|
||||
let log_string = handler.log_string();
|
||||
let max_limit = handler.max_limit();
|
||||
|
||||
let backfills = handler
|
||||
.select_latest_backfills(&symbols)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|backfill| (backfill.symbol.clone(), backfill))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut jobs = Vec::with_capacity(symbols.len());
|
||||
|
||||
for symbol in symbols {
|
||||
if backfill_jobs.contains_key(&symbol) {
|
||||
warn!(
|
||||
"Backfill for {} {} is already running, skipping.",
|
||||
symbol, log_string
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let backfill = backfills.get(&symbol);
|
||||
|
||||
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
|
||||
backfill.time + ONE_SECOND
|
||||
});
|
||||
|
||||
let fetch_to = last_minute();
|
||||
|
||||
if fetch_from > fetch_to {
|
||||
info!("No need to backfill {} {}.", symbol, log_string,);
|
||||
return;
|
||||
}
|
||||
|
||||
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
|
||||
|
||||
jobs.push(Job {
|
||||
symbol,
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
fresh,
|
||||
});
|
||||
}
|
||||
|
||||
let mut current_minutes = 0;
|
||||
let job_groups = jobs
|
||||
.into_iter()
|
||||
.sorted_unstable_by_key(|job| job.fetch_from)
|
||||
.fold(Vec::<NonEmpty<Job>>::new(), |mut job_groups, job| {
|
||||
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
|
||||
|
||||
if let Some(job_group) = job_groups.last_mut() {
|
||||
if current_minutes + minutes <= max_limit {
|
||||
job_group.push(job);
|
||||
current_minutes += minutes;
|
||||
return job_groups;
|
||||
}
|
||||
}
|
||||
|
||||
job_groups.push(nonempty![job]);
|
||||
current_minutes = minutes;
|
||||
job_groups
|
||||
});
|
||||
|
||||
for job_group in job_groups {
|
||||
let symbols = job_group
|
||||
.iter()
|
||||
.map(|job| job.symbol.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let handler = handler.clone();
|
||||
let symbols_clone = symbols.clone();
|
||||
let backfill_jobs_clone = backfill_jobs_clone.clone();
|
||||
|
||||
let fut = spawn(async move {
|
||||
handler.queue_backfill(&job_group).await;
|
||||
handler.backfill(job_group).await;
|
||||
|
||||
let mut backfill_jobs = backfill_jobs_clone.lock().await;
|
||||
backfill_jobs.remove_many(&symbols_clone);
|
||||
let remaining = backfill_jobs.len();
|
||||
drop(backfill_jobs);
|
||||
|
||||
info!("{} {} backfills remaining.", remaining, log_string);
|
||||
});
|
||||
|
||||
backfill_jobs.insert(symbols, fut);
|
||||
}
|
||||
}
|
||||
Action::Purge => {
|
||||
for symbol in &symbols {
|
||||
if let Some(job) = backfill_jobs.remove(symbol) {
|
||||
job.abort();
|
||||
let _ = job.await;
|
||||
}
|
||||
}
|
||||
|
||||
try_join!(
|
||||
handler.delete_backfills(&symbols),
|
||||
handler.delete_data(&symbols)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
186
src/bin/qrust/threads/data/backfill/news.rs
Normal file
186
src/bin/qrust/threads/data/backfill/news.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
|
||||
database,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::{error, info};
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{
|
||||
alpaca,
|
||||
types::{
|
||||
self,
|
||||
alpaca::shared::{Sort, Source},
|
||||
Backfill, News,
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_news::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
|
||||
|
||||
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[allow(clippy::iter_with_drain)]
|
||||
async fn backfill(&self, jobs: NonEmpty<Job>) {
|
||||
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
|
||||
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
|
||||
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
|
||||
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
|
||||
let freshness = jobs
|
||||
.into_iter()
|
||||
.map(|job| (job.symbol, job.fresh))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
|
||||
let mut last_times = HashMap::new();
|
||||
let mut next_page_token = None;
|
||||
|
||||
info!("Backfilling news for {:?}.", symbols);
|
||||
|
||||
loop {
|
||||
let message = alpaca::news::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
&types::alpaca::api::outgoing::news::News {
|
||||
symbols: symbols.clone(),
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token.clone(),
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = message {
|
||||
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
|
||||
return;
|
||||
}
|
||||
|
||||
let message = message.unwrap();
|
||||
|
||||
for news_item in message.news {
|
||||
let news_item = News::from(news_item);
|
||||
|
||||
for symbol in &news_item.symbols {
|
||||
if symbols_set.contains(symbol) {
|
||||
last_times.insert(symbol.clone(), news_item.time_created);
|
||||
}
|
||||
}
|
||||
|
||||
news.push(news_item);
|
||||
}
|
||||
|
||||
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
database::news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = last_times
|
||||
.drain()
|
||||
.map(|(symbol, time)| Backfill {
|
||||
fresh: freshness[&symbol],
|
||||
symbol,
|
||||
time,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
database::backfills_news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
next_page_token = message.next_page_token;
|
||||
news.clear();
|
||||
}
|
||||
|
||||
database::backfills_news::set_fresh_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
true,
|
||||
&symbols,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Backfilled news for {:?}.", symbols);
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::news::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"news"
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
|
||||
Box::new(Handler { config })
|
||||
}
|
391
src/bin/qrust/threads/data/mod.rs
Normal file
391
src/bin/qrust/threads/data/mod.rs
Normal file
@@ -0,0 +1,391 @@
|
||||
mod backfill;
|
||||
mod websocket;
|
||||
|
||||
use super::clock;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
|
||||
create_send_await, database,
|
||||
};
|
||||
use itertools::{Either, Itertools};
|
||||
use log::error;
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{
|
||||
alpaca,
|
||||
types::{
|
||||
alpaca::websocket::{
|
||||
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
|
||||
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
|
||||
},
|
||||
Asset, Class,
|
||||
},
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{
|
||||
join, select, spawn,
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Action {
|
||||
Add,
|
||||
Enable,
|
||||
Remove,
|
||||
Disable,
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Action,
|
||||
pub assets: NonEmpty<(String, Class)>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
assets,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ThreadType {
|
||||
Bars(Class),
|
||||
News,
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
config: Arc<Config>,
|
||||
mut receiver: mpsc::Receiver<Message>,
|
||||
mut clock_receiver: mpsc::Receiver<clock::Message>,
|
||||
) {
|
||||
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
|
||||
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
|
||||
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
|
||||
init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
|
||||
let (news_websocket_sender, news_backfill_sender) =
|
||||
init_thread(config.clone(), ThreadType::News);
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = receiver.recv() => {
|
||||
spawn(handle_message(
|
||||
config.clone(),
|
||||
bars_us_equity_websocket_sender.clone(),
|
||||
bars_us_equity_backfill_sender.clone(),
|
||||
bars_crypto_websocket_sender.clone(),
|
||||
bars_crypto_backfill_sender.clone(),
|
||||
news_websocket_sender.clone(),
|
||||
news_backfill_sender.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
Some(message) = clock_receiver.recv() => {
|
||||
spawn(handle_clock_message(
|
||||
config.clone(),
|
||||
bars_us_equity_backfill_sender.clone(),
|
||||
bars_crypto_backfill_sender.clone(),
|
||||
news_backfill_sender.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
else => panic!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn init_thread(
|
||||
config: Arc<Config>,
|
||||
thread_type: ThreadType,
|
||||
) -> (
|
||||
mpsc::Sender<websocket::Message>,
|
||||
mpsc::Sender<backfill::Message>,
|
||||
) {
|
||||
let websocket_url = match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => {
|
||||
format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
|
||||
}
|
||||
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
|
||||
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
|
||||
};
|
||||
|
||||
let backfill_handler = match thread_type {
|
||||
ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
|
||||
ThreadType::News => backfill::news::create_handler(config.clone()),
|
||||
};
|
||||
|
||||
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
|
||||
|
||||
spawn(backfill::run(backfill_handler.into(), backfill_receiver));
|
||||
|
||||
let websocket_handler = match thread_type {
|
||||
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
|
||||
ThreadType::News => websocket::news::create_handler(&config),
|
||||
};
|
||||
|
||||
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
|
||||
|
||||
spawn(websocket::run(
|
||||
websocket_handler.into(),
|
||||
websocket_receiver,
|
||||
websocket_url,
|
||||
));
|
||||
|
||||
(websocket_sender, backfill_sender)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn handle_message(
|
||||
config: Arc<Config>,
|
||||
bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
|
||||
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
|
||||
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
news_websocket_sender: mpsc::Sender<websocket::Message>,
|
||||
news_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
message: Message,
|
||||
) {
|
||||
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
|
||||
.assets
|
||||
.clone()
|
||||
.into_iter()
|
||||
.partition_map(|asset| match asset.1 {
|
||||
Class::UsEquity => Either::Left(asset.0),
|
||||
Class::Crypto => Either::Right(asset.0),
|
||||
});
|
||||
|
||||
let symbols = message.assets.map(|(symbol, _)| symbol);
|
||||
|
||||
let bars_us_equity_future = async {
|
||||
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
|
||||
create_send_await!(
|
||||
bars_us_equity_websocket_sender,
|
||||
websocket::Message::new,
|
||||
message.action.into(),
|
||||
us_equity_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let bars_crypto_future = async {
|
||||
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
|
||||
create_send_await!(
|
||||
bars_crypto_websocket_sender,
|
||||
websocket::Message::new,
|
||||
message.action.into(),
|
||||
crypto_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let news_future = async {
|
||||
create_send_await!(
|
||||
news_websocket_sender,
|
||||
websocket::Message::new,
|
||||
message.action.into(),
|
||||
symbols.clone()
|
||||
);
|
||||
};
|
||||
|
||||
join!(bars_us_equity_future, bars_crypto_future, news_future);
|
||||
|
||||
if message.action == Action::Disable {
|
||||
message.response.send(()).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
match message.action {
|
||||
Action::Add | Action::Enable => {
|
||||
let symbols = Vec::from(symbols.clone());
|
||||
|
||||
let assets = async {
|
||||
alpaca::assets::get_by_symbols(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&symbols,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|asset| (asset.symbol.clone(), asset))
|
||||
.collect::<HashMap<_, _>>()
|
||||
};
|
||||
|
||||
let positions = async {
|
||||
alpaca::positions::get_by_symbols(
|
||||
&config.alpaca_client,
|
||||
&config.alpaca_rate_limiter,
|
||||
&symbols,
|
||||
None,
|
||||
&ALPACA_API_BASE,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|position| (position.symbol.clone(), position))
|
||||
.collect::<HashMap<_, _>>()
|
||||
};
|
||||
|
||||
let (mut assets, mut positions) = join!(assets, positions);
|
||||
|
||||
let batch =
|
||||
symbols
|
||||
.iter()
|
||||
.fold(Vec::with_capacity(symbols.len()), |mut batch, symbol| {
|
||||
if let Some(asset) = assets.remove(symbol) {
|
||||
let position = positions.remove(symbol);
|
||||
batch.push(Asset::from((asset, position)));
|
||||
} else {
|
||||
error!("Failed to find asset for symbol: {}.", symbol);
|
||||
}
|
||||
batch
|
||||
});
|
||||
|
||||
database::assets::upsert_batch(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&batch,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
Action::Remove => {
|
||||
database::assets::delete_where_symbols(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&Vec::from(symbols.clone()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
Action::Disable => unreachable!(),
|
||||
}
|
||||
|
||||
let bars_us_equity_future = async {
|
||||
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
|
||||
create_send_await!(
|
||||
bars_us_equity_backfill_sender,
|
||||
backfill::Message::new,
|
||||
match message.action {
|
||||
Action::Add | Action::Enable => backfill::Action::Backfill,
|
||||
Action::Remove => backfill::Action::Purge,
|
||||
Action::Disable => unreachable!(),
|
||||
},
|
||||
us_equity_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let bars_crypto_future = async {
|
||||
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
|
||||
create_send_await!(
|
||||
bars_crypto_backfill_sender,
|
||||
backfill::Message::new,
|
||||
match message.action {
|
||||
Action::Add | Action::Enable => backfill::Action::Backfill,
|
||||
Action::Remove => backfill::Action::Purge,
|
||||
Action::Disable => unreachable!(),
|
||||
},
|
||||
crypto_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let news_future = async {
|
||||
create_send_await!(
|
||||
news_backfill_sender,
|
||||
backfill::Message::new,
|
||||
match message.action {
|
||||
Action::Add | Action::Enable => backfill::Action::Backfill,
|
||||
Action::Remove => backfill::Action::Purge,
|
||||
Action::Disable => unreachable!(),
|
||||
},
|
||||
symbols
|
||||
);
|
||||
};
|
||||
|
||||
join!(bars_us_equity_future, bars_crypto_future, news_future);
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
async fn handle_clock_message(
|
||||
config: Arc<Config>,
|
||||
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
news_backfill_sender: mpsc::Sender<backfill::Message>,
|
||||
message: clock::Message,
|
||||
) {
|
||||
if message.status == clock::Status::Closed {
|
||||
database::cleanup_all(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let assets = database::assets::select(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
|
||||
.clone()
|
||||
.into_iter()
|
||||
.partition_map(|asset| match asset.class {
|
||||
Class::UsEquity => Either::Left(asset.symbol),
|
||||
Class::Crypto => Either::Right(asset.symbol),
|
||||
});
|
||||
|
||||
let symbols = assets
|
||||
.into_iter()
|
||||
.map(|asset| asset.symbol)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bars_us_equity_future = async {
|
||||
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
|
||||
create_send_await!(
|
||||
bars_us_equity_backfill_sender,
|
||||
backfill::Message::new,
|
||||
backfill::Action::Backfill,
|
||||
us_equity_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let bars_crypto_future = async {
|
||||
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
|
||||
create_send_await!(
|
||||
bars_crypto_backfill_sender,
|
||||
backfill::Message::new,
|
||||
backfill::Action::Backfill,
|
||||
crypto_symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let news_future = async {
|
||||
if let Some(symbols) = NonEmpty::from_vec(symbols) {
|
||||
create_send_await!(
|
||||
news_backfill_sender,
|
||||
backfill::Message::new,
|
||||
backfill::Action::Backfill,
|
||||
symbols
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
join!(bars_us_equity_future, bars_crypto_future, news_future);
|
||||
}
|
171
src/bin/qrust/threads/data/websocket/bars.rs
Normal file
171
src/bin/qrust/threads/data/websocket/bars.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use super::State;
|
||||
use crate::{
|
||||
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
|
||||
database,
|
||||
threads::data::ThreadType,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use clickhouse::inserter::Inserter;
|
||||
use log::{debug, error, info};
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{
|
||||
types::{alpaca::websocket, Bar, Class},
|
||||
utils::ONE_SECOND,
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
pub inserter: Arc<Mutex<Inserter<Bar>>>,
|
||||
pub subscription_message_constructor:
|
||||
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: NonEmpty<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
(self.subscription_message_constructor)(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
state: Arc<RwLock<State>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::Market {
|
||||
bars: symbols, ..
|
||||
} = message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let symbols = symbols.into_iter().collect::<HashSet<_>>();
|
||||
let mut state = state.write().await;
|
||||
|
||||
let newly_subscribed = state
|
||||
.pending_subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = state
|
||||
.pending_unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
state
|
||||
.active_subscriptions
|
||||
.extend(newly_subscribed.keys().cloned());
|
||||
|
||||
drop(state);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to bars for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from bars for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Bar(message)
|
||||
| websocket::data::incoming::Message::UpdatedBar(message) => {
|
||||
let bar = Bar::from(message);
|
||||
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
||||
self.inserter.lock().await.write(&bar).await.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Status(message) => {
|
||||
debug!(
|
||||
"Received status message for {}: {:?}.",
|
||||
message.symbol, message.status
|
||||
);
|
||||
|
||||
match message.status {
|
||||
websocket::data::incoming::status::Status::TradingHalt(_)
|
||||
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::status::Status::Resume(_)
|
||||
| websocket::data::incoming::status::Status::TradingResumption(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"bars"
|
||||
}
|
||||
|
||||
async fn run_inserter(&self) {
|
||||
super::run_inserter(self.inserter.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
|
||||
let inserter = Arc::new(Mutex::new(
|
||||
config
|
||||
.clickhouse_client
|
||||
.inserter("bars")
|
||||
.unwrap()
|
||||
.with_period(Some(ONE_SECOND))
|
||||
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
|
||||
));
|
||||
|
||||
let subscription_message_constructor = match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => {
|
||||
websocket::data::outgoing::subscribe::Message::new_market_us_equity
|
||||
}
|
||||
ThreadType::Bars(Class::Crypto) => {
|
||||
websocket::data::outgoing::subscribe::Message::new_market_crypto
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Box::new(Handler {
|
||||
config,
|
||||
inserter,
|
||||
subscription_message_constructor,
|
||||
})
|
||||
}
|
353
src/bin/qrust/threads/data/websocket/mod.rs
Normal file
353
src/bin/qrust/threads/data/websocket/mod.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
pub mod bars;
|
||||
pub mod news;
|
||||
|
||||
use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET};
|
||||
use async_trait::async_trait;
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use clickhouse::{inserter::Inserter, Row};
|
||||
use futures_util::{future::join_all, SinkExt, StreamExt};
|
||||
use log::error;
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::types::alpaca::{self, websocket};
|
||||
use serde::Serialize;
|
||||
use serde_json::{from_str, to_string};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
select, spawn,
|
||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||
};
|
||||
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
pub enum Action {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
||||
|
||||
impl From<super::Action> for Option<Action> {
|
||||
fn from(action: super::Action) -> Self {
|
||||
match action {
|
||||
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
|
||||
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Option<Action>,
|
||||
pub symbols: NonEmpty<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
pub active_subscriptions: HashSet<String>,
|
||||
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync + 'static {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: NonEmpty<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message;
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
state: Arc<RwLock<State>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
);
|
||||
fn log_string(&self) -> &'static str;
|
||||
async fn run_inserter(&self);
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
mut receiver: mpsc::Receiver<Message>,
|
||||
websocket_url: String,
|
||||
) {
|
||||
let state = Arc::new(RwLock::new(State {
|
||||
active_subscriptions: HashSet::new(),
|
||||
pending_subscriptions: HashMap::new(),
|
||||
pending_unsubscriptions: HashMap::new(),
|
||||
}));
|
||||
|
||||
let handler_clone = handler.clone();
|
||||
spawn(async move { handler_clone.run_inserter().await });
|
||||
|
||||
let (sink_sender, sink_receiver) = mpsc::channel(100);
|
||||
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
|
||||
|
||||
spawn(run_connection(
|
||||
handler.clone(),
|
||||
sink_receiver,
|
||||
stream_sender,
|
||||
websocket_url.clone(),
|
||||
state.clone(),
|
||||
));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = receiver.recv() => {
|
||||
spawn(handle_message(
|
||||
handler.clone(),
|
||||
state.clone(),
|
||||
sink_sender.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
Some(message) = stream_receiver.recv() => {
|
||||
match message {
|
||||
tungstenite::Message::Text(message) => {
|
||||
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
|
||||
|
||||
if parsed_message.is_err() {
|
||||
error!("Failed to deserialize websocket message: {:?}.", message);
|
||||
continue;
|
||||
}
|
||||
|
||||
for message in parsed_message.unwrap() {
|
||||
let handler = handler.clone();
|
||||
let state = state.clone();
|
||||
spawn(async move {
|
||||
handler.handle_websocket_message(state, message).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
tungstenite::Message::Ping(_) => {}
|
||||
_ => error!("Unexpected websocket message: {:?}.", message),
|
||||
}
|
||||
}
|
||||
else => panic!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn run_connection(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
|
||||
stream_sender: mpsc::Sender<tungstenite::Message>,
|
||||
websocket_url: String,
|
||||
state: Arc<RwLock<State>>,
|
||||
) {
|
||||
let mut peek = None;
|
||||
|
||||
'connection: loop {
|
||||
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
|
||||
ExponentialBackoff::default(),
|
||||
|| async {
|
||||
connect_async(websocket_url.clone())
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
error!(
|
||||
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
|
||||
handler.log_string(),
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (mut sink, mut stream) = websocket.split();
|
||||
alpaca::websocket::data::authenticate(
|
||||
&mut sink,
|
||||
&mut stream,
|
||||
(*ALPACA_API_KEY).to_string(),
|
||||
(*ALPACA_API_SECRET).to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut state = state.write().await;
|
||||
|
||||
state
|
||||
.pending_unsubscriptions
|
||||
.drain()
|
||||
.for_each(|(_, sender)| {
|
||||
sender.send(()).unwrap();
|
||||
});
|
||||
|
||||
let (recovered_subscriptions, receivers) = state
|
||||
.active_subscriptions
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
state.pending_subscriptions.extend(recovered_subscriptions);
|
||||
|
||||
let pending_subscriptions = state
|
||||
.pending_subscriptions
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
drop(state);
|
||||
|
||||
if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) {
|
||||
if let Err(err) = sink
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Subscribe(
|
||||
handler.create_subscription_message(pending_subscriptions),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("Failed to send websocket message: {:?}.", err);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
join_all(receivers).await;
|
||||
|
||||
if peek.is_some() {
|
||||
if let Err(err) = sink.send(peek.clone().unwrap()).await {
|
||||
error!("Failed to send websocket message: {:?}.", err);
|
||||
continue;
|
||||
}
|
||||
peek = None;
|
||||
}
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = sink_receiver.recv() => {
|
||||
peek = Some(message.clone());
|
||||
|
||||
if let Err(err) = sink.send(message).await {
|
||||
error!("Failed to send websocket message: {:?}.", err);
|
||||
continue 'connection;
|
||||
};
|
||||
|
||||
peek = None;
|
||||
}
|
||||
message = stream.next() => {
|
||||
if message.is_none() {
|
||||
error!("Websocket stream unexpectedly closed.");
|
||||
continue 'connection;
|
||||
}
|
||||
|
||||
let message = message.unwrap();
|
||||
|
||||
if let Err(err) = message {
|
||||
error!("Failed to receive websocket message: {:?}.", err);
|
||||
continue 'connection;
|
||||
}
|
||||
|
||||
let message = message.unwrap();
|
||||
|
||||
if message.is_close() {
|
||||
error!("Websocket connection closed.");
|
||||
continue 'connection;
|
||||
}
|
||||
|
||||
stream_sender.send(message).await.unwrap();
|
||||
}
|
||||
else => error!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
pending: Arc<RwLock<State>>,
|
||||
sink_sender: mpsc::Sender<tungstenite::Message>,
|
||||
message: Message,
|
||||
) {
|
||||
match message.action {
|
||||
Some(Action::Subscribe) => {
|
||||
let (pending_subscriptions, receivers) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.pending_subscriptions
|
||||
.extend(pending_subscriptions);
|
||||
|
||||
sink_sender
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Subscribe(
|
||||
handler.create_subscription_message(message.symbols),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
Some(Action::Unsubscribe) => {
|
||||
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.pending_unsubscriptions
|
||||
.extend(pending_unsubscriptions);
|
||||
|
||||
sink_sender
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Unsubscribe(
|
||||
handler.create_subscription_message(message.symbols.clone()),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>)
|
||||
where
|
||||
T: Row + Serialize,
|
||||
{
|
||||
loop {
|
||||
let time_left = inserter.lock().await.time_left().unwrap();
|
||||
tokio::time::sleep(time_left).await;
|
||||
inserter.lock().await.commit().await.unwrap();
|
||||
}
|
||||
}
|
119
src/bin/qrust/threads/data/websocket/news.rs
Normal file
119
src/bin/qrust/threads/data/websocket/news.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use super::State;
|
||||
use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE};
|
||||
use async_trait::async_trait;
|
||||
use clickhouse::inserter::Inserter;
|
||||
use log::{debug, error, info};
|
||||
use nonempty::NonEmpty;
|
||||
use qrust::{
|
||||
types::{alpaca::websocket, News},
|
||||
utils::ONE_SECOND,
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
pub struct Handler {
|
||||
pub inserter: Arc<Mutex<Inserter<News>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: NonEmpty<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
websocket::data::outgoing::subscribe::Message::new_news(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
state: Arc<RwLock<State>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::News { news: symbols } =
|
||||
message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let symbols = symbols.into_iter().collect::<HashSet<_>>();
|
||||
let mut state = state.write().await;
|
||||
|
||||
let newly_subscribed = state
|
||||
.pending_subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = state
|
||||
.pending_unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
state
|
||||
.active_subscriptions
|
||||
.extend(newly_subscribed.keys().cloned());
|
||||
|
||||
drop(state);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to news for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from news for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
self.inserter.lock().await.write(&news).await.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"news"
|
||||
}
|
||||
|
||||
async fn run_inserter(&self) {
|
||||
super::run_inserter(self.inserter.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(config: &Arc<Config>) -> Box<dyn super::Handler> {
|
||||
let inserter = Arc::new(Mutex::new(
|
||||
config
|
||||
.clickhouse_client
|
||||
.inserter("news")
|
||||
.unwrap()
|
||||
.with_period(Some(ONE_SECOND))
|
||||
.with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()),
|
||||
));
|
||||
|
||||
Box::new(Handler { inserter })
|
||||
}
|
27
src/bin/qrust/threads/trading/mod.rs
Normal file
27
src/bin/qrust/threads/trading/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
mod websocket;
|
||||
|
||||
use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET};
|
||||
use futures_util::StreamExt;
|
||||
use qrust::types::alpaca;
|
||||
use std::sync::Arc;
|
||||
use tokio::spawn;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
pub async fn run(config: Arc<Config>) {
|
||||
let (websocket, _) =
|
||||
connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE))
|
||||
.await
|
||||
.unwrap();
|
||||
let (mut websocket_sink, mut websocket_stream) = websocket.split();
|
||||
|
||||
alpaca::websocket::trading::authenticate(
|
||||
&mut websocket_sink,
|
||||
&mut websocket_stream,
|
||||
(*ALPACA_API_KEY).to_string(),
|
||||
(*ALPACA_API_SECRET).to_string(),
|
||||
)
|
||||
.await;
|
||||
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
|
||||
|
||||
spawn(websocket::run(config, websocket_stream));
|
||||
}
|
@@ -1,10 +1,7 @@
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, Order},
|
||||
};
|
||||
use crate::{config::Config, database};
|
||||
use futures_util::{stream::SplitStream, StreamExt};
|
||||
use log::{debug, error};
|
||||
use qrust::types::{alpaca::websocket, Order};
|
||||
use serde_json::from_str;
|
||||
use std::sync::Arc;
|
||||
use tokio::{net::TcpStream, spawn};
|
||||
@@ -24,7 +21,7 @@ pub async fn run(
|
||||
);
|
||||
|
||||
if parsed_message.is_err() {
|
||||
error!("Failed to deserialize websocket message: {:?}", message);
|
||||
error!("Failed to deserialize websocket message: {:?}.", message);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -34,7 +31,7 @@ pub async fn run(
|
||||
));
|
||||
}
|
||||
tungstenite::Message::Ping(_) => {}
|
||||
_ => error!("Unexpected websocket message: {:?}", message),
|
||||
_ => error!("Unexpected websocket message: {:?}.", message),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -46,15 +43,19 @@ async fn handle_websocket_message(
|
||||
match message {
|
||||
websocket::trading::incoming::Message::Order(message) => {
|
||||
debug!(
|
||||
"Received order message for {}: {:?}",
|
||||
"Received order message for {}: {:?}.",
|
||||
message.order.symbol, message.event
|
||||
);
|
||||
|
||||
let order = Order::from(message.order);
|
||||
|
||||
database::orders::upsert(&config.clickhouse_client, &order)
|
||||
.await
|
||||
.unwrap();
|
||||
database::orders::upsert(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&order,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match message.event {
|
||||
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
|
||||
@@ -63,6 +64,7 @@ async fn handle_websocket_message(
|
||||
} => {
|
||||
database::assets::update_qty_where_symbol(
|
||||
&config.clickhouse_client,
|
||||
&config.clickhouse_concurrency_limiter,
|
||||
&order.symbol,
|
||||
position_qty,
|
||||
)
|
133
src/bin/trainer/mod.rs
Normal file
133
src/bin/trainer/mod.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{
|
||||
dataloader::{DataLoaderBuilder, Dataset},
|
||||
dataset::transform::{PartialDataset, ShuffledDataset},
|
||||
},
|
||||
module::Module,
|
||||
optim::AdamConfig,
|
||||
record::CompactRecorder,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::LearnerBuilder,
|
||||
};
|
||||
use dotenv::dotenv;
|
||||
use log::info;
|
||||
use qrust::{
|
||||
database,
|
||||
ml::{
|
||||
BarWindow, BarWindowBatcher, ModelConfig, MultipleSymbolDataset, MyAutodiffBackend, DEVICE,
|
||||
},
|
||||
types::Bar,
|
||||
};
|
||||
use std::{env, fs, path::Path, sync::Arc};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct TrainingConfig {
|
||||
pub model: ModelConfig,
|
||||
pub optimizer: AdamConfig,
|
||||
#[config(default = 100)]
|
||||
pub epochs: usize,
|
||||
#[config(default = 256)]
|
||||
pub batch_size: usize,
|
||||
#[config(default = 16)]
|
||||
pub num_workers: usize,
|
||||
#[config(default = 0)]
|
||||
pub seed: u64,
|
||||
#[config(default = 0.2)]
|
||||
pub valid_pct: f64,
|
||||
#[config(default = 1.0e-4)]
|
||||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenv().ok();
|
||||
let dir = Path::new(file!()).parent().unwrap();
|
||||
|
||||
let model_config = ModelConfig::new();
|
||||
let optimizer = AdamConfig::new();
|
||||
|
||||
let training_config = TrainingConfig::new(model_config, optimizer);
|
||||
|
||||
let clickhouse_client = clickhouse::Client::default()
|
||||
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
|
||||
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
|
||||
.with_password(env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."))
|
||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."));
|
||||
|
||||
let clickhouse_concurrency_limiter = Arc::new(Semaphore::new(Semaphore::MAX_PERMITS));
|
||||
|
||||
let bars = database::ta::select(&clickhouse_client, &clickhouse_concurrency_limiter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("Loaded {} bars.", bars.len());
|
||||
|
||||
train::<MyAutodiffBackend>(
|
||||
bars,
|
||||
&training_config,
|
||||
dir.join("artifacts").to_str().unwrap(),
|
||||
&DEVICE,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
fn train<B: AutodiffBackend<FloatElem = f32, IntElem = i32>>(
|
||||
bars: Vec<Bar>,
|
||||
config: &TrainingConfig,
|
||||
dir: &str,
|
||||
device: &B::Device,
|
||||
) {
|
||||
B::seed(config.seed);
|
||||
|
||||
fs::create_dir_all(dir).unwrap();
|
||||
|
||||
let dataset = MultipleSymbolDataset::new(bars);
|
||||
let dataset = ShuffledDataset::with_seed(dataset, config.seed);
|
||||
let dataset = Arc::new(dataset);
|
||||
|
||||
let split = (dataset.len() as f64 * (1.0 - config.valid_pct)) as usize;
|
||||
|
||||
let train: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
|
||||
PartialDataset::new(dataset.clone(), 0, split);
|
||||
|
||||
let batcher_train = BarWindowBatcher::<B> {
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(config.num_workers)
|
||||
.build(train);
|
||||
|
||||
let valid: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
|
||||
PartialDataset::new(dataset.clone(), split, dataset.len());
|
||||
|
||||
let batcher_valid = BarWindowBatcher::<B::InnerBackend> {
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(config.batch_size)
|
||||
.num_workers(config.num_workers)
|
||||
.build(valid);
|
||||
|
||||
let learner = LearnerBuilder::new(dir)
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.epochs)
|
||||
.build(
|
||||
config.model.init::<B>(device),
|
||||
config.optimizer.init(),
|
||||
config.learning_rate,
|
||||
);
|
||||
|
||||
let trained = learner.fit(dataloader_train, dataloader_valid);
|
||||
|
||||
trained.save_file(dir, &CompactRecorder::new()).unwrap();
|
||||
}
|
123
src/config.rs
123
src/config.rs
@@ -1,123 +0,0 @@
|
||||
use crate::types::alpaca::shared::{Mode, Source};
|
||||
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
|
||||
use lazy_static::lazy_static;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderName, HeaderValue},
|
||||
Client,
|
||||
};
|
||||
use rust_bert::{
|
||||
pipelines::{
|
||||
common::{ModelResource, ModelType},
|
||||
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
|
||||
},
|
||||
resources::LocalResource,
|
||||
};
|
||||
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
|
||||
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
|
||||
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
|
||||
|
||||
pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
|
||||
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
|
||||
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
|
||||
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
|
||||
|
||||
lazy_static! {
|
||||
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
||||
.expect("ALPACA_MODE must be set.")
|
||||
.parse()
|
||||
.expect("ALPACA_MODE must be 'live' or 'paper'");
|
||||
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
|
||||
.expect("ALPACA_SOURCE must be set.")
|
||||
.parse()
|
||||
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
|
||||
pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
||||
pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
||||
#[derive(Debug)]
|
||||
pub static ref ALPACA_API_URL: String = format!(
|
||||
"https://{}.alpaca.markets/v2",
|
||||
match *ALPACA_MODE {
|
||||
Mode::Live => String::from("api"),
|
||||
Mode::Paper => String::from("paper-api"),
|
||||
}
|
||||
);
|
||||
#[derive(Debug)]
|
||||
pub static ref ALPACA_WEBSOCKET_URL: String = format!(
|
||||
"wss://{}.alpaca.markets/stream",
|
||||
match *ALPACA_MODE {
|
||||
Mode::Live => String::from("api"),
|
||||
Mode::Paper => String::from("paper-api"),
|
||||
}
|
||||
);
|
||||
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
|
||||
.expect("MAX_BERT_INPUTS must be set.")
|
||||
.parse()
|
||||
.expect("MAX_BERT_INPUTS must be a positive integer.");
|
||||
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub alpaca_client: Client,
|
||||
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
||||
pub clickhouse_client: clickhouse::Client,
|
||||
pub sequence_classifier: Mutex<SequenceClassificationModel>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> Self {
|
||||
Self {
|
||||
alpaca_client: Client::builder()
|
||||
.default_headers(HeaderMap::from_iter([
|
||||
(
|
||||
HeaderName::from_static("apca-api-key-id"),
|
||||
HeaderValue::from_str(&ALPACA_API_KEY)
|
||||
.expect("Alpaca API key must not contain invalid characters."),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("apca-api-secret-key"),
|
||||
HeaderValue::from_str(&ALPACA_API_SECRET)
|
||||
.expect("Alpaca API secret must not contain invalid characters."),
|
||||
),
|
||||
]))
|
||||
.build()
|
||||
.unwrap(),
|
||||
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
|
||||
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
|
||||
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
|
||||
Source::Otc => unimplemented!("OTC rate limit not implemented."),
|
||||
})),
|
||||
clickhouse_client: clickhouse::Client::default()
|
||||
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
|
||||
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
|
||||
.with_password(
|
||||
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
|
||||
)
|
||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
||||
sequence_classifier: Mutex::new(
|
||||
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::Torch(Box::new(LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
|
||||
})),
|
||||
LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/config.json"),
|
||||
},
|
||||
LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/vocab.txt"),
|
||||
},
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arc_from_env() -> Arc<Self> {
|
||||
Arc::new(Self::from_env())
|
||||
}
|
||||
}
|
@@ -1,17 +0,0 @@
|
||||
use crate::{
|
||||
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
|
||||
};
|
||||
use clickhouse::{error::Error, Client};
|
||||
|
||||
select_where_symbol!(Backfill, "backfills_bars");
|
||||
upsert!(Backfill, "backfills_bars");
|
||||
delete_where_symbols!("backfills_bars");
|
||||
cleanup!("backfills_bars");
|
||||
optimize!("backfills_bars");
|
||||
|
||||
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
|
||||
clickhouse_client
|
||||
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
|
||||
.execute()
|
||||
.await
|
||||
}
|
@@ -1,17 +0,0 @@
|
||||
use crate::{
|
||||
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
|
||||
};
|
||||
use clickhouse::{error::Error, Client};
|
||||
|
||||
select_where_symbol!(Backfill, "backfills_news");
|
||||
upsert!(Backfill, "backfills_news");
|
||||
delete_where_symbols!("backfills_news");
|
||||
cleanup!("backfills_news");
|
||||
optimize!("backfills_news");
|
||||
|
||||
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
|
||||
clickhouse_client
|
||||
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
|
||||
.execute()
|
||||
.await
|
||||
}
|
@@ -1,7 +0,0 @@
|
||||
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
|
||||
|
||||
upsert!(Bar, "bars");
|
||||
upsert_batch!(Bar, "bars");
|
||||
delete_where_symbols!("bars");
|
||||
cleanup!("bars");
|
||||
optimize!("bars");
|
@@ -1,152 +0,0 @@
|
||||
pub mod assets;
|
||||
pub mod backfills_bars;
|
||||
pub mod backfills_news;
|
||||
pub mod bars;
|
||||
pub mod calendar;
|
||||
pub mod news;
|
||||
pub mod orders;
|
||||
|
||||
use clickhouse::{error::Error, Client};
|
||||
use log::info;
|
||||
use tokio::try_join;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn select(
|
||||
client: &clickhouse::Client,
|
||||
) -> Result<Vec<$record>, clickhouse::error::Error> {
|
||||
client
|
||||
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
|
||||
.fetch_all::<$record>()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select_where_symbol {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn select_where_symbol<T>(
|
||||
client: &clickhouse::Client,
|
||||
symbol: &T,
|
||||
) -> Result<Option<$record>, clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
client
|
||||
.query(&format!(
|
||||
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
|
||||
$table_name
|
||||
))
|
||||
.bind(symbol)
|
||||
.fetch_optional::<$record>()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! upsert {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn upsert(
|
||||
client: &clickhouse::Client,
|
||||
record: &$record,
|
||||
) -> Result<(), clickhouse::error::Error> {
|
||||
let mut insert = client.insert($table_name)?;
|
||||
insert.write(record).await?;
|
||||
insert.end().await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! upsert_batch {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn upsert_batch<'a, T>(
|
||||
client: &clickhouse::Client,
|
||||
records: T,
|
||||
) -> Result<(), clickhouse::error::Error>
|
||||
where
|
||||
T: IntoIterator<Item = &'a $record> + Send + Sync,
|
||||
T::IntoIter: Send,
|
||||
{
|
||||
let mut insert = client.insert($table_name)?;
|
||||
for record in records {
|
||||
insert.write(record).await?;
|
||||
}
|
||||
insert.end().await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! delete_where_symbols {
|
||||
($table_name:expr) => {
|
||||
pub async fn delete_where_symbols<T>(
|
||||
client: &clickhouse::Client,
|
||||
symbols: &[T],
|
||||
) -> Result<(), clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
client
|
||||
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
|
||||
.bind(symbols)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! cleanup {
|
||||
($table_name:expr) => {
|
||||
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
|
||||
client
|
||||
.query(&format!(
|
||||
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
|
||||
$table_name
|
||||
))
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! optimize {
|
||||
($table_name:expr) => {
|
||||
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
|
||||
client
|
||||
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
|
||||
info!("Cleaning up database.");
|
||||
try_join!(
|
||||
bars::cleanup(clickhouse_client),
|
||||
news::cleanup(clickhouse_client),
|
||||
backfills_bars::cleanup(clickhouse_client),
|
||||
backfills_news::cleanup(clickhouse_client)
|
||||
)
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
|
||||
info!("Optimizing database.");
|
||||
try_join!(
|
||||
assets::optimize(clickhouse_client),
|
||||
bars::optimize(clickhouse_client),
|
||||
news::optimize(clickhouse_client),
|
||||
backfills_bars::optimize(clickhouse_client),
|
||||
backfills_news::optimize(clickhouse_client),
|
||||
orders::optimize(clickhouse_client),
|
||||
calendar::optimize(clickhouse_client)
|
||||
)
|
||||
.map(|_| ())
|
||||
}
|
39
src/lib/qrust/alpaca/account.rs
Normal file
39
src/lib/qrust/alpaca/account.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::incoming::account::Account;
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Account, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Account>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get account, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
132
src/lib/qrust/alpaca/assets.rs
Normal file
132
src/lib/qrust/alpaca/assets.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::{
|
||||
incoming::asset::{Asset, Class},
|
||||
outgoing,
|
||||
};
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use itertools::Itertools;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use std::{collections::HashSet, time::Duration};
|
||||
use tokio::try_join;
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
query: &outgoing::asset::Asset,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Asset>, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
|
||||
.query(query)
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Vec<Asset>>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get assets, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_by_symbol(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
symbol: &str,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Asset, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!(
|
||||
"https://{}.alpaca.markets/v2/assets/{}",
|
||||
api_base, symbol
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Asset>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get asset, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_by_symbols(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
symbols: &[String],
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Asset>, Error> {
|
||||
if symbols.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if symbols.len() == 1 {
|
||||
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
|
||||
return Ok(vec![asset]);
|
||||
}
|
||||
|
||||
let symbols = symbols.iter().collect::<HashSet<_>>();
|
||||
|
||||
let backoff_clone = backoff.clone();
|
||||
|
||||
let us_equity_query = outgoing::asset::Asset {
|
||||
class: Some(Class::UsEquity),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let us_equity_assets = get(
|
||||
client,
|
||||
rate_limiter,
|
||||
&us_equity_query,
|
||||
backoff_clone,
|
||||
api_base,
|
||||
);
|
||||
|
||||
let crypto_query = outgoing::asset::Asset {
|
||||
class: Some(Class::Crypto),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
|
||||
|
||||
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
|
||||
|
||||
Ok(crypto_assets
|
||||
.into_iter()
|
||||
.chain(us_equity_assets)
|
||||
.dedup_by(|a, b| a.symbol == b.symbol)
|
||||
.filter(|asset| symbols.contains(&asset.symbol))
|
||||
.collect())
|
||||
}
|
50
src/lib/qrust/alpaca/bars.rs
Normal file
50
src/lib/qrust/alpaca/bars.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
pub const MAX_LIMIT: i64 = 10_000;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Message {
|
||||
pub bars: HashMap<String, Vec<Bar>>,
|
||||
pub next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
data_url: &str,
|
||||
query: &outgoing::bar::Bar,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
) -> Result<Message, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(data_url)
|
||||
.query(query)
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Message>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get historical bars, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
41
src/lib/qrust/alpaca/calendar.rs
Normal file
41
src/lib/qrust/alpaca/calendar.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
query: &outgoing::calendar::Calendar,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Calendar>, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
|
||||
.query(query)
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Vec<Calendar>>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get calendar, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
39
src/lib/qrust/alpaca/clock.rs
Normal file
39
src/lib/qrust/alpaca/clock.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::incoming::clock::Clock;
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Clock, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Clock>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get clock, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
27
src/lib/qrust/alpaca/mod.rs
Normal file
27
src/lib/qrust/alpaca/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub mod account;
|
||||
pub mod assets;
|
||||
pub mod bars;
|
||||
pub mod calendar;
|
||||
pub mod clock;
|
||||
pub mod news;
|
||||
pub mod orders;
|
||||
pub mod positions;
|
||||
|
||||
use http::StatusCode;
|
||||
|
||||
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
|
||||
if err.is_status() {
|
||||
return match err.status() {
|
||||
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
|
||||
| None => backoff::Error::Permanent(err),
|
||||
_ => err.into(),
|
||||
};
|
||||
}
|
||||
|
||||
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
|
||||
{
|
||||
return backoff::Error::Permanent(err);
|
||||
}
|
||||
|
||||
err.into()
|
||||
}
|
49
src/lib/qrust/alpaca/news.rs
Normal file
49
src/lib/qrust/alpaca/news.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use serde::Deserialize;
|
||||
use std::time::Duration;
|
||||
|
||||
pub const MAX_LIMIT: i64 = 50;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Message {
|
||||
pub news: Vec<News>,
|
||||
pub next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
query: &outgoing::news::News,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
) -> Result<Message, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(ALPACA_NEWS_DATA_API_URL)
|
||||
.query(query)
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Message>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get historical news, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
@@ -1,44 +1,39 @@
|
||||
use crate::{
|
||||
config::ALPACA_API_URL,
|
||||
types::alpaca::{api::outgoing, shared},
|
||||
};
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::{api::outgoing, shared::order};
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use std::time::Duration;
|
||||
|
||||
pub use shared::order::Order;
|
||||
pub use order::Order;
|
||||
|
||||
pub async fn get(
|
||||
alpaca_client: &Client,
|
||||
alpaca_rate_limiter: &DefaultDirectRateLimiter,
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
query: &outgoing::order::Order,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Order>, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
alpaca_rate_limiter.until_ready().await;
|
||||
alpaca_client
|
||||
.get(&format!("{}/orders", *ALPACA_API_URL))
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
|
||||
.query(query)
|
||||
.send()
|
||||
.await?
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(|e| match e.status() {
|
||||
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
|
||||
backoff::Error::Permanent(e)
|
||||
}
|
||||
_ => e.into(),
|
||||
})?
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Vec<Order>>()
|
||||
.await
|
||||
.map_err(backoff::Error::Permanent)
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get orders, will retry in {} seconds: {}",
|
||||
"Failed to get orders, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
109
src/lib/qrust/alpaca/positions.rs
Normal file
109
src/lib/qrust/alpaca/positions.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use super::error_to_backoff;
|
||||
use crate::types::alpaca::api::incoming::position::Position;
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use http::StatusCode;
|
||||
use log::warn;
|
||||
use reqwest::Client;
|
||||
use std::{collections::HashSet, time::Duration};
|
||||
|
||||
pub async fn get(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Position>, reqwest::Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
client
|
||||
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Vec<Position>>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get positions, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_by_symbol(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
symbol: &str,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Option<Position>, reqwest::Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
rate_limiter.until_ready().await;
|
||||
let response = client
|
||||
.get(&format!(
|
||||
"https://{}.alpaca.markets/v2/positions/{}",
|
||||
api_base, symbol
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.map_err(error_to_backoff)?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
response
|
||||
.error_for_status()
|
||||
.map_err(error_to_backoff)?
|
||||
.json::<Position>()
|
||||
.await
|
||||
.map_err(error_to_backoff)
|
||||
.map(Some)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get position, will retry in {} seconds: {}.",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_by_symbols(
|
||||
client: &Client,
|
||||
rate_limiter: &DefaultDirectRateLimiter,
|
||||
symbols: &[String],
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
api_base: &str,
|
||||
) -> Result<Vec<Position>, reqwest::Error> {
|
||||
if symbols.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if symbols.len() == 1 {
|
||||
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
|
||||
return Ok(position.into_iter().collect());
|
||||
}
|
||||
|
||||
let symbols = symbols.iter().collect::<HashSet<_>>();
|
||||
|
||||
let positions = get(client, rate_limiter, backoff, api_base).await?;
|
||||
|
||||
Ok(positions
|
||||
.into_iter()
|
||||
.filter(|position| symbols.contains(&position.symbol))
|
||||
.collect())
|
||||
}
|
@@ -1,8 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
|
||||
};
|
||||
use clickhouse::{error::Error, Client};
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
select!(Asset, "assets");
|
||||
select_where_symbol!(Asset, "assets");
|
||||
@@ -11,14 +14,16 @@ delete_where_symbols!("assets");
|
||||
optimize!("assets");
|
||||
|
||||
pub async fn update_status_where_symbol<T>(
|
||||
clickhouse_client: &Client,
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
symbol: &T,
|
||||
status: bool,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: AsRef<str> + Serialize + Send + Sync,
|
||||
{
|
||||
clickhouse_client
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
|
||||
.bind(status)
|
||||
.bind(symbol)
|
||||
@@ -27,14 +32,16 @@ where
|
||||
}
|
||||
|
||||
pub async fn update_qty_where_symbol<T>(
|
||||
clickhouse_client: &Client,
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
symbol: &T,
|
||||
qty: f64,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: AsRef<str> + Serialize + Send + Sync,
|
||||
{
|
||||
clickhouse_client
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
|
||||
.bind(qty)
|
||||
.bind(symbol)
|
11
src/lib/qrust/database/backfills_bars.rs
Normal file
11
src/lib/qrust/database/backfills_bars.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use crate::{
|
||||
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
|
||||
types::Backfill, upsert_batch,
|
||||
};
|
||||
|
||||
select_where_symbols!(Backfill, "backfills_bars");
|
||||
upsert_batch!(Backfill, "backfills_bars");
|
||||
delete_where_symbols!("backfills_bars");
|
||||
cleanup!("backfills_bars");
|
||||
optimize!("backfills_bars");
|
||||
set_fresh_where_symbols!("backfills_bars");
|
11
src/lib/qrust/database/backfills_news.rs
Normal file
11
src/lib/qrust/database/backfills_news.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use crate::{
|
||||
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
|
||||
types::Backfill, upsert_batch,
|
||||
};
|
||||
|
||||
select_where_symbols!(Backfill, "backfills_news");
|
||||
upsert_batch!(Backfill, "backfills_news");
|
||||
delete_where_symbols!("backfills_news");
|
||||
cleanup!("backfills_news");
|
||||
optimize!("backfills_news");
|
||||
set_fresh_where_symbols!("backfills_news");
|
21
src/lib/qrust/database/bars.rs
Normal file
21
src/lib/qrust/database/bars.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
|
||||
use clickhouse::Client;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
upsert!(Bar, "bars");
|
||||
upsert_batch!(Bar, "bars");
|
||||
delete_where_symbols!("bars");
|
||||
optimize!("bars");
|
||||
|
||||
pub async fn cleanup(
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
) -> Result<(), clickhouse::error::Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
|
||||
.execute()
|
||||
.await
|
||||
}
|
@@ -1,16 +1,19 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{optimize, types::Calendar};
|
||||
use clickhouse::error::Error;
|
||||
use tokio::try_join;
|
||||
use clickhouse::{error::Error, Client};
|
||||
use tokio::{sync::Semaphore, try_join};
|
||||
|
||||
optimize!("calendar");
|
||||
|
||||
pub async fn upsert_batch_and_delete<'a, T>(
|
||||
client: &clickhouse::Client,
|
||||
records: T,
|
||||
pub async fn upsert_batch_and_delete<'a, I>(
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
records: I,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
|
||||
T::IntoIter: Send,
|
||||
I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
|
||||
I::IntoIter: Send,
|
||||
{
|
||||
let upsert_future = async {
|
||||
let mut insert = client.insert("calendar")?;
|
||||
@@ -34,5 +37,6 @@ where
|
||||
.await
|
||||
};
|
||||
|
||||
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
|
||||
try_join!(upsert_future, delete_future).map(|_| ())
|
||||
}
|
224
src/lib/qrust/database/mod.rs
Normal file
224
src/lib/qrust/database/mod.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
pub mod assets;
|
||||
pub mod backfills_bars;
|
||||
pub mod backfills_news;
|
||||
pub mod bars;
|
||||
pub mod calendar;
|
||||
pub mod news;
|
||||
pub mod orders;
|
||||
pub mod ta;
|
||||
|
||||
use clickhouse::{error::Error, Client};
|
||||
use tokio::try_join;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn select(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
) -> Result<Vec<$record>, clickhouse::error::Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
|
||||
.fetch_all::<$record>()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select_where_symbol {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn select_where_symbol<T>(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
symbol: &T,
|
||||
) -> Result<Option<$record>, clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!(
|
||||
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
|
||||
$table_name
|
||||
))
|
||||
.bind(symbol)
|
||||
.fetch_optional::<$record>()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! select_where_symbols {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn select_where_symbols<T>(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
symbols: &[T],
|
||||
) -> Result<Vec<$record>, clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!(
|
||||
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
|
||||
$table_name
|
||||
))
|
||||
.bind(symbols)
|
||||
.fetch_all::<$record>()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! upsert {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn upsert(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
record: &$record,
|
||||
) -> Result<(), clickhouse::error::Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
let mut insert = client.insert($table_name)?;
|
||||
insert.write(record).await?;
|
||||
insert.end().await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! upsert_batch {
|
||||
($record:ty, $table_name:expr) => {
|
||||
pub async fn upsert_batch<'a, I>(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
records: I,
|
||||
) -> Result<(), clickhouse::error::Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a $record> + Send + Sync,
|
||||
I::IntoIter: Send,
|
||||
{
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
let mut insert = client.insert($table_name)?;
|
||||
for record in records {
|
||||
insert.write(record).await?;
|
||||
}
|
||||
insert.end().await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! delete_where_symbols {
|
||||
($table_name:expr) => {
|
||||
pub async fn delete_where_symbols<T>(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
symbols: &[T],
|
||||
) -> Result<(), clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
|
||||
.bind(symbols)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! cleanup {
|
||||
($table_name:expr) => {
|
||||
pub async fn cleanup(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
) -> Result<(), clickhouse::error::Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!(
|
||||
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
|
||||
$table_name
|
||||
))
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! optimize {
|
||||
($table_name:expr) => {
|
||||
pub async fn optimize(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
) -> Result<(), clickhouse::error::Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_fresh_where_symbols {
|
||||
($table_name:expr) => {
|
||||
pub async fn set_fresh_where_symbols<T>(
|
||||
client: &clickhouse::Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
fresh: bool,
|
||||
symbols: &[T],
|
||||
) -> Result<(), clickhouse::error::Error>
|
||||
where
|
||||
T: AsRef<str> + serde::Serialize + Send + Sync,
|
||||
{
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(&format!(
|
||||
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
|
||||
$table_name
|
||||
))
|
||||
.bind(fresh)
|
||||
.bind(symbols)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub async fn cleanup_all(
|
||||
clickhouse_client: &Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
) -> Result<(), Error> {
|
||||
try_join!(
|
||||
bars::cleanup(clickhouse_client, concurrency_limiter),
|
||||
news::cleanup(clickhouse_client, concurrency_limiter),
|
||||
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
|
||||
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
|
||||
)
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
pub async fn optimize_all(
|
||||
clickhouse_client: &Client,
|
||||
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
|
||||
) -> Result<(), Error> {
|
||||
try_join!(
|
||||
assets::optimize(clickhouse_client, concurrency_limiter),
|
||||
bars::optimize(clickhouse_client, concurrency_limiter),
|
||||
news::optimize(clickhouse_client, concurrency_limiter),
|
||||
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
|
||||
backfills_news::optimize(clickhouse_client, concurrency_limiter),
|
||||
orders::optimize(clickhouse_client, concurrency_limiter),
|
||||
calendar::optimize(clickhouse_client, concurrency_limiter)
|
||||
)
|
||||
.map(|_| ())
|
||||
}
|
@@ -1,24 +1,33 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{optimize, types::News, upsert, upsert_batch};
|
||||
use clickhouse::{error::Error, Client};
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
upsert!(News, "news");
|
||||
upsert_batch!(News, "news");
|
||||
optimize!("news");
|
||||
|
||||
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
|
||||
pub async fn delete_where_symbols<T>(
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
symbols: &[T],
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: AsRef<str> + Serialize + Send + Sync,
|
||||
{
|
||||
clickhouse_client
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
|
||||
.bind(symbols)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
|
||||
clickhouse_client
|
||||
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(
|
||||
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
|
||||
)
|
30
src/lib/qrust/database/ta.rs
Normal file
30
src/lib/qrust/database/ta.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use crate::types::Bar;
|
||||
use clickhouse::{error::Error, Client};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
pub async fn select(
|
||||
client: &Client,
|
||||
concurrency_limiter: &Arc<Semaphore>,
|
||||
) -> Result<Vec<Bar>, Error> {
|
||||
let _ = concurrency_limiter.acquire().await.unwrap();
|
||||
client
|
||||
.query(
|
||||
"
|
||||
SELECT symbol,
|
||||
toStartOfHour(bars.time) AS time,
|
||||
any(bars.open) AS open,
|
||||
max(bars.high) AS high,
|
||||
min(bars.low) AS low,
|
||||
anyLast(bars.close) AS close,
|
||||
sum(bars.volume) AS volume,
|
||||
sum(bars.trades) AS trades
|
||||
FROM bars FINAL
|
||||
GROUP BY ALL
|
||||
ORDER BY symbol,
|
||||
time
|
||||
",
|
||||
)
|
||||
.fetch_all::<Bar>()
|
||||
.await
|
||||
}
|
75
src/lib/qrust/ml/batcher.rs
Normal file
75
src/lib/qrust/ml/batcher.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use super::BarWindow;
|
||||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
tensor::{self, backend::Backend, Tensor},
|
||||
};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BarWindowBatcher<B: Backend> {
|
||||
pub device: B::Device,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BarWindowBatch<B: Backend> {
|
||||
pub hour_tensor: Tensor<B, 2, tensor::Int>,
|
||||
pub day_tensor: Tensor<B, 2, tensor::Int>,
|
||||
pub numerical_tensor: Tensor<B, 3>,
|
||||
pub target_tensor: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: Backend<FloatElem = f32, IntElem = i32>> Batcher<BarWindow, BarWindowBatch<B>>
|
||||
for BarWindowBatcher<B>
|
||||
{
|
||||
fn batch(&self, items: Vec<BarWindow>) -> BarWindowBatch<B> {
|
||||
let batch_size = items.len();
|
||||
|
||||
let (hour_tensors, day_tensors, numerical_tensors, target_tensors) = items
|
||||
.into_par_iter()
|
||||
.fold(
|
||||
|| {
|
||||
(
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
)
|
||||
},
|
||||
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
|
||||
item| {
|
||||
hour_tensors.push(Tensor::from_data(item.hours, &self.device));
|
||||
day_tensors.push(Tensor::from_data(item.days, &self.device));
|
||||
numerical_tensors.push(Tensor::from_data(item.numerical, &self.device));
|
||||
target_tensors.push(Tensor::from_data(item.target, &self.device));
|
||||
|
||||
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
|
||||
},
|
||||
)
|
||||
.reduce(
|
||||
|| {
|
||||
(
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
Vec::with_capacity(batch_size),
|
||||
)
|
||||
},
|
||||
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
|
||||
item| {
|
||||
hour_tensors.extend(item.0);
|
||||
day_tensors.extend(item.1);
|
||||
numerical_tensors.extend(item.2);
|
||||
target_tensors.extend(item.3);
|
||||
|
||||
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
|
||||
},
|
||||
);
|
||||
|
||||
BarWindowBatch {
|
||||
hour_tensor: Tensor::stack(hour_tensors, 0).to_device(&self.device),
|
||||
day_tensor: Tensor::stack(day_tensors, 0).to_device(&self.device),
|
||||
numerical_tensor: Tensor::stack(numerical_tensors, 0).to_device(&self.device),
|
||||
target_tensor: Tensor::stack(target_tensors, 0).to_device(&self.device),
|
||||
}
|
||||
}
|
||||
}
|
219
src/lib/qrust/ml/dataset.rs
Normal file
219
src/lib/qrust/ml/dataset.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
use crate::types::{
|
||||
ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT},
|
||||
Bar,
|
||||
};
|
||||
use burn::{
|
||||
data::dataset::{transform::ComposedDataset, Dataset},
|
||||
tensor::Data,
|
||||
};
|
||||
|
||||
pub const WINDOW_SIZE: usize = 48;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BarWindow {
|
||||
pub hours: Data<i32, 1>,
|
||||
pub days: Data<i32, 1>,
|
||||
pub numerical: Data<f32, 2>,
|
||||
pub target: Data<f32, 1>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct SingleSymbolDataset {
|
||||
hours: Vec<i32>,
|
||||
days: Vec<i32>,
|
||||
numerical: Vec<[f32; NUMERICAL_FIELD_COUNT]>,
|
||||
targets: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SingleSymbolDataset {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
pub fn new(bars: Vec<IndicatedBar>) -> Self {
|
||||
if !bars.is_empty() {
|
||||
let symbol = &bars[0].symbol;
|
||||
assert!(bars.iter().all(|bar| bar.symbol == *symbol));
|
||||
}
|
||||
|
||||
let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold(
|
||||
(
|
||||
Vec::with_capacity(bars.len() - 1),
|
||||
Vec::with_capacity(bars.len() - 1),
|
||||
Vec::with_capacity(bars.len() - 1),
|
||||
Vec::with_capacity(bars.len() - 1),
|
||||
),
|
||||
|(mut hours, mut days, mut numerical, mut targets), bar| {
|
||||
hours.push(i32::from(bar[0].hour));
|
||||
days.push(i32::from(bar[0].day));
|
||||
numerical.push([
|
||||
bar[0].open as f32,
|
||||
(bar[0].open_pct as f32).min(f32::MAX),
|
||||
bar[0].high as f32,
|
||||
(bar[0].high_pct as f32).min(f32::MAX),
|
||||
bar[0].low as f32,
|
||||
(bar[0].low_pct as f32).min(f32::MAX),
|
||||
bar[0].close as f32,
|
||||
(bar[0].close_pct as f32).min(f32::MAX),
|
||||
bar[0].volume as f32,
|
||||
(bar[0].volume_pct as f32).min(f32::MAX),
|
||||
bar[0].trades as f32,
|
||||
(bar[0].trades_pct as f32).min(f32::MAX),
|
||||
bar[0].sma_3 as f32,
|
||||
bar[0].sma_6 as f32,
|
||||
bar[0].sma_12 as f32,
|
||||
bar[0].sma_24 as f32,
|
||||
bar[0].sma_48 as f32,
|
||||
bar[0].sma_72 as f32,
|
||||
bar[0].ema_3 as f32,
|
||||
bar[0].ema_6 as f32,
|
||||
bar[0].ema_12 as f32,
|
||||
bar[0].ema_24 as f32,
|
||||
bar[0].ema_48 as f32,
|
||||
bar[0].ema_72 as f32,
|
||||
bar[0].macd as f32,
|
||||
bar[0].macd_signal as f32,
|
||||
bar[0].obv as f32,
|
||||
bar[0].rsi as f32,
|
||||
bar[0].bbands_lower as f32,
|
||||
bar[0].bbands_mean as f32,
|
||||
bar[0].bbands_upper as f32,
|
||||
]);
|
||||
targets.push(bar[1].close_pct as f32);
|
||||
(hours, days, numerical, targets)
|
||||
},
|
||||
);
|
||||
|
||||
Self {
|
||||
hours,
|
||||
days,
|
||||
numerical,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<BarWindow> for SingleSymbolDataset {
|
||||
fn len(&self) -> usize {
|
||||
self.targets.len() - WINDOW_SIZE + 1
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
fn get(&self, idx: usize) -> Option<BarWindow> {
|
||||
if idx >= self.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let hours: [i32; WINDOW_SIZE] = self.hours[idx..idx + WINDOW_SIZE].try_into().unwrap();
|
||||
let days: [i32; WINDOW_SIZE] = self.days[idx..idx + WINDOW_SIZE].try_into().unwrap();
|
||||
let numerical: [[f32; NUMERICAL_FIELD_COUNT]; WINDOW_SIZE] =
|
||||
self.numerical[idx..idx + WINDOW_SIZE].try_into().unwrap();
|
||||
let target: [f32; 1] = [self.targets[idx + WINDOW_SIZE - 1]];
|
||||
|
||||
Some(BarWindow {
|
||||
hours: Data::from(hours),
|
||||
days: Data::from(days),
|
||||
numerical: Data::from(numerical),
|
||||
target: Data::from(target),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MultipleSymbolDataset {
|
||||
composed_dataset: ComposedDataset<SingleSymbolDataset>,
|
||||
}
|
||||
|
||||
impl MultipleSymbolDataset {
|
||||
pub fn new(bars: Vec<Bar>) -> Self {
|
||||
let groups = calculate_indicators(bars)
|
||||
.into_iter()
|
||||
.map(SingleSymbolDataset::new)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Self {
|
||||
composed_dataset: ComposedDataset::new(groups),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<BarWindow> for MultipleSymbolDataset {
|
||||
fn len(&self) -> usize {
|
||||
self.composed_dataset.len()
|
||||
}
|
||||
|
||||
fn get(&self, idx: usize) -> Option<BarWindow> {
|
||||
self.composed_dataset.get(idx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{
|
||||
distributions::{Distribution, Uniform},
|
||||
Rng,
|
||||
};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn generate_random_dataset(length: usize) -> MultipleSymbolDataset {
|
||||
let mut rng = rand::thread_rng();
|
||||
let uniform = Uniform::new(1.0, 100.0);
|
||||
let mut bars = Vec::with_capacity(length);
|
||||
|
||||
for _ in 0..=(length + (HEAD_SIZE - 1) + (WINDOW_SIZE - 1)) {
|
||||
bars.push(Bar {
|
||||
symbol: "AAPL".to_string(),
|
||||
time: OffsetDateTime::now_utc(),
|
||||
open: uniform.sample(&mut rng),
|
||||
high: uniform.sample(&mut rng),
|
||||
low: uniform.sample(&mut rng),
|
||||
close: uniform.sample(&mut rng),
|
||||
volume: uniform.sample(&mut rng),
|
||||
trades: rng.gen_range(1..100),
|
||||
});
|
||||
}
|
||||
|
||||
MultipleSymbolDataset::new(bars)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_symbol_dataset() {
|
||||
let length = 100;
|
||||
let dataset = generate_random_dataset(length);
|
||||
|
||||
assert_eq!(dataset.len(), length);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_symbol_dataset_window() {
|
||||
let length = 100;
|
||||
let dataset = generate_random_dataset(length);
|
||||
|
||||
let item = dataset.get(0).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
item.numerical.shape.dims,
|
||||
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
|
||||
);
|
||||
assert_eq!(item.target.shape.dims, [1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_symbol_dataset_last_window() {
|
||||
let length = 100;
|
||||
let dataset = generate_random_dataset(length);
|
||||
|
||||
let item = dataset.get(dataset.len() - 1).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
item.numerical.shape.dims,
|
||||
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
|
||||
);
|
||||
assert_eq!(item.target.shape.dims, [1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_symbol_dataset_out_of_bounds() {
|
||||
let length = 100;
|
||||
let dataset = generate_random_dataset(length);
|
||||
|
||||
assert!(dataset.get(dataset.len()).is_none());
|
||||
}
|
||||
}
|
21
src/lib/qrust/ml/mod.rs
Normal file
21
src/lib/qrust/ml/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
pub mod batcher;
|
||||
pub mod dataset;
|
||||
pub mod model;
|
||||
|
||||
pub use batcher::{BarWindowBatch, BarWindowBatcher};
|
||||
pub use dataset::{BarWindow, MultipleSymbolDataset};
|
||||
pub use model::{Model, ModelConfig};
|
||||
|
||||
use burn::{
|
||||
backend::{
|
||||
wgpu::{AutoGraphicsApi, WgpuDevice},
|
||||
Autodiff, Wgpu,
|
||||
},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
pub type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
pub type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
pub type MyDevice = <Autodiff<Wgpu> as Backend>::Device;
|
||||
|
||||
pub const DEVICE: MyDevice = WgpuDevice::BestAvailable;
|
160
src/lib/qrust/ml/model.rs
Normal file
160
src/lib/qrust/ml/model.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use super::BarWindowBatch;
|
||||
use crate::types::ta::NUMERICAL_FIELD_COUNT;
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
loss::{MseLoss, Reduction},
|
||||
Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig,
|
||||
},
|
||||
tensor::{
|
||||
self,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
},
|
||||
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
hour_embedding: Embedding<B>,
|
||||
day_embedding: Embedding<B>,
|
||||
lstm_1: Lstm<B>,
|
||||
dropout_1: Dropout,
|
||||
lstm_2: Lstm<B>,
|
||||
dropout_2: Dropout,
|
||||
lstm_3: Lstm<B>,
|
||||
dropout_3: Dropout,
|
||||
lstm_4: Lstm<B>,
|
||||
dropout_4: Dropout,
|
||||
linear: Linear<B>,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ModelConfig {
|
||||
#[config(default = "3")]
|
||||
pub hour_features: usize,
|
||||
#[config(default = "2")]
|
||||
pub day_features: usize,
|
||||
#[config(default = "{NUMERICAL_FIELD_COUNT}")]
|
||||
pub numerical_features: usize,
|
||||
#[config(default = "0.2")]
|
||||
pub dropout: f64,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
|
||||
let num_features = self.numerical_features + self.hour_features + self.day_features;
|
||||
|
||||
let lstm_1_hidden_size = 512;
|
||||
let lstm_2_hidden_size = 256;
|
||||
let lstm_3_hidden_size = 64;
|
||||
let lstm_4_hidden_size = 32;
|
||||
|
||||
Model {
|
||||
hour_embedding: EmbeddingConfig::new(24, self.hour_features).init(device),
|
||||
day_embedding: EmbeddingConfig::new(7, self.day_features).init(device),
|
||||
lstm_1: LstmConfig::new(num_features, lstm_1_hidden_size, true).init(device),
|
||||
dropout_1: DropoutConfig::new(self.dropout).init(),
|
||||
lstm_2: LstmConfig::new(lstm_1_hidden_size, lstm_2_hidden_size, true).init(device),
|
||||
dropout_2: DropoutConfig::new(self.dropout).init(),
|
||||
lstm_3: LstmConfig::new(lstm_2_hidden_size, lstm_3_hidden_size, true).init(device),
|
||||
dropout_3: DropoutConfig::new(self.dropout).init(),
|
||||
lstm_4: LstmConfig::new(lstm_3_hidden_size, lstm_4_hidden_size, true).init(device),
|
||||
dropout_4: DropoutConfig::new(self.dropout).init(),
|
||||
linear: LinearConfig::new(lstm_4_hidden_size, 1).init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
hour: Tensor<B, 2, tensor::Int>,
|
||||
day: Tensor<B, 2, tensor::Int>,
|
||||
numerical: Tensor<B, 3>,
|
||||
) -> Tensor<B, 2> {
|
||||
let hour = self.hour_embedding.forward(hour);
|
||||
let day = self.day_embedding.forward(day);
|
||||
|
||||
let x = Tensor::cat(vec![hour, day, numerical], 2);
|
||||
|
||||
let (_, x) = self.lstm_1.forward(x, None);
|
||||
let x = self.dropout_1.forward(x);
|
||||
let (_, x) = self.lstm_2.forward(x, None);
|
||||
let x = self.dropout_2.forward(x);
|
||||
let (_, x) = self.lstm_3.forward(x, None);
|
||||
let x = self.dropout_3.forward(x);
|
||||
let (_, x) = self.lstm_4.forward(x, None);
|
||||
let x = self.dropout_4.forward(x);
|
||||
|
||||
let [batch_size, window_size, features] = x.shape().dims;
|
||||
|
||||
let x = x.slice([0..batch_size, window_size - 1..window_size, 0..features]);
|
||||
let x = x.squeeze(1);
|
||||
|
||||
self.linear.forward(x)
|
||||
}
|
||||
|
||||
pub fn forward_regression(
|
||||
&self,
|
||||
hour: Tensor<B, 2, tensor::Int>,
|
||||
day: Tensor<B, 2, tensor::Int>,
|
||||
numerical: Tensor<B, 3>,
|
||||
target: Tensor<B, 2>,
|
||||
) -> RegressionOutput<B> {
|
||||
let output = self.forward(hour, day, numerical);
|
||||
let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean);
|
||||
|
||||
RegressionOutput::new(loss, output, target)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: BarWindowBatch<B>) -> TrainOutput<RegressionOutput<B>> {
|
||||
let item = self.forward_regression(
|
||||
batch.hour_tensor,
|
||||
batch.day_tensor,
|
||||
batch.numerical_tensor,
|
||||
batch.target_tensor,
|
||||
);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
|
||||
fn step(&self, batch: BarWindowBatch<B>) -> RegressionOutput<B> {
|
||||
self.forward_regression(
|
||||
batch.hour_tensor,
|
||||
batch.day_tensor,
|
||||
batch.numerical_tensor,
|
||||
batch.target_tensor,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::{backend::Wgpu, tensor::Distribution};
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_model() {
|
||||
let device = Default::default();
|
||||
let distribution = Distribution::Normal(0.0, 1.0);
|
||||
|
||||
let config = ModelConfig::new().with_numerical_features(7);
|
||||
|
||||
let model = config.init::<Wgpu>(&device);
|
||||
|
||||
let hour = Tensor::ones([2, 10], &device);
|
||||
let day = Tensor::ones([2, 10], &device);
|
||||
let numerical = Tensor::random([2, 10, 7], distribution, &device);
|
||||
|
||||
let output = model.forward(hour, day, numerical);
|
||||
|
||||
assert_eq!(output.shape().dims, [2, 1]);
|
||||
}
|
||||
}
|
6
src/lib/qrust/mod.rs
Normal file
6
src/lib/qrust/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod alpaca;
|
||||
pub mod database;
|
||||
pub mod ml;
|
||||
pub mod ta;
|
||||
pub mod types;
|
||||
pub mod utils;
|
149
src/lib/qrust/ta/bbands.rs
Normal file
149
src/lib/qrust/ta/bbands.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
|
||||
|
||||
pub struct BbandsState {
|
||||
window: VecDeque<f64>,
|
||||
sum: f64,
|
||||
squared_sum: f64,
|
||||
multiplier: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Bbands<T>: Iterator + Sized {
|
||||
fn bbands(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
multiplier: f64, // Typically 2.0
|
||||
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>>;
|
||||
}
|
||||
|
||||
impl<I, T> Bbands<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn bbands(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
multiplier: f64,
|
||||
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>> {
|
||||
self.scan(
|
||||
BbandsState {
|
||||
window: VecDeque::from(vec![0.0; period.get()]),
|
||||
sum: 0.0,
|
||||
squared_sum: 0.0,
|
||||
multiplier,
|
||||
},
|
||||
|state: &mut BbandsState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
let front = state.window.pop_front().unwrap();
|
||||
state.sum -= front;
|
||||
state.squared_sum -= front.powi(2);
|
||||
|
||||
state.window.push_back(value);
|
||||
state.sum += value;
|
||||
state.squared_sum += value.powi(2);
|
||||
|
||||
let mean = state.sum / state.window.len() as f64;
|
||||
let variance =
|
||||
((state.squared_sum / state.window.len() as f64) - mean.powi(2)).max(0.0);
|
||||
let standard_deviation = variance.sqrt();
|
||||
|
||||
let upper_band = mean + state.multiplier * standard_deviation;
|
||||
let lower_band = mean - state.multiplier * standard_deviation;
|
||||
|
||||
Some((upper_band, mean, lower_band))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bbands() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let bbands = data
|
||||
.into_iter()
|
||||
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
|
||||
.map(|(upper, mean, lower)| {
|
||||
(
|
||||
(upper * 100.0).round() / 100.0,
|
||||
(mean * 100.0).round() / 100.0,
|
||||
(lower * 100.0).round() / 100.0,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
bbands,
|
||||
vec![
|
||||
(1.28, 0.33, -0.61),
|
||||
(2.63, 1.0, -0.63),
|
||||
(3.63, 2.0, 0.37),
|
||||
(4.63, 3.0, 1.37),
|
||||
(5.63, 4.0, 2.37)
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbands_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let bbands = data
|
||||
.into_iter()
|
||||
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(bbands, Vec::<(f64, f64, f64)>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbands_1_period() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let bbands = data
|
||||
.into_iter()
|
||||
.bbands(NonZeroUsize::new(1).unwrap(), 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
bbands,
|
||||
vec![
|
||||
(1.0, 1.0, 1.0),
|
||||
(2.0, 2.0, 2.0),
|
||||
(3.0, 3.0, 3.0),
|
||||
(4.0, 4.0, 4.0),
|
||||
(5.0, 5.0, 5.0)
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbands_borrow() {
|
||||
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let bbands = data
|
||||
.iter()
|
||||
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
|
||||
.map(|(upper, mean, lower)| {
|
||||
(
|
||||
(upper * 100.0).round() / 100.0,
|
||||
(mean * 100.0).round() / 100.0,
|
||||
(lower * 100.0).round() / 100.0,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
bbands,
|
||||
vec![
|
||||
(1.28, 0.33, -0.61),
|
||||
(2.63, 1.0, -0.63),
|
||||
(3.63, 2.0, 0.37),
|
||||
(4.63, 3.0, 1.37),
|
||||
(5.63, 4.0, 2.37)
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
59
src/lib/qrust/ta/deriv.rs
Normal file
59
src/lib/qrust/ta/deriv.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::{borrow::Borrow, iter::Scan};
|
||||
|
||||
pub struct DerivState {
|
||||
pub last: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Deriv<T>: Iterator + Sized {
|
||||
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Deriv<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>> {
|
||||
self.scan(
|
||||
DerivState { last: 0.0 },
|
||||
|state: &mut DerivState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
let deriv = value - state.last;
|
||||
state.last = value;
|
||||
|
||||
Some(deriv)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deriv() {
|
||||
let data = vec![1.0, 3.0, 6.0, 3.0, 1.0];
|
||||
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deriv_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(deriv, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deriv_borrow() {
|
||||
let data = [1.0, 3.0, 6.0, 3.0, 1.0];
|
||||
let deriv = data.iter().deriv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
|
||||
}
|
||||
}
|
95
src/lib/qrust/ta/ema.rs
Normal file
95
src/lib/qrust/ta/ema.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
iter::{Peekable, Scan},
|
||||
num::NonZeroUsize,
|
||||
};
|
||||
|
||||
pub struct EmaState {
|
||||
weight: f64,
|
||||
ema: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Ema<T>: Iterator + Sized {
|
||||
fn ema(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Ema<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn ema(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>> {
|
||||
let smoothing = 2.0;
|
||||
let weight = smoothing / (1.0 + period.get() as f64);
|
||||
|
||||
let mut iter = self.peekable();
|
||||
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
|
||||
|
||||
iter.scan(
|
||||
EmaState { weight, ema: first },
|
||||
|state: &mut EmaState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
state.ema = (value * state.weight) + (state.ema * (1.0 - state.weight));
|
||||
|
||||
Some(state.ema)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ema() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let ema = data
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let ema = data
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(ema, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_1_period() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let ema = data
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(1).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(ema, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_borrow() {
|
||||
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let ema = data
|
||||
.iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
|
||||
}
|
||||
}
|
216
src/lib/qrust/ta/macd.rs
Normal file
216
src/lib/qrust/ta/macd.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
iter::{Peekable, Scan},
|
||||
num::NonZeroUsize,
|
||||
};
|
||||
|
||||
pub struct MacdState {
|
||||
short_weight: f64,
|
||||
long_weight: f64,
|
||||
signal_weight: f64,
|
||||
short_ema: f64,
|
||||
long_ema: f64,
|
||||
signal_ema: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Macd<T>: Iterator + Sized {
|
||||
fn macd(
|
||||
self,
|
||||
short_period: NonZeroUsize, // Typically 12
|
||||
long_period: NonZeroUsize, // Typically 26
|
||||
signal_period: NonZeroUsize, // Typically 9
|
||||
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>>;
|
||||
}
|
||||
|
||||
impl<I, T> Macd<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn macd(
|
||||
self,
|
||||
short_period: NonZeroUsize,
|
||||
long_period: NonZeroUsize,
|
||||
signal_period: NonZeroUsize,
|
||||
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>> {
|
||||
let smoothing = 2.0;
|
||||
let short_weight = smoothing / (1.0 + short_period.get() as f64);
|
||||
let long_weight = smoothing / (1.0 + long_period.get() as f64);
|
||||
let signal_weight = smoothing / (1.0 + signal_period.get() as f64);
|
||||
|
||||
let mut iter = self.peekable();
|
||||
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
|
||||
|
||||
iter.scan(
|
||||
MacdState {
|
||||
short_weight,
|
||||
long_weight,
|
||||
signal_weight,
|
||||
short_ema: first,
|
||||
long_ema: first,
|
||||
signal_ema: 0.0,
|
||||
},
|
||||
|state: &mut MacdState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
state.short_ema =
|
||||
(value * state.short_weight) + (state.short_ema * (1.0 - state.short_weight));
|
||||
state.long_ema =
|
||||
(value * state.long_weight) + (state.long_ema * (1.0 - state.long_weight));
|
||||
|
||||
let macd = state.short_ema - state.long_ema;
|
||||
state.signal_ema =
|
||||
(macd * state.signal_weight) + (state.signal_ema * (1.0 - state.signal_weight));
|
||||
|
||||
Some((macd, state.signal_ema))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::ema::Ema;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_macd() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
|
||||
let short_ema = data
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let long_ema = data
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(5).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let macd = short_ema
|
||||
.into_iter()
|
||||
.zip(long_ema)
|
||||
.map(|(short, long)| short - long)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let signal = macd
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
data.into_iter()
|
||||
.macd(
|
||||
NonZeroUsize::new(3).unwrap(),
|
||||
NonZeroUsize::new(5).unwrap(),
|
||||
NonZeroUsize::new(3).unwrap()
|
||||
)
|
||||
.collect::<Vec<_>>(),
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_macd_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
|
||||
assert_eq!(
|
||||
data.into_iter()
|
||||
.macd(
|
||||
NonZeroUsize::new(3).unwrap(),
|
||||
NonZeroUsize::new(5).unwrap(),
|
||||
NonZeroUsize::new(3).unwrap()
|
||||
)
|
||||
.collect::<Vec<_>>(),
|
||||
vec![]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_macd_1_period() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
let short_ema = data
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(1).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let long_ema = data
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(1).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let macd = short_ema
|
||||
.into_iter()
|
||||
.zip(long_ema)
|
||||
.map(|(short, long)| short - long)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let signal = macd
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(1).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
data.into_iter()
|
||||
.macd(
|
||||
NonZeroUsize::new(1).unwrap(),
|
||||
NonZeroUsize::new(1).unwrap(),
|
||||
NonZeroUsize::new(1).unwrap()
|
||||
)
|
||||
.collect::<Vec<_>>(),
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_macd_borrow() {
|
||||
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
let short_ema = data
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let long_ema = data
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(5).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let macd = short_ema
|
||||
.into_iter()
|
||||
.zip(long_ema)
|
||||
.map(|(short, long)| short - long)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let signal = macd
|
||||
.clone()
|
||||
.into_iter()
|
||||
.ema(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
data.iter()
|
||||
.macd(
|
||||
NonZeroUsize::new(3).unwrap(),
|
||||
NonZeroUsize::new(5).unwrap(),
|
||||
NonZeroUsize::new(3).unwrap()
|
||||
)
|
||||
.collect::<Vec<_>>(),
|
||||
expected
|
||||
);
|
||||
}
|
||||
}
|
17
src/lib/qrust/ta/mod.rs
Normal file
17
src/lib/qrust/ta/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
pub mod bbands;
|
||||
pub mod deriv;
|
||||
pub mod ema;
|
||||
pub mod macd;
|
||||
pub mod obv;
|
||||
pub mod pct;
|
||||
pub mod rsi;
|
||||
pub mod sma;
|
||||
|
||||
pub use bbands::Bbands;
|
||||
pub use deriv::Deriv;
|
||||
pub use ema::Ema;
|
||||
pub use macd::Macd;
|
||||
pub use obv::Obv;
|
||||
pub use pct::Pct;
|
||||
pub use rsi::Rsi;
|
||||
pub use sma::Sma;
|
73
src/lib/qrust/ta/obv.rs
Normal file
73
src/lib/qrust/ta/obv.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
iter::{Peekable, Scan},
|
||||
};
|
||||
|
||||
pub struct ObvState {
|
||||
last: f64,
|
||||
obv: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Obv<T>: Iterator + Sized {
|
||||
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Obv<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<(f64, f64)>,
|
||||
{
|
||||
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>> {
|
||||
let mut iter = self.peekable();
|
||||
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
|
||||
|
||||
iter.scan(
|
||||
ObvState {
|
||||
last: first.0,
|
||||
obv: 0.0,
|
||||
},
|
||||
|state: &mut ObvState, value: T| {
|
||||
let (close, volume) = *value.borrow();
|
||||
|
||||
if close > state.last {
|
||||
state.obv += volume;
|
||||
} else if close < state.last {
|
||||
state.obv -= volume;
|
||||
}
|
||||
state.last = close;
|
||||
|
||||
Some(state.obv)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_obv() {
|
||||
let data = vec![(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
|
||||
let obv = data.into_iter().obv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_obv_empty() {
|
||||
let data = Vec::<(f64, f64)>::new();
|
||||
let obv = data.into_iter().obv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(obv, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_obv_borrow() {
|
||||
let data = [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
|
||||
let obv = data.iter().obv().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
|
||||
}
|
||||
}
|
64
src/lib/qrust/ta/pct.rs
Normal file
64
src/lib/qrust/ta/pct.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use std::{borrow::Borrow, iter::Scan};
|
||||
|
||||
pub struct PctState {
|
||||
pub last: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Pct<T>: Iterator + Sized {
|
||||
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Pct<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>> {
|
||||
self.scan(PctState { last: 0.0 }, |state: &mut PctState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
let pct = value / state.last - 1.0;
|
||||
state.last = value;
|
||||
|
||||
Some(pct)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pct() {
|
||||
let data = vec![1.0, 2.0, 4.0, 2.0, 1.0];
|
||||
let pct = data.into_iter().pct().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pct_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let pct = data.into_iter().pct().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(pct, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pct_0() {
|
||||
let data = vec![1.0, 0.0, 4.0, 2.0, 1.0];
|
||||
let pct = data.into_iter().pct().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(pct, vec![f64::INFINITY, -1.0, f64::INFINITY, -0.5, -0.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pct_borrow() {
|
||||
let data = [1.0, 2.0, 4.0, 2.0, 1.0];
|
||||
let pct = data.iter().pct().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
|
||||
}
|
||||
}
|
135
src/lib/qrust/ta/rsi.rs
Normal file
135
src/lib/qrust/ta/rsi.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::VecDeque,
|
||||
iter::{Peekable, Scan},
|
||||
num::NonZeroUsize,
|
||||
};
|
||||
|
||||
pub struct RsiState {
|
||||
last: f64,
|
||||
window_gains: VecDeque<f64>,
|
||||
window_losses: VecDeque<f64>,
|
||||
sum_gains: f64,
|
||||
sum_losses: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Rsi<T>: Iterator + Sized {
|
||||
fn rsi(
|
||||
self,
|
||||
period: NonZeroUsize, // Typically 14
|
||||
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Rsi<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn rsi(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>> {
|
||||
let mut iter = self.peekable();
|
||||
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
|
||||
|
||||
iter.scan(
|
||||
RsiState {
|
||||
last: first,
|
||||
window_gains: VecDeque::from(vec![0.0; period.get()]),
|
||||
window_losses: VecDeque::from(vec![0.0; period.get()]),
|
||||
sum_gains: 0.0,
|
||||
sum_losses: 0.0,
|
||||
},
|
||||
|state, value| {
|
||||
let value = *value.borrow();
|
||||
|
||||
state.sum_gains -= state.window_gains.pop_front().unwrap();
|
||||
state.sum_losses -= state.window_losses.pop_front().unwrap();
|
||||
|
||||
let gain = (value - state.last).max(0.0);
|
||||
let loss = (state.last - value).max(0.0);
|
||||
|
||||
state.last = value;
|
||||
|
||||
state.window_gains.push_back(gain);
|
||||
state.window_losses.push_back(loss);
|
||||
state.sum_gains += gain;
|
||||
state.sum_losses += loss;
|
||||
|
||||
let avg_loss = state.sum_losses / state.window_losses.len() as f64;
|
||||
|
||||
if avg_loss == 0.0 {
|
||||
return Some(100.0);
|
||||
}
|
||||
|
||||
let avg_gain = state.sum_gains / state.window_gains.len() as f64;
|
||||
let rs = avg_gain / avg_loss;
|
||||
|
||||
Some(100.0 - (100.0 / (1.0 + rs)))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rsi() {
|
||||
let data = vec![1.0, 4.0, 7.0, 4.0, 1.0];
|
||||
let rsi = data
|
||||
.into_iter()
|
||||
.rsi(NonZeroUsize::new(3).unwrap())
|
||||
.map(|v| (v * 100.0).round() / 100.0)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rsi_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let rsi = data
|
||||
.into_iter()
|
||||
.rsi(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(rsi, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rsi_no_loss() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let rsi = data
|
||||
.into_iter()
|
||||
.rsi(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 100.0, 100.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rsi_no_gain() {
|
||||
let data = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
let rsi = data
|
||||
.into_iter()
|
||||
.rsi(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(rsi, vec![100.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rsi_borrow() {
|
||||
let data = [1.0, 4.0, 7.0, 4.0, 1.0];
|
||||
let rsi = data
|
||||
.iter()
|
||||
.rsi(NonZeroUsize::new(3).unwrap())
|
||||
.map(|v| (v * 100.0).round() / 100.0)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
|
||||
}
|
||||
}
|
88
src/lib/qrust/ta/sma.rs
Normal file
88
src/lib/qrust/ta/sma.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
|
||||
|
||||
pub struct SmaState {
|
||||
window: VecDeque<f64>,
|
||||
sum: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub trait Sma<T>: Iterator + Sized {
|
||||
fn sma(self, period: NonZeroUsize)
|
||||
-> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>>;
|
||||
}
|
||||
|
||||
impl<I, T> Sma<T> for I
|
||||
where
|
||||
I: Iterator<Item = T>,
|
||||
T: Borrow<f64>,
|
||||
{
|
||||
fn sma(
|
||||
self,
|
||||
period: NonZeroUsize,
|
||||
) -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>> {
|
||||
self.scan(
|
||||
SmaState {
|
||||
window: VecDeque::from(vec![0.0; period.get()]),
|
||||
sum: 0.0,
|
||||
},
|
||||
|state: &mut SmaState, value: T| {
|
||||
let value = *value.borrow();
|
||||
|
||||
state.sum -= state.window.pop_front().unwrap();
|
||||
state.window.push_back(value);
|
||||
state.sum += value;
|
||||
|
||||
Some(state.sum / state.window.len() as f64)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sma() {
|
||||
let data = vec![3.0, 6.0, 9.0, 12.0, 15.0];
|
||||
let sma = data
|
||||
.into_iter()
|
||||
.sma(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sma_empty() {
|
||||
let data = Vec::<f64>::new();
|
||||
let sma = data
|
||||
.into_iter()
|
||||
.sma(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(sma, Vec::<f64>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sma_1_period() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let sma = data
|
||||
.into_iter()
|
||||
.sma(NonZeroUsize::new(1).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(sma, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sma_borrow() {
|
||||
let data = [3.0, 6.0, 9.0, 12.0, 15.0];
|
||||
let sma = data
|
||||
.iter()
|
||||
.sma(NonZeroUsize::new(3).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
|
||||
}
|
||||
}
|
@@ -1,13 +1,7 @@
|
||||
use crate::config::ALPACA_API_URL;
|
||||
use backoff::{future::retry_notify, ExponentialBackoff};
|
||||
use governor::DefaultDirectRateLimiter;
|
||||
use log::warn;
|
||||
use reqwest::{Client, Error};
|
||||
use serde::Deserialize;
|
||||
use serde_aux::field_attributes::{
|
||||
deserialize_number_from_string, deserialize_option_number_from_string,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use time::OffsetDateTime;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -79,38 +73,3 @@ pub struct Account {
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub regt_buying_power: f64,
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
alpaca_client: &Client,
|
||||
alpaca_rate_limiter: &DefaultDirectRateLimiter,
|
||||
backoff: Option<ExponentialBackoff>,
|
||||
) -> Result<Account, Error> {
|
||||
retry_notify(
|
||||
backoff.unwrap_or_default(),
|
||||
|| async {
|
||||
alpaca_rate_limiter.until_ready().await;
|
||||
alpaca_client
|
||||
.get(&format!("{}/account", *ALPACA_API_URL))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()
|
||||
.map_err(|e| match e.status() {
|
||||
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
|
||||
backoff::Error::Permanent(e)
|
||||
}
|
||||
_ => e.into(),
|
||||
})?
|
||||
.json::<Account>()
|
||||
.await
|
||||
.map_err(backoff::Error::Permanent)
|
||||
},
|
||||
|e, duration: Duration| {
|
||||
warn!(
|
||||
"Failed to get account, will retry in {} seconds: {}",
|
||||
duration.as_secs(),
|
||||
e
|
||||
);
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
39
src/lib/qrust/types/alpaca/api/incoming/asset.rs
Normal file
39
src/lib/qrust/types/alpaca/api/incoming/asset.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use super::position::Position;
|
||||
use crate::types::{self, alpaca::shared::asset};
|
||||
use serde::Deserialize;
|
||||
use serde_aux::field_attributes::deserialize_option_number_from_string;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub use asset::{Class, Exchange, Status};
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Deserialize, Clone)]
|
||||
pub struct Asset {
|
||||
pub id: Uuid,
|
||||
pub class: Class,
|
||||
pub exchange: Exchange,
|
||||
pub symbol: String,
|
||||
pub name: String,
|
||||
pub status: Status,
|
||||
pub tradable: bool,
|
||||
pub marginable: bool,
|
||||
pub shortable: bool,
|
||||
pub easy_to_borrow: bool,
|
||||
pub fractionable: bool,
|
||||
#[serde(deserialize_with = "deserialize_option_number_from_string")]
|
||||
pub maintenance_margin_requirement: Option<f32>,
|
||||
pub attributes: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl From<(Asset, Option<Position>)> for types::Asset {
|
||||
fn from((asset, position): (Asset, Option<Position>)) -> Self {
|
||||
Self {
|
||||
symbol: asset.symbol,
|
||||
class: asset.class.into(),
|
||||
exchange: asset.exchange.into(),
|
||||
status: asset.status.into(),
|
||||
time_added: time::OffsetDateTime::now_utc(),
|
||||
qty: position.map(|position| position.qty).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
39
src/lib/qrust/types/alpaca/api/incoming/bar.rs
Normal file
39
src/lib/qrust/types/alpaca/api/incoming/bar.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use crate::types;
|
||||
use serde::Deserialize;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Bar {
|
||||
#[serde(rename = "t")]
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub time: OffsetDateTime,
|
||||
#[serde(rename = "o")]
|
||||
pub open: f64,
|
||||
#[serde(rename = "h")]
|
||||
pub high: f64,
|
||||
#[serde(rename = "l")]
|
||||
pub low: f64,
|
||||
#[serde(rename = "c")]
|
||||
pub close: f64,
|
||||
#[serde(rename = "v")]
|
||||
pub volume: f64,
|
||||
#[serde(rename = "n")]
|
||||
pub trades: i64,
|
||||
#[serde(rename = "vw")]
|
||||
pub vwap: f64,
|
||||
}
|
||||
|
||||
impl From<(Bar, String)> for types::Bar {
|
||||
fn from((bar, symbol): (Bar, String)) -> Self {
|
||||
Self {
|
||||
symbol,
|
||||
time: bar.time,
|
||||
open: bar.open,
|
||||
high: bar.high,
|
||||
low: bar.low,
|
||||
close: bar.close,
|
||||
volume: bar.volume,
|
||||
trades: bar.trades,
|
||||
}
|
||||
}
|
||||
}
|
26
src/lib/qrust/types/alpaca/api/incoming/calendar.rs
Normal file
26
src/lib/qrust/types/alpaca/api/incoming/calendar.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use crate::{
|
||||
types,
|
||||
utils::{de, time::EST_OFFSET},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use time::{Date, OffsetDateTime, Time};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Calendar {
|
||||
pub date: Date,
|
||||
#[serde(deserialize_with = "de::human_time_hh_mm")]
|
||||
pub open: Time,
|
||||
#[serde(deserialize_with = "de::human_time_hh_mm")]
|
||||
pub close: Time,
|
||||
pub settlement_date: Date,
|
||||
}
|
||||
|
||||
impl From<Calendar> for types::Calendar {
|
||||
fn from(calendar: Calendar) -> Self {
|
||||
Self {
|
||||
date: calendar.date,
|
||||
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
|
||||
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
|
||||
}
|
||||
}
|
||||
}
|
13
src/lib/qrust/types/alpaca/api/incoming/clock.rs
Normal file
13
src/lib/qrust/types/alpaca/api/incoming/clock.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use serde::Deserialize;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Clock {
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub timestamp: OffsetDateTime,
|
||||
pub is_open: bool,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub next_open: OffsetDateTime,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub next_close: OffsetDateTime,
|
||||
}
|
57
src/lib/qrust/types/alpaca/api/incoming/news.rs
Normal file
57
src/lib/qrust/types/alpaca/api/incoming/news.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::{
|
||||
types::{self, alpaca::shared::news::strip},
|
||||
utils::de,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ImageSize {
|
||||
Thumb,
|
||||
Small,
|
||||
Large,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Image {
|
||||
pub size: ImageSize,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct News {
|
||||
pub id: i64,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
#[serde(rename = "created_at")]
|
||||
pub time_created: OffsetDateTime,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
#[serde(rename = "updated_at")]
|
||||
pub time_updated: OffsetDateTime,
|
||||
#[serde(deserialize_with = "de::add_slash_to_symbols")]
|
||||
pub symbols: Vec<String>,
|
||||
pub headline: String,
|
||||
pub author: String,
|
||||
pub source: String,
|
||||
pub summary: String,
|
||||
pub content: String,
|
||||
pub url: Option<String>,
|
||||
pub images: Vec<Image>,
|
||||
}
|
||||
|
||||
impl From<News> for types::News {
|
||||
fn from(news: News) -> Self {
|
||||
Self {
|
||||
id: news.id,
|
||||
time_created: news.time_created,
|
||||
time_updated: news.time_updated,
|
||||
symbols: news.symbols,
|
||||
headline: strip(&news.headline),
|
||||
author: strip(&news.author),
|
||||
source: strip(&news.source),
|
||||
summary: news.summary,
|
||||
content: news.content,
|
||||
url: news.url.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
3
src/lib/qrust/types/alpaca/api/incoming/order.rs
Normal file
3
src/lib/qrust/types/alpaca/api/incoming/order.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
use crate::types::alpaca::shared::order;
|
||||
|
||||
pub use order::{Order, Side};
|
61
src/lib/qrust/types/alpaca/api/incoming/position.rs
Normal file
61
src/lib/qrust/types/alpaca/api/incoming/position.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use crate::{
|
||||
types::alpaca::api::incoming::{
|
||||
asset::{Class, Exchange},
|
||||
order,
|
||||
},
|
||||
utils::de,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_aux::field_attributes::deserialize_number_from_string;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Deserialize, Clone, Copy)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Side {
|
||||
Long,
|
||||
Short,
|
||||
}
|
||||
|
||||
impl From<Side> for order::Side {
|
||||
fn from(side: Side) -> Self {
|
||||
match side {
|
||||
Side::Long => Self::Buy,
|
||||
Side::Short => Self::Sell,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
pub struct Position {
|
||||
pub asset_id: Uuid,
|
||||
#[serde(deserialize_with = "de::add_slash_to_symbol")]
|
||||
pub symbol: String,
|
||||
pub exchange: Exchange,
|
||||
pub asset_class: Class,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub avg_entry_price: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub qty: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub qty_available: f64,
|
||||
pub side: Side,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub market_value: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub cost_basis: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub unrealized_pl: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub unrealized_plpc: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub unrealized_intraday_pl: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub unrealized_intraday_plpc: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub current_price: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub lastday_price: f64,
|
||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||
pub change_today: f64,
|
||||
pub asset_marginable: bool,
|
||||
}
|
6
src/lib/qrust/types/alpaca/api/mod.rs
Normal file
6
src/lib/qrust/types/alpaca/api/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod incoming;
|
||||
pub mod outgoing;
|
||||
|
||||
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
|
||||
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
|
||||
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
|
23
src/lib/qrust/types/alpaca/api/outgoing/asset.rs
Normal file
23
src/lib/qrust/types/alpaca/api/outgoing/asset.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use crate::types::alpaca::shared::asset;
|
||||
use serde::Serialize;
|
||||
|
||||
pub use asset::{Class, Exchange, Status};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Asset {
|
||||
pub status: Option<Status>,
|
||||
pub class: Option<Class>,
|
||||
pub exchange: Option<Exchange>,
|
||||
pub attributes: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Default for Asset {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
status: None,
|
||||
class: Some(Class::UsEquity),
|
||||
exchange: None,
|
||||
attributes: None,
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,12 +1,14 @@
|
||||
use crate::{
|
||||
config::ALPACA_SOURCE,
|
||||
types::alpaca::shared::{Sort, Source},
|
||||
alpaca::bars::MAX_LIMIT,
|
||||
types::alpaca::shared,
|
||||
utils::{ser, ONE_MINUTE},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use std::time::Duration;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
pub use shared::{Sort, Source};
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[allow(dead_code)]
|
||||
@@ -53,10 +55,10 @@ impl Default for UsEquity {
|
||||
timeframe: ONE_MINUTE,
|
||||
start: None,
|
||||
end: None,
|
||||
limit: Some(10000),
|
||||
limit: Some(MAX_LIMIT),
|
||||
adjustment: Some(Adjustment::All),
|
||||
asof: None,
|
||||
feed: Some(*ALPACA_SOURCE),
|
||||
feed: Some(Source::Iex),
|
||||
currency: None,
|
||||
page_token: None,
|
||||
sort: Some(Sort::Asc),
|
||||
@@ -91,7 +93,7 @@ impl Default for Crypto {
|
||||
timeframe: ONE_MINUTE,
|
||||
start: None,
|
||||
end: None,
|
||||
limit: Some(10000),
|
||||
limit: Some(MAX_LIMIT),
|
||||
page_token: None,
|
||||
sort: Some(Sort::Asc),
|
||||
}
|
@@ -1,3 +1,4 @@
|
||||
pub mod asset;
|
||||
pub mod bar;
|
||||
pub mod calendar;
|
||||
pub mod news;
|
@@ -1,10 +1,10 @@
|
||||
use crate::{types::alpaca::shared::Sort, utils::ser};
|
||||
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
|
||||
use serde::Serialize;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct News {
|
||||
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]
|
||||
#[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
|
||||
pub symbols: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(with = "time::serde::rfc3339::option")]
|
||||
@@ -30,7 +30,7 @@ impl Default for News {
|
||||
symbols: vec![],
|
||||
start: None,
|
||||
end: None,
|
||||
limit: Some(50),
|
||||
limit: Some(MAX_LIMIT),
|
||||
include_content: Some(true),
|
||||
exclude_contentless: Some(false),
|
||||
page_token: None,
|
@@ -1,10 +1,12 @@
|
||||
use crate::{
|
||||
types::alpaca::shared::{order::Side, Sort},
|
||||
types::alpaca::shared::{order, Sort},
|
||||
utils::ser,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
pub use order::Side;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[allow(dead_code)]
|
@@ -1,7 +1,7 @@
|
||||
use crate::{impl_from_enum, types};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Class {
|
||||
UsEquity,
|
||||
@@ -10,7 +10,7 @@ pub enum Class {
|
||||
|
||||
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Clone, Copy)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum Exchange {
|
||||
Amex,
|
||||
@@ -36,7 +36,7 @@ impl_from_enum!(
|
||||
Crypto
|
||||
);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Status {
|
||||
Active,
|
@@ -1,4 +1,3 @@
|
||||
use html_escape::decode_html_entities;
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
|
||||
@@ -7,12 +6,21 @@ lazy_static! {
|
||||
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
|
||||
}
|
||||
|
||||
pub fn normalize_html_content(content: &str) -> String {
|
||||
pub fn strip(content: &str) -> String {
|
||||
let content = content.replace('\n', " ");
|
||||
let content = RE_TAGS.replace_all(&content, "");
|
||||
let content = RE_SPACES.replace_all(&content, " ");
|
||||
let content = decode_html_entities(&content);
|
||||
let content = content.trim();
|
||||
|
||||
content.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_strip() {
|
||||
let content = "<p> <b> Hello, </b> <i> World! </i> </p>";
|
||||
assert_eq!(strip(content), "Hello, World!");
|
||||
}
|
||||
}
|
@@ -223,3 +223,53 @@ impl Order {
|
||||
orders
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
let order_template = Order {
|
||||
id: Uuid::new_v4(),
|
||||
client_order_id: Uuid::new_v4(),
|
||||
created_at: OffsetDateTime::now_utc(),
|
||||
updated_at: None,
|
||||
submitted_at: OffsetDateTime::now_utc(),
|
||||
filled_at: None,
|
||||
expired_at: None,
|
||||
cancel_requested_at: None,
|
||||
canceled_at: None,
|
||||
failed_at: None,
|
||||
replaced_at: None,
|
||||
replaced_by: None,
|
||||
replaces: None,
|
||||
asset_id: Uuid::new_v4(),
|
||||
symbol: "AAPL".to_string(),
|
||||
asset_class: super::super::asset::Class::UsEquity,
|
||||
notional: None,
|
||||
qty: None,
|
||||
filled_qty: 0.0,
|
||||
filled_avg_price: None,
|
||||
order_class: Class::Simple,
|
||||
order_type: Type::Market,
|
||||
side: Side::Buy,
|
||||
time_in_force: TimeInForce::Day,
|
||||
limit_price: None,
|
||||
stop_price: None,
|
||||
status: Status::New,
|
||||
extended_hours: false,
|
||||
legs: None,
|
||||
trail_percent: None,
|
||||
trail_price: None,
|
||||
hwm: None,
|
||||
};
|
||||
|
||||
let mut order = order_template.clone();
|
||||
order.legs = Some(vec![order_template.clone(), order_template.clone()]);
|
||||
order.legs.as_mut().unwrap()[0].legs = Some(vec![order_template.clone()]);
|
||||
let orders = order.normalize();
|
||||
|
||||
assert_eq!(orders.len(), 4);
|
||||
}
|
||||
}
|
@@ -28,15 +28,14 @@ pub struct Message {
|
||||
impl From<Message> for Bar {
|
||||
fn from(bar: Message) -> Self {
|
||||
Self {
|
||||
time: bar.time,
|
||||
symbol: bar.symbol,
|
||||
time: bar.time,
|
||||
open: bar.open,
|
||||
high: bar.high,
|
||||
low: bar.low,
|
||||
close: bar.close,
|
||||
volume: bar.volume,
|
||||
trades: bar.trades,
|
||||
vwap: bar.vwap,
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News},
|
||||
types::{alpaca::shared::news::strip, News},
|
||||
utils::de,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -31,13 +31,11 @@ impl From<Message> for News {
|
||||
time_created: news.time_created,
|
||||
time_updated: news.time_updated,
|
||||
symbols: news.symbols,
|
||||
headline: normalize_html_content(&news.headline),
|
||||
author: normalize_html_content(&news.author),
|
||||
source: normalize_html_content(&news.source),
|
||||
summary: normalize_html_content(&news.summary),
|
||||
content: normalize_html_content(&news.content),
|
||||
sentiment: Sentiment::Neutral,
|
||||
confidence: 0.0,
|
||||
headline: strip(&news.headline),
|
||||
author: strip(&news.author),
|
||||
source: strip(&news.source),
|
||||
summary: news.summary,
|
||||
content: news.content,
|
||||
url: news.url.unwrap_or_default(),
|
||||
}
|
||||
}
|
@@ -6,13 +6,13 @@ use serde::Deserialize;
|
||||
pub enum Message {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Market {
|
||||
trades: Vec<String>,
|
||||
quotes: Vec<String>,
|
||||
bars: Vec<String>,
|
||||
updated_bars: Vec<String>,
|
||||
daily_bars: Vec<String>,
|
||||
statuses: Vec<String>,
|
||||
trades: Option<Vec<String>>,
|
||||
quotes: Option<Vec<String>>,
|
||||
daily_bars: Option<Vec<String>>,
|
||||
orderbooks: Option<Vec<String>>,
|
||||
statuses: Option<Vec<String>>,
|
||||
lulds: Option<Vec<String>>,
|
||||
cancel_errors: Option<Vec<String>>,
|
||||
},
|
@@ -1,10 +1,7 @@
|
||||
pub mod incoming;
|
||||
pub mod outgoing;
|
||||
|
||||
use crate::{
|
||||
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
|
||||
types::alpaca::websocket,
|
||||
};
|
||||
use crate::types::alpaca::websocket;
|
||||
use core::panic;
|
||||
use futures_util::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
@@ -17,6 +14,8 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
|
||||
pub async fn authenticate(
|
||||
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
|
||||
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
api_key: String,
|
||||
api_secret: String,
|
||||
) {
|
||||
match stream.next().await.unwrap().unwrap() {
|
||||
Message::Text(data)
|
||||
@@ -32,8 +31,8 @@ pub async fn authenticate(
|
||||
sink.send(Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Auth(
|
||||
websocket::auth::Message {
|
||||
key: (*ALPACA_API_KEY).clone(),
|
||||
secret: (*ALPACA_API_SECRET).clone(),
|
||||
key: api_key,
|
||||
secret: api_secret,
|
||||
},
|
||||
))
|
||||
.unwrap(),
|
@@ -1,4 +1,5 @@
|
||||
use crate::utils::ser;
|
||||
use nonempty::NonEmpty;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -6,14 +7,14 @@ use serde::Serialize;
|
||||
pub enum Market {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
UsEquity {
|
||||
bars: Vec<String>,
|
||||
updated_bars: Vec<String>,
|
||||
statuses: Vec<String>,
|
||||
bars: NonEmpty<String>,
|
||||
updated_bars: NonEmpty<String>,
|
||||
statuses: NonEmpty<String>,
|
||||
},
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Crypto {
|
||||
bars: Vec<String>,
|
||||
updated_bars: Vec<String>,
|
||||
bars: NonEmpty<String>,
|
||||
updated_bars: NonEmpty<String>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -23,12 +24,12 @@ pub enum Message {
|
||||
Market(Market),
|
||||
News {
|
||||
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
|
||||
news: Vec<String>,
|
||||
news: NonEmpty<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new_market_us_equity(symbols: Vec<String>) -> Self {
|
||||
pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
|
||||
Self::Market(Market::UsEquity {
|
||||
bars: symbols.clone(),
|
||||
updated_bars: symbols.clone(),
|
||||
@@ -36,14 +37,14 @@ impl Message {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_market_crypto(symbols: Vec<String>) -> Self {
|
||||
pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
|
||||
Self::Market(Market::Crypto {
|
||||
bars: symbols.clone(),
|
||||
updated_bars: symbols,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_news(symbols: Vec<String>) -> Self {
|
||||
pub fn new_news(symbols: NonEmpty<String>) -> Self {
|
||||
Self::News { news: symbols }
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user