28 Commits

Author SHA1 Message Date
90b7f10a77 Update and fix bugs
It's good to be back

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-05-10 17:49:16 +01:00
d7e9350257 Add initial ML implementation
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-29 13:40:49 +00:00
f715881b07 Remove unsafe blocks
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:51:52 +00:00
d0ad9f65b1 Remove vwap
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:51:16 +00:00
ce8c4db422 Remove local XKCD image
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:50:39 +00:00
46508d1b4f Update reqwest
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 20:00:35 +00:00
2ad42c5462 Foldify for loops
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:48 +00:00
733e6373e9 Add tests
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:47 +00:00
d072b849c0 Reorganize crate source code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:46 +00:00
718e794f51 Reorder Bar struct
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-19 15:48:45 +00:00
b7a175d5b4 Improve ser function naming
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-17 09:22:45 +00:00
e9012d6ec3 Add backfill progress logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-15 17:26:58 +00:00
10365745aa Attempt to fix bugs related to empty vecs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 22:38:20 +00:00
8202255132 Fix backfill freshness
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 12:47:52 +00:00
0d276d537c Add websocket infinite inserting
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 01:46:18 +00:00
1707d74cf7 Improve alpaca request error handling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 19:45:06 +00:00
f3f9c6336b Remove rust-bert
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 12:09:50 +00:00
5ed0c7670a Fix backfill sentiment batching bug
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-12 21:00:11 +00:00
d2d20e2978 Add automatic websocket reconnection
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 23:41:06 +00:00
d02f958865 Optimize backfill early saving allocations
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:41:59 +00:00
2d8972dce2 Fix possible crashes on .unwrap()s
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:15:19 +00:00
7bacc2565a Fix CI
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:22 +00:00
b60cbc891d Add backfill early saving
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:12 +00:00
2de86b46f7 Improve backfill error logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 19:51:41 +00:00
8c7ee3d12d Add shared lib
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 18:28:40 +00:00
a15fd2c3c9 Separate data management code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 16:59:21 +00:00
acfc0ca4c9 Add pipelined backfilling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 11:22:24 +00:00
681d7393d7 Add multiple asset adding route
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-09 20:13:36 +00:00
141 changed files with 8040 additions and 33641 deletions

4
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables
debug/
target/
log/
# These are backup files generated by rustfmt
**/*.rs.bk
@@ -10,6 +11,3 @@ target/
*.pdb
.env*
# ML models
models/*/rust_model.ot

View File

@@ -22,7 +22,7 @@ build:
cache:
<<: *global_cache
script:
- cargo +nightly build
- cargo +nightly build --workspace
test:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -30,7 +30,7 @@ test:
cache:
<<: *global_cache
script:
- cargo +nightly test
- cargo +nightly test --workspace
lint:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -39,7 +39,7 @@ lint:
<<: *global_cache
script:
- cargo +nightly fmt --all -- --check
- cargo +nightly clippy --all-targets --all-features
- cargo +nightly clippy --workspace --all-targets --all-features
depcheck:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -48,7 +48,7 @@ depcheck:
<<: *global_cache
script:
- cargo +nightly outdated
- cargo +nightly udeps
- cargo +nightly udeps --workspace --all-targets
build-release:
image: registry.karaolidis.com/karaolidis/qrust/rust

2765
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,18 @@ name = "qrust"
version = "0.1.0"
edition = "2021"
[lib]
name = "qrust"
path = "src/lib/qrust/mod.rs"
[[bin]]
name = "qrust"
path = "src/bin/qrust/mod.rs"
[[bin]]
name = "trainer"
path = "src/bin/trainer/mod.rs"
[profile.release]
panic = 'abort'
strip = true
@@ -12,9 +24,9 @@ codegen-units = 1
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.7.4"
axum = "0.7.5"
dotenv = "0.15.0"
tokio = { version = "1.32.0", features = [
tokio = { version = "1.37.0", features = [
"macros",
"rt-multi-thread",
] }
@@ -22,29 +34,29 @@ tokio-tungstenite = { version = "0.21.0", features = [
"tokio-native-tls",
"native-tls",
] }
log = "0.4.20"
log4rs = "1.2.0"
serde = "1.0.188"
serde_json = "1.0.105"
serde_repr = "0.1.18"
serde_with = "3.6.1"
serde-aux = "4.4.0"
futures-util = "0.3.28"
reqwest = { version = "0.11.20", features = [
log = "0.4.21"
log4rs = "1.3.0"
serde = "1.0.201"
serde_json = "1.0.117"
serde_repr = "0.1.19"
serde_with = "3.8.1"
serde-aux = "4.5.0"
futures-util = "0.3.30"
reqwest = { version = "0.12.4", features = [
"json",
"serde_json",
] }
http = "1.0.0"
governor = "0.6.0"
http = "1.1.0"
governor = "0.6.3"
clickhouse = { version = "0.11.6", features = [
"watch",
"time",
"uuid",
] }
uuid = { version = "1.6.1", features = [
uuid = { version = "1.8.0", features = [
"serde",
"v4",
] }
time = { version = "0.3.31", features = [
time = { version = "0.3.36", features = [
"serde",
"serde-well-known",
"serde-human-readable",
@@ -55,9 +67,22 @@ time = { version = "0.3.31", features = [
backoff = { version = "0.4.0", features = [
"tokio",
] }
regex = "1.10.3"
html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77"
regex = "1.10.4"
async-trait = "0.1.80"
itertools = "0.12.1"
lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [
"serialize",
] }
rand = "0.8.5"
rayon = "1.10.0"
burn = { version = "0.13.2", features = [
"wgpu",
"cuda",
"tui",
"metrics",
"train",
] }
[dev-dependencies]
serde_test = "1.0.176"

View File

@@ -1,5 +1,5 @@
# qrust
![XKCD - Engineer Syllogism](./static/engineer-syllogism.png)
![XKCD - Engineer Syllogism](https://imgs.xkcd.com/comics/engineer_syllogism.png)
`qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.

View File

@@ -4,7 +4,14 @@ appenders:
encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root:
level: info
appenders:
- stdout
- file

View File

@@ -1,32 +0,0 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -1 +0,0 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -1 +0,0 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

File diff suppressed because it is too large Load Diff

86
src/bin/qrust/config.rs Normal file
View File

@@ -0,0 +1,86 @@
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use qrust::types::alpaca::shared::{Mode, Source};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
use tokio::sync::Semaphore;
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
};
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String =
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
}
impl Config {
pub fn from_env() -> Self {
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&ALPACA_API_KEY)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&ALPACA_API_SECRET)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => NonZeroU32::new(200).unwrap(),
Source::Sip => NonZeroU32::new(10_000).unwrap(),
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})),
clickhouse_client: clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

View File

@@ -1,24 +1,25 @@
use crate::{
config::{Config, ALPACA_MODE},
config::{Config, ALPACA_API_BASE},
database,
types::alpaca,
};
use log::{info, warn};
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::api::incoming::account::get(
let account = alpaca::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
assert!(
!(account.status != alpaca::api::incoming::account::Status::Active),
!(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
@@ -33,56 +34,60 @@ pub async fn check_account(config: &Arc<Config>) {
warn!("Account cash is zero, qrust will not be able to trade.");
}
warn!(
"qrust active on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_MODE, account.currency, account.cash
info!(
"qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_API_BASE, account.currency, account.cash
);
}
pub async fn rehydrate_orders(config: &Arc<Config>) {
info!("Rehydrating order data.");
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
while let Some(message) = alpaca::api::incoming::order::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::order::Order {
status: Some(alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
)
.await
.ok()
.filter(|message| !message.is_empty())
{
loop {
let message = alpaca::orders::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::order::Order {
status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
if message.is_empty() {
break;
}
orders.extend(message);
after = orders.last().unwrap().submitted_at;
}
let orders = orders
.into_iter()
.flat_map(&alpaca::api::incoming::order::Order::normalize)
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(&config.clickhouse_client, &orders)
.await
.unwrap();
info!("Rehydrated order data.");
database::orders::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
}
pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data.");
let positions_future = async {
alpaca::api::incoming::position::get(
alpaca::positions::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
@@ -92,9 +97,12 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
};
let assets_future = async {
database::assets::select(&config.clickhouse_client)
.await
.unwrap()
database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
};
let (mut positions, assets) = join!(positions_future, assets_future);
@@ -111,9 +119,13 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
})
.collect::<Vec<_>>();
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await
.unwrap();
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
for position in positions.values() {
warn!(
@@ -121,6 +133,4 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
position.symbol, position.qty
);
}
info!("Rehydrated position data.");
}

115
src/bin/qrust/mod.rs Normal file
View File

@@ -0,0 +1,115 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
#![allow(clippy::missing_docs_in_private_items)]
#![feature(hash_extract_if)]
mod config;
mod init;
mod routes;
mod threads;
use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
};
use dotenv::dotenv;
use log::info;
use log4rs::config::Deserializers;
use nonempty::NonEmpty;
use qrust::{create_send_await, database};
use tokio::{join, spawn, sync::mpsc, try_join};
#[tokio::main]
async fn main() {
dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
info!("Marking all assets as stale.");
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
try_join!(
database::backfills_bars::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
),
database::backfills_news::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
)
)
.unwrap();
info!("Cleaning up database.");
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Optimizing database.");
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Rehydrating account data.");
init::check_account(&config).await;
join!(
init::rehydrate_orders(&config),
init::rehydrate_positions(&config)
);
info!("Starting threads.");
spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
spawn(threads::data::run(
config.clone(),
data_receiver,
clock_receiver,
));
spawn(threads::clock::run(config.clone(), clock_sender));
if let Some(assets) = NonEmpty::from_vec(assets) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
}
routes::run(config, data_sender).await;
}

View File

@@ -0,0 +1,197 @@
use crate::{
config::{Config, ALPACA_API_BASE},
create_send_await, database, threads,
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use nonempty::{nonempty, NonEmpty};
use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
})
}
#[derive(Deserialize)]
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
}
Ok((
StatusCode::OK,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?;
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN);
}
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
nonempty![(asset.symbol, asset.class.into())]
);
Ok(StatusCode::CREATED)
}
pub async fn delete(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
nonempty![(asset.symbol, asset.class)]
);
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -16,6 +16,7 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
.route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete))
.layer(Extension(config))
.layer(Extension(data_sender));

View File

@@ -1,14 +1,17 @@
use crate::{
config::Config,
config::{Config, ALPACA_API_BASE},
database,
types::{alpaca, Calendar},
utils::{backoff, duration_until},
};
use log::info;
use qrust::{
alpaca,
types::{self, Calendar},
utils::{backoff, duration_until},
};
use std::sync::Arc;
use time::OffsetDateTime;
use tokio::{join, sync::mpsc, time::sleep};
#[derive(PartialEq, Eq)]
pub enum Status {
Open,
Closed,
@@ -16,21 +19,16 @@ pub enum Status {
pub struct Message {
pub status: Status,
pub next_switch: OffsetDateTime,
}
impl From<alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self {
if clock.is_open {
Self {
status: Status::Open,
next_switch: clock.next_close,
}
} else {
Self {
status: Status::Closed,
next_switch: clock.next_open,
}
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
Self {
status: if clock.is_open {
Status::Open
} else {
Status::Closed
},
}
}
}
@@ -38,21 +36,23 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock_future = async {
alpaca::api::incoming::clock::get(
alpaca::clock::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
};
let calendar_future = async {
alpaca::api::incoming::calendar::get(
alpaca::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::calendar::Calendar::default(),
&types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
@@ -74,9 +74,13 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
let sleep_future = sleep(sleep_until);
let calendar_future = async {
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar)
.await
.unwrap();
database::calendar::upsert_batch_and_delete(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
};
join!(sleep_future, calendar_future);

View File

@@ -0,0 +1,238 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::{
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
shared::{Sort, Source},
},
Backfill, Bar, Class,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
pub data_url: &'static str,
pub api_query_constructor: fn(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
feed: Some(*ALPACA_SOURCE),
..Default::default()
})
}
pub fn crypto_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
..Default::default()
})
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbols.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
database::backfills_bars::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled bars for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"bars"
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
_ => unreachable!(),
};
let api_query_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
_ => unreachable!(),
};
Box::new(Handler {
config,
data_url,
api_query_constructor,
})
}

View File

@@ -0,0 +1,243 @@
pub mod bars;
pub mod news;
use async_trait::async_trait;
use itertools::Itertools;
use log::{info, warn};
use nonempty::{nonempty, NonEmpty};
use qrust::{
types::Backfill,
utils::{last_minute, ONE_SECOND},
};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::JoinHandle,
try_join,
};
use uuid::Uuid;
pub enum Action {
Backfill,
Purge,
}
pub struct Message {
pub action: Action,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[derive(Clone)]
pub struct Job {
pub symbol: String,
pub fetch_from: OffsetDateTime,
pub fetch_to: OffsetDateTime,
pub fresh: bool,
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
async fn backfill(&self, jobs: NonEmpty<Job>);
fn max_limit(&self) -> i64;
fn log_string(&self) -> &'static str;
}
pub struct Jobs {
pub symbol_to_uuid: HashMap<String, Uuid>,
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
}
impl Jobs {
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
let uuid = Uuid::new_v4();
for symbol in jobs {
self.symbol_to_uuid.insert(symbol.clone(), uuid);
}
self.uuid_to_job.insert(uuid, fut);
}
pub fn contains_key(&self, symbol: &str) -> bool {
self.symbol_to_uuid.contains_key(symbol)
}
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
self.symbol_to_uuid
.remove(symbol)
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
}
pub fn remove_many<T>(&mut self, symbols: &[T])
where
T: AsRef<str> + Hash + Eq,
{
for symbol in symbols {
self.symbol_to_uuid
.remove(symbol.as_ref())
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
}
}
pub fn len(&self) -> usize {
self.symbol_to_uuid.len()
}
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(Jobs {
symbol_to_uuid: HashMap::new(),
uuid_to_job: HashMap::new(),
}));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<Jobs>>,
message: Message,
) {
let backfill_jobs_clone = backfill_jobs.clone();
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = Vec::from(message.symbols);
match message.action {
Action::Backfill => {
let log_string = handler.log_string();
let max_limit = handler.max_limit();
let backfills = handler
.select_latest_backfills(&symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
let mut jobs = Vec::with_capacity(symbols.len());
for symbol in symbols {
if backfill_jobs.contains_key(&symbol) {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
let backfill = backfills.get(&symbol);
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
jobs.push(Job {
symbol,
fetch_from,
fetch_to,
fresh,
});
}
let mut current_minutes = 0;
let job_groups = jobs
.into_iter()
.sorted_unstable_by_key(|job| job.fetch_from)
.fold(Vec::<NonEmpty<Job>>::new(), |mut job_groups, job| {
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
if let Some(job_group) = job_groups.last_mut() {
if current_minutes + minutes <= max_limit {
job_group.push(job);
current_minutes += minutes;
return job_groups;
}
}
job_groups.push(nonempty![job]);
current_minutes = minutes;
job_groups
});
for job_group in job_groups {
let symbols = job_group
.iter()
.map(|job| job.symbol.clone())
.collect::<Vec<_>>();
let handler = handler.clone();
let symbols_clone = symbols.clone();
let backfill_jobs_clone = backfill_jobs_clone.clone();
let fut = spawn(async move {
handler.queue_backfill(&job_group).await;
handler.backfill(job_group).await;
let mut backfill_jobs = backfill_jobs_clone.lock().await;
backfill_jobs.remove_many(&symbols_clone);
let remaining = backfill_jobs.len();
drop(backfill_jobs);
info!("{} {} backfills remaining.", remaining, log_string);
});
backfill_jobs.insert(symbols, fut);
}
}
Action::Purge => {
for symbol in &symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
job.abort();
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&symbols),
handler.delete_data(&symbols)
)
.unwrap();
}
}
message.response.send(()).unwrap();
}

View File

@@ -0,0 +1,186 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
database,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::shared::{Sort, Source},
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
database::backfills_news::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
Box::new(Handler { config })
}

View File

@@ -0,0 +1,391 @@
mod backfill;
mod websocket;
use super::clock;
use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database,
};
use itertools::{Either, Itertools};
use log::error;
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Action {
Add,
Enable,
Remove,
Disable,
}
pub struct Message {
pub action: Action,
pub assets: NonEmpty<(String, Class)>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
assets,
response: sender,
},
receiver,
)
}
}
#[derive(Clone, Copy)]
pub enum ThreadType {
Bars(Class),
News,
}
pub async fn run(
config: Arc<Config>,
mut receiver: mpsc::Receiver<Message>,
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News);
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
config.clone(),
bars_us_equity_websocket_sender.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_websocket_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_websocket_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
Some(message) = clock_receiver.recv() => {
spawn(handle_clock_message(
config.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
fn init_thread(
config: Arc<Config>,
thread_type: ThreadType,
) -> (
mpsc::Sender<websocket::Message>,
mpsc::Sender<backfill::Message>,
) {
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
}
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
};
let backfill_handler = match thread_type {
ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
ThreadType::News => backfill::news::create_handler(config.clone()),
};
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(backfill_handler.into(), backfill_receiver));
let websocket_handler = match thread_type {
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
ThreadType::News => websocket::news::create_handler(&config),
};
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run(
websocket_handler.into(),
websocket_receiver,
websocket_url,
));
(websocket_sender, backfill_sender)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_lines)]
async fn handle_message(
config: Arc<Config>,
bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_websocket_sender: mpsc::Sender<websocket::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message,
) {
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets
.clone()
.into_iter()
.partition_map(|asset| match asset.1 {
Class::UsEquity => Either::Left(asset.0),
Class::Crypto => Either::Right(asset.0),
});
let symbols = message.assets.map(|(symbol, _)| symbol);
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_websocket_sender,
websocket::Message::new,
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
match message.action {
Action::Add | Action::Enable => {
let symbols = Vec::from(symbols.clone());
let assets = async {
alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let batch =
symbols
.iter()
.fold(Vec::with_capacity(symbols.len()), |mut batch, symbol| {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}.", symbol);
}
batch
});
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&Vec::from(symbols.clone()),
)
.await
.unwrap();
}
Action::Disable => unreachable!(),
}
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap();
}
async fn handle_clock_message(
config: Arc<Config>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: clock::Message,
) {
if message.status == clock::Status::Closed {
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
}
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()
.into_iter()
.partition_map(|asset| match asset.class {
Class::UsEquity => Either::Left(asset.symbol),
Class::Crypto => Either::Right(asset.symbol),
});
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
crypto_symbols
);
}
};
let news_future = async {
if let Some(symbols) = NonEmpty::from_vec(symbols) {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
symbols
);
}
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
}

View File

@@ -0,0 +1,171 @@
use super::State;
use crate::{
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, Bar, Class},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<Bar>>>,
pub subscription_message_constructor:
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
self.inserter.lock().await.write(&bar).await.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("bars")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
));
let subscription_message_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
websocket::data::outgoing::subscribe::Message::new_market_us_equity
}
ThreadType::Bars(Class::Crypto) => {
websocket::data::outgoing::subscribe::Message::new_market_crypto
}
_ => unreachable!(),
};
Box::new(Handler {
config,
inserter,
subscription_message_constructor,
})
}

View File

@@ -0,0 +1,353 @@
pub mod bars;
pub mod news;
use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait;
use backoff::{future::retry_notify, ExponentialBackoff};
use clickhouse::{inserter::Inserter, Row};
use futures_util::{future::join_all, SinkExt, StreamExt};
use log::error;
use nonempty::NonEmpty;
use qrust::types::alpaca::{self, websocket};
use serde::Serialize;
use serde_json::{from_str, to_string};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct State {
pub active_subscriptions: HashSet<String>,
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync + 'static {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
);
fn log_string(&self) -> &'static str;
async fn run_inserter(&self);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
websocket_url: String,
) {
let state = Arc::new(RwLock::new(State {
active_subscriptions: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let handler_clone = handler.clone();
spawn(async move { handler_clone.run_inserter().await });
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
state.clone(),
sink_sender.clone(),
message,
));
}
Some(message) = stream_receiver.recv() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let state = state.clone();
spawn(async move {
handler.handle_websocket_message(state, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<State>>,
sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message,
) {
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
pending
.write()
.await
.pending_subscriptions
.extend(pending_subscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.pending_unsubscriptions
.extend(pending_unsubscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>)
where
T: Row + Serialize,
{
loop {
let time_left = inserter.lock().await.time_left().unwrap();
tokio::time::sleep(time_left).await;
inserter.lock().await.commit().await.unwrap();
}
}

View File

@@ -0,0 +1,119 @@
use super::State;
use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, News},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub inserter: Arc<Mutex<Inserter<News>>>,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
self.inserter.lock().await.write(&news).await.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"news"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: &Arc<Config>) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("news")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()),
));
Box::new(Handler { inserter })
}

View File

@@ -0,0 +1,27 @@
mod websocket;
use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET};
use futures_util::StreamExt;
use qrust::types::alpaca;
use std::sync::Arc;
use tokio::spawn;
use tokio_tungstenite::connect_async;
pub async fn run(config: Arc<Config>) {
let (websocket, _) =
connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE))
.await
.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::trading::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
spawn(websocket::run(config, websocket_stream));
}

View File

@@ -1,10 +1,7 @@
use crate::{
config::Config,
database,
types::{alpaca::websocket, Order},
};
use crate::{config::Config, database};
use futures_util::{stream::SplitStream, StreamExt};
use log::{debug, error};
use qrust::types::{alpaca::websocket, Order};
use serde_json::from_str;
use std::sync::Arc;
use tokio::{net::TcpStream, spawn};
@@ -24,7 +21,7 @@ pub async fn run(
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
@@ -34,7 +31,7 @@ pub async fn run(
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
}
@@ -46,15 +43,19 @@ async fn handle_websocket_message(
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}",
"Received order message for {}: {:?}.",
message.order.symbol, message.event
);
let order = Order::from(message.order);
database::orders::upsert(&config.clickhouse_client, &order)
.await
.unwrap();
database::orders::upsert(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
@@ -63,6 +64,7 @@ async fn handle_websocket_message(
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol,
position_qty,
)

133
src/bin/trainer/mod.rs Normal file
View File

@@ -0,0 +1,133 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
use burn::{
config::Config,
data::{
dataloader::{DataLoaderBuilder, Dataset},
dataset::transform::{PartialDataset, ShuffledDataset},
},
module::Module,
optim::AdamConfig,
record::CompactRecorder,
tensor::backend::AutodiffBackend,
train::LearnerBuilder,
};
use dotenv::dotenv;
use log::info;
use qrust::{
database,
ml::{
BarWindow, BarWindowBatcher, ModelConfig, MultipleSymbolDataset, MyAutodiffBackend, DEVICE,
},
types::Bar,
};
use std::{env, fs, path::Path, sync::Arc};
use tokio::sync::Semaphore;
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 100)]
pub epochs: usize,
#[config(default = 256)]
pub batch_size: usize,
#[config(default = 16)]
pub num_workers: usize,
#[config(default = 0)]
pub seed: u64,
#[config(default = 0.2)]
pub valid_pct: f64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
#[tokio::main]
async fn main() {
dotenv().ok();
let dir = Path::new(file!()).parent().unwrap();
let model_config = ModelConfig::new();
let optimizer = AdamConfig::new();
let training_config = TrainingConfig::new(model_config, optimizer);
let clickhouse_client = clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."))
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."));
let clickhouse_concurrency_limiter = Arc::new(Semaphore::new(Semaphore::MAX_PERMITS));
let bars = database::ta::select(&clickhouse_client, &clickhouse_concurrency_limiter)
.await
.unwrap();
info!("Loaded {} bars.", bars.len());
train::<MyAutodiffBackend>(
bars,
&training_config,
dir.join("artifacts").to_str().unwrap(),
&DEVICE,
);
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
fn train<B: AutodiffBackend<FloatElem = f32, IntElem = i32>>(
bars: Vec<Bar>,
config: &TrainingConfig,
dir: &str,
device: &B::Device,
) {
B::seed(config.seed);
fs::create_dir_all(dir).unwrap();
let dataset = MultipleSymbolDataset::new(bars);
let dataset = ShuffledDataset::with_seed(dataset, config.seed);
let dataset = Arc::new(dataset);
let split = (dataset.len() as f64 * (1.0 - config.valid_pct)) as usize;
let train: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), 0, split);
let batcher_train = BarWindowBatcher::<B> {
device: device.clone(),
};
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(train);
let valid: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), split, dataset.len());
let batcher_valid = BarWindowBatcher::<B::InnerBackend> {
device: device.clone(),
};
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(valid);
let learner = LearnerBuilder::new(dir)
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.epochs)
.build(
config.model.init::<B>(device),
config.optimizer.init(),
config.learning_rate,
);
let trained = learner.fit(dataloader_train, dataloader_valid);
trained.save_file(dir, &CompactRecorder::new()).unwrap();
}

View File

@@ -1,123 +0,0 @@
use crate::types::alpaca::shared::{Mode, Source};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use rust_bert::{
pipelines::{
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
#[derive(Debug)]
pub static ref ALPACA_API_URL: String = format!(
"https://{}.alpaca.markets/v2",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
#[derive(Debug)]
pub static ref ALPACA_WEBSOCKET_URL: String = format!(
"wss://{}.alpaca.markets/stream",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse()
.expect("MAX_BERT_INPUTS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub sequence_classifier: Mutex<SequenceClassificationModel>,
}
impl Config {
pub fn from_env() -> Self {
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&ALPACA_API_KEY)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&ALPACA_API_SECRET)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})),
clickhouse_client: clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
sequence_classifier: Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
),
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

View File

@@ -1,17 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -1,17 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -1,7 +0,0 @@
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
cleanup!("bars");
optimize!("bars");

View File

@@ -1,152 +0,0 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use log::info;
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
) -> Result<Vec<$record>, clickhouse::error::Error> {
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), clickhouse::error::Error>
where
T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Cleaning up database.");
try_join!(
bars::cleanup(clickhouse_client),
news::cleanup(clickhouse_client),
backfills_bars::cleanup(clickhouse_client),
backfills_news::cleanup(clickhouse_client)
)
.map(|_| ())
}
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Optimizing database.");
try_join!(
assets::optimize(clickhouse_client),
bars::optimize(clickhouse_client),
news::optimize(clickhouse_client),
backfills_bars::optimize(clickhouse_client),
backfills_news::optimize(clickhouse_client),
orders::optimize(clickhouse_client),
calendar::optimize(clickhouse_client)
)
.map(|_| ())
}

View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Account>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,132 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Asset>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Asset>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -0,0 +1,50 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,41 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Calendar>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Clock>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,27 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;
use http::StatusCode;
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
if err.is_status() {
return match err.status() {
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
| None => backoff::Error::Permanent(err),
_ => err.into(),
};
}
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
{
return backoff::Error::Permanent(err);
}
err.into()
}

View File

@@ -0,0 +1,49 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,44 +1,39 @@
use crate::{
config::ALPACA_API_URL,
types::alpaca::{api::outgoing, shared},
};
use super::error_to_backoff;
use crate::types::alpaca::{api::outgoing, shared::order};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use shared::order::Order;
pub use order::Order;
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/orders", *ALPACA_API_URL))
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await?
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.map_err(error_to_backoff)?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}",
"Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(),
e
);

View File

@@ -0,0 +1,109 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use http::StatusCode;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Position>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(error_to_backoff)?
.json::<Position>()
.await
.map_err(error_to_backoff)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,8 +1,11 @@
use std::sync::Arc;
use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets");
select_where_symbol!(Asset, "assets");
@@ -11,14 +14,16 @@ delete_where_symbols!("assets");
optimize!("assets");
pub async fn update_status_where_symbol<T>(
clickhouse_client: &Client,
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
status: bool,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
@@ -27,14 +32,16 @@ where
}
pub async fn update_qty_where_symbol<T>(
clickhouse_client: &Client,
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol)

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_bars");
upsert_batch!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
set_fresh_where_symbols!("backfills_bars");

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_news");
upsert_batch!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
set_fresh_where_symbols!("backfills_news");

View File

@@ -0,0 +1,21 @@
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,16 +1,19 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar};
use clickhouse::error::Error;
use tokio::try_join;
use clickhouse::{error::Error, Client};
use tokio::{sync::Semaphore, try_join};
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, T>(
client: &clickhouse::Client,
records: T,
pub async fn upsert_batch_and_delete<'a, I>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: I,
) -> Result<(), Error>
where
T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
T::IntoIter: Send,
I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
I::IntoIter: Send,
{
let upsert_future = async {
let mut insert = client.insert("calendar")?;
@@ -34,5 +37,6 @@ where
.await
};
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ())
}

View File

@@ -0,0 +1,224 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
pub mod ta;
use clickhouse::{error::Error, Client};
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, I>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: I,
) -> Result<(), clickhouse::error::Error>
where
I: IntoIterator<Item = &'a $record> + Send + Sync,
I::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! set_fresh_where_symbols {
($table_name:expr) => {
pub async fn set_fresh_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
fresh: bool,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
$table_name
))
.bind(fresh)
.bind(symbols)
.execute()
.await
}
};
}
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -1,24 +1,33 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols)
.execute()
.await
}
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)

View File

@@ -0,0 +1,30 @@
use crate::types::Bar;
use clickhouse::{error::Error, Client};
use std::sync::Arc;
use tokio::sync::Semaphore;
pub async fn select(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<Vec<Bar>, Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"
SELECT symbol,
toStartOfHour(bars.time) AS time,
any(bars.open) AS open,
max(bars.high) AS high,
min(bars.low) AS low,
anyLast(bars.close) AS close,
sum(bars.volume) AS volume,
sum(bars.trades) AS trades
FROM bars FINAL
GROUP BY ALL
ORDER BY symbol,
time
",
)
.fetch_all::<Bar>()
.await
}

View File

@@ -0,0 +1,75 @@
use super::BarWindow;
use burn::{
data::dataloader::batcher::Batcher,
tensor::{self, backend::Backend, Tensor},
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Clone, Debug)]
pub struct BarWindowBatcher<B: Backend> {
pub device: B::Device,
}
#[derive(Clone, Debug)]
pub struct BarWindowBatch<B: Backend> {
pub hour_tensor: Tensor<B, 2, tensor::Int>,
pub day_tensor: Tensor<B, 2, tensor::Int>,
pub numerical_tensor: Tensor<B, 3>,
pub target_tensor: Tensor<B, 2>,
}
impl<B: Backend<FloatElem = f32, IntElem = i32>> Batcher<BarWindow, BarWindowBatch<B>>
for BarWindowBatcher<B>
{
fn batch(&self, items: Vec<BarWindow>) -> BarWindowBatch<B> {
let batch_size = items.len();
let (hour_tensors, day_tensors, numerical_tensors, target_tensors) = items
.into_par_iter()
.fold(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.push(Tensor::from_data(item.hours, &self.device));
day_tensors.push(Tensor::from_data(item.days, &self.device));
numerical_tensors.push(Tensor::from_data(item.numerical, &self.device));
target_tensors.push(Tensor::from_data(item.target, &self.device));
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
)
.reduce(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.extend(item.0);
day_tensors.extend(item.1);
numerical_tensors.extend(item.2);
target_tensors.extend(item.3);
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
);
BarWindowBatch {
hour_tensor: Tensor::stack(hour_tensors, 0).to_device(&self.device),
day_tensor: Tensor::stack(day_tensors, 0).to_device(&self.device),
numerical_tensor: Tensor::stack(numerical_tensors, 0).to_device(&self.device),
target_tensor: Tensor::stack(target_tensors, 0).to_device(&self.device),
}
}
}

219
src/lib/qrust/ml/dataset.rs Normal file
View File

@@ -0,0 +1,219 @@
use crate::types::{
ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT},
Bar,
};
use burn::{
data::dataset::{transform::ComposedDataset, Dataset},
tensor::Data,
};
pub const WINDOW_SIZE: usize = 48;
#[derive(Clone, Debug)]
pub struct BarWindow {
pub hours: Data<i32, 1>,
pub days: Data<i32, 1>,
pub numerical: Data<f32, 2>,
pub target: Data<f32, 1>,
}
#[derive(Clone, Debug)]
struct SingleSymbolDataset {
hours: Vec<i32>,
days: Vec<i32>,
numerical: Vec<[f32; NUMERICAL_FIELD_COUNT]>,
targets: Vec<f32>,
}
impl SingleSymbolDataset {
#[allow(clippy::cast_possible_truncation)]
pub fn new(bars: Vec<IndicatedBar>) -> Self {
if !bars.is_empty() {
let symbol = &bars[0].symbol;
assert!(bars.iter().all(|bar| bar.symbol == *symbol));
}
let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold(
(
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
),
|(mut hours, mut days, mut numerical, mut targets), bar| {
hours.push(i32::from(bar[0].hour));
days.push(i32::from(bar[0].day));
numerical.push([
bar[0].open as f32,
(bar[0].open_pct as f32).min(f32::MAX),
bar[0].high as f32,
(bar[0].high_pct as f32).min(f32::MAX),
bar[0].low as f32,
(bar[0].low_pct as f32).min(f32::MAX),
bar[0].close as f32,
(bar[0].close_pct as f32).min(f32::MAX),
bar[0].volume as f32,
(bar[0].volume_pct as f32).min(f32::MAX),
bar[0].trades as f32,
(bar[0].trades_pct as f32).min(f32::MAX),
bar[0].sma_3 as f32,
bar[0].sma_6 as f32,
bar[0].sma_12 as f32,
bar[0].sma_24 as f32,
bar[0].sma_48 as f32,
bar[0].sma_72 as f32,
bar[0].ema_3 as f32,
bar[0].ema_6 as f32,
bar[0].ema_12 as f32,
bar[0].ema_24 as f32,
bar[0].ema_48 as f32,
bar[0].ema_72 as f32,
bar[0].macd as f32,
bar[0].macd_signal as f32,
bar[0].obv as f32,
bar[0].rsi as f32,
bar[0].bbands_lower as f32,
bar[0].bbands_mean as f32,
bar[0].bbands_upper as f32,
]);
targets.push(bar[1].close_pct as f32);
(hours, days, numerical, targets)
},
);
Self {
hours,
days,
numerical,
targets,
}
}
}
impl Dataset<BarWindow> for SingleSymbolDataset {
fn len(&self) -> usize {
self.targets.len() - WINDOW_SIZE + 1
}
#[allow(clippy::single_range_in_vec_init)]
fn get(&self, idx: usize) -> Option<BarWindow> {
if idx >= self.len() {
return None;
}
let hours: [i32; WINDOW_SIZE] = self.hours[idx..idx + WINDOW_SIZE].try_into().unwrap();
let days: [i32; WINDOW_SIZE] = self.days[idx..idx + WINDOW_SIZE].try_into().unwrap();
let numerical: [[f32; NUMERICAL_FIELD_COUNT]; WINDOW_SIZE] =
self.numerical[idx..idx + WINDOW_SIZE].try_into().unwrap();
let target: [f32; 1] = [self.targets[idx + WINDOW_SIZE - 1]];
Some(BarWindow {
hours: Data::from(hours),
days: Data::from(days),
numerical: Data::from(numerical),
target: Data::from(target),
})
}
}
pub struct MultipleSymbolDataset {
composed_dataset: ComposedDataset<SingleSymbolDataset>,
}
impl MultipleSymbolDataset {
pub fn new(bars: Vec<Bar>) -> Self {
let groups = calculate_indicators(bars)
.into_iter()
.map(SingleSymbolDataset::new)
.collect::<Vec<_>>();
Self {
composed_dataset: ComposedDataset::new(groups),
}
}
}
impl Dataset<BarWindow> for MultipleSymbolDataset {
fn len(&self) -> usize {
self.composed_dataset.len()
}
fn get(&self, idx: usize) -> Option<BarWindow> {
self.composed_dataset.get(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{
distributions::{Distribution, Uniform},
Rng,
};
use time::OffsetDateTime;
fn generate_random_dataset(length: usize) -> MultipleSymbolDataset {
let mut rng = rand::thread_rng();
let uniform = Uniform::new(1.0, 100.0);
let mut bars = Vec::with_capacity(length);
for _ in 0..=(length + (HEAD_SIZE - 1) + (WINDOW_SIZE - 1)) {
bars.push(Bar {
symbol: "AAPL".to_string(),
time: OffsetDateTime::now_utc(),
open: uniform.sample(&mut rng),
high: uniform.sample(&mut rng),
low: uniform.sample(&mut rng),
close: uniform.sample(&mut rng),
volume: uniform.sample(&mut rng),
trades: rng.gen_range(1..100),
});
}
MultipleSymbolDataset::new(bars)
}
#[test]
fn test_single_symbol_dataset() {
let length = 100;
let dataset = generate_random_dataset(length);
assert_eq!(dataset.len(), length);
}
#[test]
fn test_single_symbol_dataset_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(0).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_last_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(dataset.len() - 1).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_out_of_bounds() {
let length = 100;
let dataset = generate_random_dataset(length);
assert!(dataset.get(dataset.len()).is_none());
}
}

21
src/lib/qrust/ml/mod.rs Normal file
View File

@@ -0,0 +1,21 @@
pub mod batcher;
pub mod dataset;
pub mod model;
pub use batcher::{BarWindowBatch, BarWindowBatcher};
pub use dataset::{BarWindow, MultipleSymbolDataset};
pub use model::{Model, ModelConfig};
use burn::{
backend::{
wgpu::{AutoGraphicsApi, WgpuDevice},
Autodiff, Wgpu,
},
tensor::backend::Backend,
};
pub type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
pub type MyAutodiffBackend = Autodiff<MyBackend>;
pub type MyDevice = <Autodiff<Wgpu> as Backend>::Device;
pub const DEVICE: MyDevice = WgpuDevice::BestAvailable;

160
src/lib/qrust/ml/model.rs Normal file
View File

@@ -0,0 +1,160 @@
use super::BarWindowBatch;
use crate::types::ta::NUMERICAL_FIELD_COUNT;
use burn::{
config::Config,
module::Module,
nn::{
loss::{MseLoss, Reduction},
Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig,
},
tensor::{
self,
backend::{AutodiffBackend, Backend},
Tensor,
},
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
hour_embedding: Embedding<B>,
day_embedding: Embedding<B>,
lstm_1: Lstm<B>,
dropout_1: Dropout,
lstm_2: Lstm<B>,
dropout_2: Dropout,
lstm_3: Lstm<B>,
dropout_3: Dropout,
lstm_4: Lstm<B>,
dropout_4: Dropout,
linear: Linear<B>,
}
#[derive(Config, Debug)]
pub struct ModelConfig {
#[config(default = "3")]
pub hour_features: usize,
#[config(default = "2")]
pub day_features: usize,
#[config(default = "{NUMERICAL_FIELD_COUNT}")]
pub numerical_features: usize,
#[config(default = "0.2")]
pub dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
let num_features = self.numerical_features + self.hour_features + self.day_features;
let lstm_1_hidden_size = 512;
let lstm_2_hidden_size = 256;
let lstm_3_hidden_size = 64;
let lstm_4_hidden_size = 32;
Model {
hour_embedding: EmbeddingConfig::new(24, self.hour_features).init(device),
day_embedding: EmbeddingConfig::new(7, self.day_features).init(device),
lstm_1: LstmConfig::new(num_features, lstm_1_hidden_size, true).init(device),
dropout_1: DropoutConfig::new(self.dropout).init(),
lstm_2: LstmConfig::new(lstm_1_hidden_size, lstm_2_hidden_size, true).init(device),
dropout_2: DropoutConfig::new(self.dropout).init(),
lstm_3: LstmConfig::new(lstm_2_hidden_size, lstm_3_hidden_size, true).init(device),
dropout_3: DropoutConfig::new(self.dropout).init(),
lstm_4: LstmConfig::new(lstm_3_hidden_size, lstm_4_hidden_size, true).init(device),
dropout_4: DropoutConfig::new(self.dropout).init(),
linear: LinearConfig::new(lstm_4_hidden_size, 1).init(device),
}
}
}
impl<B: Backend> Model<B> {
pub fn forward(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
) -> Tensor<B, 2> {
let hour = self.hour_embedding.forward(hour);
let day = self.day_embedding.forward(day);
let x = Tensor::cat(vec![hour, day, numerical], 2);
let (_, x) = self.lstm_1.forward(x, None);
let x = self.dropout_1.forward(x);
let (_, x) = self.lstm_2.forward(x, None);
let x = self.dropout_2.forward(x);
let (_, x) = self.lstm_3.forward(x, None);
let x = self.dropout_3.forward(x);
let (_, x) = self.lstm_4.forward(x, None);
let x = self.dropout_4.forward(x);
let [batch_size, window_size, features] = x.shape().dims;
let x = x.slice([0..batch_size, window_size - 1..window_size, 0..features]);
let x = x.squeeze(1);
self.linear.forward(x)
}
pub fn forward_regression(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
target: Tensor<B, 2>,
) -> RegressionOutput<B> {
let output = self.forward(hour, day, numerical);
let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean);
RegressionOutput::new(loss, output, target)
}
}
impl<B: AutodiffBackend> TrainStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> TrainOutput<RegressionOutput<B>> {
let item = self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> RegressionOutput<B> {
self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::{backend::Wgpu, tensor::Distribution};
#[test]
#[ignore]
fn test_model() {
let device = Default::default();
let distribution = Distribution::Normal(0.0, 1.0);
let config = ModelConfig::new().with_numerical_features(7);
let model = config.init::<Wgpu>(&device);
let hour = Tensor::ones([2, 10], &device);
let day = Tensor::ones([2, 10], &device);
let numerical = Tensor::random([2, 10, 7], distribution, &device);
let output = model.forward(hour, day, numerical);
assert_eq!(output.shape().dims, [2, 1]);
}
}

6
src/lib/qrust/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod alpaca;
pub mod database;
pub mod ml;
pub mod ta;
pub mod types;
pub mod utils;

149
src/lib/qrust/ta/bbands.rs Normal file
View File

@@ -0,0 +1,149 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct BbandsState {
window: VecDeque<f64>,
sum: f64,
squared_sum: f64,
multiplier: f64,
}
#[allow(clippy::type_complexity)]
pub trait Bbands<T>: Iterator + Sized {
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64, // Typically 2.0
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>>;
}
impl<I, T> Bbands<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64,
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>> {
self.scan(
BbandsState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
squared_sum: 0.0,
multiplier,
},
|state: &mut BbandsState, value: T| {
let value = *value.borrow();
let front = state.window.pop_front().unwrap();
state.sum -= front;
state.squared_sum -= front.powi(2);
state.window.push_back(value);
state.sum += value;
state.squared_sum += value.powi(2);
let mean = state.sum / state.window.len() as f64;
let variance =
((state.squared_sum / state.window.len() as f64) - mean.powi(2)).max(0.0);
let standard_deviation = variance.sqrt();
let upper_band = mean + state.multiplier * standard_deviation;
let lower_band = mean - state.multiplier * standard_deviation;
Some((upper_band, mean, lower_band))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bbands() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
#[test]
fn test_bbands_empty() {
let data = Vec::<f64>::new();
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(bbands, Vec::<(f64, f64, f64)>::new());
}
#[test]
fn test_bbands_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(1).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.0, 1.0, 1.0),
(2.0, 2.0, 2.0),
(3.0, 3.0, 3.0),
(4.0, 4.0, 4.0),
(5.0, 5.0, 5.0)
]
);
}
#[test]
fn test_bbands_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
}

59
src/lib/qrust/ta/deriv.rs Normal file
View File

@@ -0,0 +1,59 @@
use std::{borrow::Borrow, iter::Scan};
pub struct DerivState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Deriv<T>: Iterator + Sized {
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>>;
}
impl<I, T> Deriv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>> {
self.scan(
DerivState { last: 0.0 },
|state: &mut DerivState, value: T| {
let value = *value.borrow();
let deriv = value - state.last;
state.last = value;
Some(deriv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deriv() {
let data = vec![1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
#[test]
fn test_deriv_empty() {
let data = Vec::<f64>::new();
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, Vec::<f64>::new());
}
#[test]
fn test_deriv_borrow() {
let data = [1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
}

95
src/lib/qrust/ta/ema.rs Normal file
View File

@@ -0,0 +1,95 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct EmaState {
weight: f64,
ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Ema<T>: Iterator + Sized {
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>>;
}
impl<I, T> Ema<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>> {
let smoothing = 2.0;
let weight = smoothing / (1.0 + period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
EmaState { weight, ema: first },
|state: &mut EmaState, value: T| {
let value = *value.borrow();
state.ema = (value * state.weight) + (state.ema * (1.0 - state.weight));
Some(state.ema)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
#[test]
fn test_ema_empty() {
let data = Vec::<f64>::new();
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, Vec::<f64>::new());
}
#[test]
fn test_ema_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_ema_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
}

216
src/lib/qrust/ta/macd.rs Normal file
View File

@@ -0,0 +1,216 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct MacdState {
short_weight: f64,
long_weight: f64,
signal_weight: f64,
short_ema: f64,
long_ema: f64,
signal_ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Macd<T>: Iterator + Sized {
fn macd(
self,
short_period: NonZeroUsize, // Typically 12
long_period: NonZeroUsize, // Typically 26
signal_period: NonZeroUsize, // Typically 9
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>>;
}
impl<I, T> Macd<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn macd(
self,
short_period: NonZeroUsize,
long_period: NonZeroUsize,
signal_period: NonZeroUsize,
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>> {
let smoothing = 2.0;
let short_weight = smoothing / (1.0 + short_period.get() as f64);
let long_weight = smoothing / (1.0 + long_period.get() as f64);
let signal_weight = smoothing / (1.0 + signal_period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
MacdState {
short_weight,
long_weight,
signal_weight,
short_ema: first,
long_ema: first,
signal_ema: 0.0,
},
|state: &mut MacdState, value: T| {
let value = *value.borrow();
state.short_ema =
(value * state.short_weight) + (state.short_ema * (1.0 - state.short_weight));
state.long_ema =
(value * state.long_weight) + (state.long_ema * (1.0 - state.long_weight));
let macd = state.short_ema - state.long_ema;
state.signal_ema =
(macd * state.signal_weight) + (state.signal_ema * (1.0 - state.signal_weight));
Some((macd, state.signal_ema))
},
)
}
}
#[cfg(test)]
mod tests {
use super::super::ema::Ema;
use super::*;
#[test]
fn test_macd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_empty() {
let data = Vec::<f64>::new();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
vec![]
);
}
#[test]
fn test_macd_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
}

17
src/lib/qrust/ta/mod.rs Normal file
View File

@@ -0,0 +1,17 @@
pub mod bbands;
pub mod deriv;
pub mod ema;
pub mod macd;
pub mod obv;
pub mod pct;
pub mod rsi;
pub mod sma;
pub use bbands::Bbands;
pub use deriv::Deriv;
pub use ema::Ema;
pub use macd::Macd;
pub use obv::Obv;
pub use pct::Pct;
pub use rsi::Rsi;
pub use sma::Sma;

73
src/lib/qrust/ta/obv.rs Normal file
View File

@@ -0,0 +1,73 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
};
pub struct ObvState {
last: f64,
obv: f64,
}
#[allow(clippy::type_complexity)]
pub trait Obv<T>: Iterator + Sized {
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>>;
}
impl<I, T> Obv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<(f64, f64)>,
{
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
ObvState {
last: first.0,
obv: 0.0,
},
|state: &mut ObvState, value: T| {
let (close, volume) = *value.borrow();
if close > state.last {
state.obv += volume;
} else if close < state.last {
state.obv -= volume;
}
state.last = close;
Some(state.obv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_obv() {
let data = vec![(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
#[test]
fn test_obv_empty() {
let data = Vec::<(f64, f64)>::new();
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, Vec::<f64>::new());
}
#[test]
fn test_obv_borrow() {
let data = [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
}

64
src/lib/qrust/ta/pct.rs Normal file
View File

@@ -0,0 +1,64 @@
use std::{borrow::Borrow, iter::Scan};
pub struct PctState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Pct<T>: Iterator + Sized {
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>>;
}
impl<I, T> Pct<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>> {
self.scan(PctState { last: 0.0 }, |state: &mut PctState, value: T| {
let value = *value.borrow();
let pct = value / state.last - 1.0;
state.last = value;
Some(pct)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pct() {
let data = vec![1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
#[test]
fn test_pct_empty() {
let data = Vec::<f64>::new();
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, Vec::<f64>::new());
}
#[test]
fn test_pct_0() {
let data = vec![1.0, 0.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, -1.0, f64::INFINITY, -0.5, -0.5]);
}
#[test]
fn test_pct_borrow() {
let data = [1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
}

135
src/lib/qrust/ta/rsi.rs Normal file
View File

@@ -0,0 +1,135 @@
use std::{
borrow::Borrow,
collections::VecDeque,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct RsiState {
last: f64,
window_gains: VecDeque<f64>,
window_losses: VecDeque<f64>,
sum_gains: f64,
sum_losses: f64,
}
#[allow(clippy::type_complexity)]
pub trait Rsi<T>: Iterator + Sized {
fn rsi(
self,
period: NonZeroUsize, // Typically 14
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>>;
}
impl<I, T> Rsi<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn rsi(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
RsiState {
last: first,
window_gains: VecDeque::from(vec![0.0; period.get()]),
window_losses: VecDeque::from(vec![0.0; period.get()]),
sum_gains: 0.0,
sum_losses: 0.0,
},
|state, value| {
let value = *value.borrow();
state.sum_gains -= state.window_gains.pop_front().unwrap();
state.sum_losses -= state.window_losses.pop_front().unwrap();
let gain = (value - state.last).max(0.0);
let loss = (state.last - value).max(0.0);
state.last = value;
state.window_gains.push_back(gain);
state.window_losses.push_back(loss);
state.sum_gains += gain;
state.sum_losses += loss;
let avg_loss = state.sum_losses / state.window_losses.len() as f64;
if avg_loss == 0.0 {
return Some(100.0);
}
let avg_gain = state.sum_gains / state.window_gains.len() as f64;
let rs = avg_gain / avg_loss;
Some(100.0 - (100.0 / (1.0 + rs)))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi() {
let data = vec![1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
#[test]
fn test_rsi_empty() {
let data = Vec::<f64>::new();
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, Vec::<f64>::new());
}
#[test]
fn test_rsi_no_loss() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 100.0, 100.0]);
}
#[test]
fn test_rsi_no_gain() {
let data = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_rsi_borrow() {
let data = [1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
}

88
src/lib/qrust/ta/sma.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct SmaState {
window: VecDeque<f64>,
sum: f64,
}
#[allow(clippy::type_complexity)]
pub trait Sma<T>: Iterator + Sized {
fn sma(self, period: NonZeroUsize)
-> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>>;
}
impl<I, T> Sma<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn sma(
self,
period: NonZeroUsize,
) -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>> {
self.scan(
SmaState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
},
|state: &mut SmaState, value: T| {
let value = *value.borrow();
state.sum -= state.window.pop_front().unwrap();
state.window.push_back(value);
state.sum += value;
Some(state.sum / state.window.len() as f64)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sma() {
let data = vec![3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_sma_empty() {
let data = Vec::<f64>::new();
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, Vec::<f64>::new());
}
#[test]
fn test_sma_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_sma_borrow() {
let data = [3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
}

View File

@@ -1,13 +1,7 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use std::time::Duration;
use time::OffsetDateTime;
use uuid::Uuid;
@@ -79,38 +73,3 @@ pub struct Account {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/account", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,39 @@
use super::position::Position;
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,39 @@
use crate::types;
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Bar {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "o")]
pub open: f64,
#[serde(rename = "h")]
pub high: f64,
#[serde(rename = "l")]
pub low: f64,
#[serde(rename = "c")]
pub close: f64,
#[serde(rename = "v")]
pub volume: f64,
#[serde(rename = "n")]
pub trades: i64,
#[serde(rename = "vw")]
pub vwap: f64,
}
impl From<(Bar, String)> for types::Bar {
fn from((bar, symbol): (Bar, String)) -> Self {
Self {
symbol,
time: bar.time,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
}
}
}

View File

@@ -0,0 +1,26 @@
use crate::{
types,
utils::{de, time::EST_OFFSET},
};
use serde::Deserialize;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}

View File

@@ -0,0 +1,13 @@
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime,
pub is_open: bool,
#[serde(with = "time::serde::rfc3339")]
pub next_open: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime,
}

View File

@@ -0,0 +1,57 @@
use crate::{
types::{self, alpaca::shared::news::strip},
utils::de,
};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageSize {
Thumb,
Small,
Large,
}
#[derive(Deserialize)]
pub struct Image {
pub size: ImageSize,
pub url: String,
}
#[derive(Deserialize)]
pub struct News {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
pub images: Vec<Image>,
}
impl From<News> for types::News {
fn from(news: News) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,3 @@
use crate::types::alpaca::shared::order;
pub use order::{Order, Side};

View File

@@ -0,0 +1,61 @@
use crate::{
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}

View File

@@ -0,0 +1,6 @@
pub mod incoming;
pub mod outgoing;
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";

View File

@@ -0,0 +1,23 @@
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,12 +1,14 @@
use crate::{
config::ALPACA_SOURCE,
types::alpaca::shared::{Sort, Source},
alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
utils::{ser, ONE_MINUTE},
};
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
pub use shared::{Sort, Source};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
@@ -53,10 +55,10 @@ impl Default for UsEquity {
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(10000),
limit: Some(MAX_LIMIT),
adjustment: Some(Adjustment::All),
asof: None,
feed: Some(*ALPACA_SOURCE),
feed: Some(Source::Iex),
currency: None,
page_token: None,
sort: Some(Sort::Asc),
@@ -91,7 +93,7 @@ impl Default for Crypto {
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(10000),
limit: Some(MAX_LIMIT),
page_token: None,
sort: Some(Sort::Asc),
}

View File

@@ -1,3 +1,4 @@
pub mod asset;
pub mod bar;
pub mod calendar;
pub mod news;

View File

@@ -1,10 +1,10 @@
use crate::{types::alpaca::shared::Sort, utils::ser};
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
pub struct News {
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]
#[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
pub symbols: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
@@ -30,7 +30,7 @@ impl Default for News {
symbols: vec![],
start: None,
end: None,
limit: Some(50),
limit: Some(MAX_LIMIT),
include_content: Some(true),
exclude_contentless: Some(false),
page_token: None,

View File

@@ -1,10 +1,12 @@
use crate::{
types::alpaca::shared::{order::Side, Sort},
types::alpaca::shared::{order, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]

View File

@@ -1,7 +1,7 @@
use crate::{impl_from_enum, types};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Deserialize)]
#[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto
);
#[derive(Deserialize)]
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,

View File

@@ -1,4 +1,3 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static;
use regex::Regex;
@@ -7,12 +6,21 @@ lazy_static! {
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
}
pub fn normalize_html_content(content: &str) -> String {
pub fn strip(content: &str) -> String {
let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim();
content.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip() {
let content = "<p> <b> Hello, </b> <i> World! </i> </p>";
assert_eq!(strip(content), "Hello, World!");
}
}

View File

@@ -223,3 +223,53 @@ impl Order {
orders
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let order_template = Order {
id: Uuid::new_v4(),
client_order_id: Uuid::new_v4(),
created_at: OffsetDateTime::now_utc(),
updated_at: None,
submitted_at: OffsetDateTime::now_utc(),
filled_at: None,
expired_at: None,
cancel_requested_at: None,
canceled_at: None,
failed_at: None,
replaced_at: None,
replaced_by: None,
replaces: None,
asset_id: Uuid::new_v4(),
symbol: "AAPL".to_string(),
asset_class: super::super::asset::Class::UsEquity,
notional: None,
qty: None,
filled_qty: 0.0,
filled_avg_price: None,
order_class: Class::Simple,
order_type: Type::Market,
side: Side::Buy,
time_in_force: TimeInForce::Day,
limit_price: None,
stop_price: None,
status: Status::New,
extended_hours: false,
legs: None,
trail_percent: None,
trail_price: None,
hwm: None,
};
let mut order = order_template.clone();
order.legs = Some(vec![order_template.clone(), order_template.clone()]);
order.legs.as_mut().unwrap()[0].legs = Some(vec![order_template.clone()]);
let orders = order.normalize();
assert_eq!(orders.len(), 4);
}
}

View File

@@ -28,15 +28,14 @@ pub struct Message {
impl From<Message> for Bar {
fn from(bar: Message) -> Self {
Self {
time: bar.time,
symbol: bar.symbol,
time: bar.time,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}

View File

@@ -1,5 +1,5 @@
use crate::{
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News},
types::{alpaca::shared::news::strip, News},
utils::de,
};
use serde::Deserialize;
@@ -31,13 +31,11 @@ impl From<Message> for News {
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: normalize_html_content(&news.headline),
author: normalize_html_content(&news.author),
source: normalize_html_content(&news.source),
summary: normalize_html_content(&news.summary),
content: normalize_html_content(&news.content),
sentiment: Sentiment::Neutral,
confidence: 0.0,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}

View File

@@ -6,13 +6,13 @@ use serde::Deserialize;
pub enum Message {
#[serde(rename_all = "camelCase")]
Market {
trades: Vec<String>,
quotes: Vec<String>,
bars: Vec<String>,
updated_bars: Vec<String>,
daily_bars: Vec<String>,
statuses: Vec<String>,
trades: Option<Vec<String>>,
quotes: Option<Vec<String>>,
daily_bars: Option<Vec<String>>,
orderbooks: Option<Vec<String>>,
statuses: Option<Vec<String>>,
lulds: Option<Vec<String>>,
cancel_errors: Option<Vec<String>>,
},

View File

@@ -1,10 +1,7 @@
pub mod incoming;
pub mod outgoing;
use crate::{
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use crate::types::alpaca::websocket;
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
@@ -17,6 +14,8 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
@@ -32,8 +31,8 @@ pub async fn authenticate(
sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message {
key: (*ALPACA_API_KEY).clone(),
secret: (*ALPACA_API_SECRET).clone(),
key: api_key,
secret: api_secret,
},
))
.unwrap(),

View File

@@ -1,4 +1,5 @@
use crate::utils::ser;
use nonempty::NonEmpty;
use serde::Serialize;
#[derive(Serialize)]
@@ -6,14 +7,14 @@ use serde::Serialize;
pub enum Market {
#[serde(rename_all = "camelCase")]
UsEquity {
bars: Vec<String>,
updated_bars: Vec<String>,
statuses: Vec<String>,
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
statuses: NonEmpty<String>,
},
#[serde(rename_all = "camelCase")]
Crypto {
bars: Vec<String>,
updated_bars: Vec<String>,
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
},
}
@@ -23,12 +24,12 @@ pub enum Message {
Market(Market),
News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: Vec<String>,
news: NonEmpty<String>,
},
}
impl Message {
pub fn new_market_us_equity(symbols: Vec<String>) -> Self {
pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::UsEquity {
bars: symbols.clone(),
updated_bars: symbols.clone(),
@@ -36,14 +37,14 @@ impl Message {
})
}
pub fn new_market_crypto(symbols: Vec<String>) -> Self {
pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::Crypto {
bars: symbols.clone(),
updated_bars: symbols,
})
}
pub fn new_news(symbols: Vec<String>) -> Self {
pub fn new_news(symbols: NonEmpty<String>) -> Self {
Self::News { news: symbols }
}
}

Some files were not shown because too many files have changed in this diff Show More