53 Commits
ollama ... main

Author SHA1 Message Date
90b7f10a77 Update and fix bugs
It's good to be back

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-05-10 17:49:16 +01:00
d7e9350257 Add initial ML implementation
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-29 13:40:49 +00:00
f715881b07 Remove unsafe blocks
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:51:52 +00:00
d0ad9f65b1 Remove vwap
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:51:16 +00:00
ce8c4db422 Remove local XKCD image
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-26 09:50:39 +00:00
46508d1b4f Update reqwest
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 20:00:35 +00:00
2ad42c5462 Foldify for loops
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:48 +00:00
733e6373e9 Add tests
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:47 +00:00
d072b849c0 Reorganize crate source code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-20 19:43:46 +00:00
718e794f51 Reorder Bar struct
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-19 15:48:45 +00:00
b7a175d5b4 Improve ser function naming
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-17 09:22:45 +00:00
e9012d6ec3 Add backfill progress logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-15 17:26:58 +00:00
10365745aa Attempt to fix bugs related to empty vecs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 22:38:20 +00:00
8202255132 Fix backfill freshness
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 12:47:52 +00:00
0d276d537c Add websocket infinite inserting
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 01:46:18 +00:00
1707d74cf7 Improve alpaca request error handling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 19:45:06 +00:00
f3f9c6336b Remove rust-bert
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 12:09:50 +00:00
5ed0c7670a Fix backfill sentiment batching bug
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-12 21:00:11 +00:00
d2d20e2978 Add automatic websocket reconnection
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 23:41:06 +00:00
d02f958865 Optimize backfill early saving allocations
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:41:59 +00:00
2d8972dce2 Fix possible crashes on .unwrap()s
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:15:19 +00:00
7bacc2565a Fix CI
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:22 +00:00
b60cbc891d Add backfill early saving
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:12 +00:00
2de86b46f7 Improve backfill error logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 19:51:41 +00:00
8c7ee3d12d Add shared lib
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 18:28:40 +00:00
a15fd2c3c9 Separate data management code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 16:59:21 +00:00
acfc0ca4c9 Add pipelined backfilling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 11:22:24 +00:00
681d7393d7 Add multiple asset adding route
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-09 20:13:36 +00:00
080f91b044 Fix README
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-26 19:01:22 +00:00
3006264af1 Fix calendar EST offset
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-26 18:55:15 +00:00
a84daea61c Add local market calendar storage
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-22 12:35:01 +00:00
0b9c6ca122 Add defaults for outgoing types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-22 11:31:28 +00:00
4665891316 Fix error on initialization with no symbols
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-19 17:51:44 +00:00
4f73058792 Add calendar
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-17 20:36:02 +00:00
152a0b4682 Fix bad request response handling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-17 20:35:50 +00:00
ae5044142d Fix status message deserialization
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-16 19:39:36 +00:00
a1781cdf29 Remove manual pongs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-15 14:17:47 +00:00
cdaa2d20a9 Update random bits and bobs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-15 01:09:16 +00:00
4b194e168f Add paper URL support
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 21:15:27 +00:00
6f85b9b0e8 Fix string to number deserialization
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 18:53:58 +00:00
6adf2b46c8 Add partial account management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 17:38:56 +00:00
648d413ac7 Add order/position management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 17:07:30 +00:00
6ec71ee144 Add position types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 10:48:37 +00:00
5961717520 Add order types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-12 16:45:11 +00:00
dee21d5324 Add asset status management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-09 15:43:42 +00:00
76bf2fddcb Clean up error propagation
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-08 18:13:52 +00:00
52e88f4bc9 Remove asset_status thread
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-07 20:40:11 +00:00
85eef2bf0b Refactor threads to use trait implementations
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 13:47:43 +00:00
a796feb299 Lower incoming data log level
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:32:46 +00:00
a2bcb6d17e Make sentiment predictions blocking
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:30:32 +00:00
caaa31133a Improve outgoing Alpaca API types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:30:11 +00:00
61c573cbc7 Remove stored abbreviation
- Alpaca is fuck

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-04 21:24:14 +00:00
65c9ae8b25 Add finbert sentiment analysis
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-03 18:58:40 +00:00
149 changed files with 10446 additions and 2493 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables
debug/
target/
log/
# These are backup files generated by rustfmt
**/*.rs.bk

View File

@@ -22,7 +22,7 @@ build:
cache:
<<: *global_cache
script:
- cargo +nightly build
- cargo +nightly build --workspace
test:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -30,7 +30,7 @@ test:
cache:
<<: *global_cache
script:
- cargo +nightly test
- cargo +nightly test --workspace
lint:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -39,7 +39,7 @@ lint:
<<: *global_cache
script:
- cargo +nightly fmt --all -- --check
- cargo +nightly clippy --all-targets --all-features
- cargo +nightly clippy --workspace --all-targets --all-features
depcheck:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -48,7 +48,7 @@ depcheck:
<<: *global_cache
script:
- cargo +nightly outdated
- cargo +nightly udeps
- cargo +nightly udeps --workspace --all-targets
build-release:
image: registry.karaolidis.com/karaolidis/qrust/rust

3196
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,18 @@ name = "qrust"
version = "0.1.0"
edition = "2021"
[lib]
name = "qrust"
path = "src/lib/qrust/mod.rs"
[[bin]]
name = "qrust"
path = "src/bin/qrust/mod.rs"
[[bin]]
name = "trainer"
path = "src/bin/trainer/mod.rs"
[profile.release]
panic = 'abort'
strip = true
@@ -12,9 +24,9 @@ codegen-units = 1
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.7.4"
axum = "0.7.5"
dotenv = "0.15.0"
tokio = { version = "1.32.0", features = [
tokio = { version = "1.37.0", features = [
"macros",
"rt-multi-thread",
] }
@@ -22,33 +34,55 @@ tokio-tungstenite = { version = "0.21.0", features = [
"tokio-native-tls",
"native-tls",
] }
log = "0.4.20"
log4rs = "1.2.0"
serde = "1.0.188"
serde_json = "1.0.105"
serde_repr = "0.1.18"
serde_with = "3.5.1"
futures-util = "0.3.28"
reqwest = { version = "0.11.20", features = [
log = "0.4.21"
log4rs = "1.3.0"
serde = "1.0.201"
serde_json = "1.0.117"
serde_repr = "0.1.19"
serde_with = "3.8.1"
serde-aux = "4.5.0"
futures-util = "0.3.30"
reqwest = { version = "0.12.4", features = [
"json",
"serde_json",
] }
http = "1.0.0"
governor = "0.6.0"
http = "1.1.0"
governor = "0.6.3"
clickhouse = { version = "0.11.6", features = [
"watch",
"time",
"uuid",
] }
uuid = "1.6.1"
time = { version = "0.3.31", features = [
uuid = { version = "1.8.0", features = [
"serde",
"v4",
] }
time = { version = "0.3.36", features = [
"serde",
"serde-well-known",
"serde-human-readable",
"formatting",
"macros",
"serde-well-known",
"local-offset",
] }
backoff = { version = "0.4.0", features = [
"tokio",
] }
regex = "1.10.3"
html-escape = "0.2.13"
regex = "1.10.4"
async-trait = "0.1.80"
itertools = "0.12.1"
lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [
"serialize",
] }
rand = "0.8.5"
rayon = "1.10.0"
burn = { version = "0.13.2", features = [
"wgpu",
"cuda",
"tui",
"metrics",
"train",
] }
[dev-dependencies]
serde_test = "1.0.176"

View File

@@ -1,3 +1,5 @@
# QRust
# qrust
QRust (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.
![XKCD - Engineer Syllogism](https://imgs.xkcd.com/comics/engineer_syllogism.png)
`qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.

View File

@@ -4,7 +4,14 @@ appenders:
encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root:
level: info
appenders:
- stdout
- file

86
src/bin/qrust/config.rs Normal file
View File

@@ -0,0 +1,86 @@
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use qrust::types::alpaca::shared::{Mode, Source};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
use tokio::sync::Semaphore;
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
};
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String =
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
}
impl Config {
pub fn from_env() -> Self {
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&ALPACA_API_KEY)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&ALPACA_API_SECRET)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => NonZeroU32::new(200).unwrap(),
Source::Sip => NonZeroU32::new(10_000).unwrap(),
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})),
clickhouse_client: clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

136
src/bin/qrust/init.rs Normal file
View File

@@ -0,0 +1,136 @@
use crate::{
config::{Config, ALPACA_API_BASE},
database,
};
use log::{info, warn};
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
assert!(
!(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
assert!(
!account.trade_suspend_by_user,
"Account trading is suspended by user."
);
assert!(!account.trading_blocked, "Account trading is blocked.");
assert!(!account.blocked, "Account is blocked.");
if account.cash == 0.0 {
warn!("Account cash is zero, qrust will not be able to trade.");
}
info!(
"qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_API_BASE, account.currency, account.cash
);
}
pub async fn rehydrate_orders(config: &Arc<Config>) {
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
loop {
let message = alpaca::orders::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::order::Order {
status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
if message.is_empty() {
break;
}
orders.extend(message);
after = orders.last().unwrap().submitted_at;
}
let orders = orders
.into_iter()
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
}
pub async fn rehydrate_positions(config: &Arc<Config>) {
let positions_future = async {
alpaca::positions::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let assets_future = async {
database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
};
let (mut positions, assets) = join!(positions_future, assets_future);
let assets = assets
.into_iter()
.map(|mut asset| {
if let Some(position) = positions.remove(&asset.symbol) {
asset.qty = position.qty_available;
} else {
asset.qty = 0.0;
}
asset
})
.collect::<Vec<_>>();
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
for position in positions.values() {
warn!(
"Position for unmonitored asset: {}, {} shares.",
position.symbol, position.qty
);
}
}

115
src/bin/qrust/mod.rs Normal file
View File

@@ -0,0 +1,115 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
#![allow(clippy::missing_docs_in_private_items)]
#![feature(hash_extract_if)]
mod config;
mod init;
mod routes;
mod threads;
use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
};
use dotenv::dotenv;
use log::info;
use log4rs::config::Deserializers;
use nonempty::NonEmpty;
use qrust::{create_send_await, database};
use tokio::{join, spawn, sync::mpsc, try_join};
#[tokio::main]
async fn main() {
dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
info!("Marking all assets as stale.");
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
try_join!(
database::backfills_bars::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
),
database::backfills_news::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
)
)
.unwrap();
info!("Cleaning up database.");
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Optimizing database.");
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Rehydrating account data.");
init::check_account(&config).await;
join!(
init::rehydrate_orders(&config),
init::rehydrate_positions(&config)
);
info!("Starting threads.");
spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
spawn(threads::data::run(
config.clone(),
data_receiver,
clock_receiver,
));
spawn(threads::clock::run(config.clone(), clock_sender));
if let Some(assets) = NonEmpty::from_vec(assets) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
}
routes::run(config, data_sender).await;
}

View File

@@ -0,0 +1,197 @@
use crate::{
config::{Config, ALPACA_API_BASE},
create_send_await, database, threads,
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use nonempty::{nonempty, NonEmpty};
use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
})
}
#[derive(Deserialize)]
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
}
Ok((
StatusCode::OK,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?;
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN);
}
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
nonempty![(asset.symbol, asset.class.into())]
);
Ok(StatusCode::CREATED)
}
pub async fn delete(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
nonempty![(asset.symbol, asset.class)]
);
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -1,5 +1,5 @@
pub mod assets;
pub mod health;
mod assets;
mod health;
use crate::{config::Config, threads};
use axum::{
@@ -10,18 +10,16 @@ use log::info;
use std::{net::SocketAddr, sync::Arc};
use tokio::{net::TcpListener, sync::mpsc};
pub async fn run(
app_config: Arc<Config>,
asset_status_sender: mpsc::Sender<threads::data::asset_status::Message>,
) {
pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::Message>) {
let app = Router::new()
.route("/health", get(health::get))
.route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete))
.layer(Extension(app_config))
.layer(Extension(asset_status_sender));
.layer(Extension(config))
.layer(Extension(data_sender));
let addr = SocketAddr::from(([0, 0, 0, 0], 7878));
let listener = TcpListener::bind(addr).await.unwrap();

View File

@@ -0,0 +1,89 @@
use crate::{
config::{Config, ALPACA_API_BASE},
database,
};
use log::info;
use qrust::{
alpaca,
types::{self, Calendar},
utils::{backoff, duration_until},
};
use std::sync::Arc;
use tokio::{join, sync::mpsc, time::sleep};
#[derive(PartialEq, Eq)]
pub enum Status {
Open,
Closed,
}
pub struct Message {
pub status: Status,
}
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
Self {
status: if clock.is_open {
Status::Open
} else {
Status::Closed
},
}
}
}
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock_future = async {
alpaca::clock::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
};
let calendar_future = async {
alpaca::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(Calendar::from)
.collect::<Vec<_>>()
};
let (clock, calendar) = join!(clock_future, calendar_future);
let sleep_until = duration_until(if clock.is_open {
info!("Market is open, will close at {}.", clock.next_close);
clock.next_close
} else {
info!("Market is closed, will reopen at {}.", clock.next_open);
clock.next_open
});
let sleep_future = sleep(sleep_until);
let calendar_future = async {
database::calendar::upsert_batch_and_delete(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
};
join!(sleep_future, calendar_future);
sender.send(clock.into()).await.unwrap();
}
}

View File

@@ -0,0 +1,238 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::{
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
shared::{Sort, Source},
},
Backfill, Bar, Class,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
pub data_url: &'static str,
pub api_query_constructor: fn(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
feed: Some(*ALPACA_SOURCE),
..Default::default()
})
}
pub fn crypto_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
..Default::default()
})
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbols.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
database::backfills_bars::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled bars for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"bars"
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
_ => unreachable!(),
};
let api_query_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
_ => unreachable!(),
};
Box::new(Handler {
config,
data_url,
api_query_constructor,
})
}

View File

@@ -0,0 +1,243 @@
pub mod bars;
pub mod news;
use async_trait::async_trait;
use itertools::Itertools;
use log::{info, warn};
use nonempty::{nonempty, NonEmpty};
use qrust::{
types::Backfill,
utils::{last_minute, ONE_SECOND},
};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::JoinHandle,
try_join,
};
use uuid::Uuid;
pub enum Action {
Backfill,
Purge,
}
pub struct Message {
pub action: Action,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[derive(Clone)]
pub struct Job {
pub symbol: String,
pub fetch_from: OffsetDateTime,
pub fetch_to: OffsetDateTime,
pub fresh: bool,
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
async fn backfill(&self, jobs: NonEmpty<Job>);
fn max_limit(&self) -> i64;
fn log_string(&self) -> &'static str;
}
pub struct Jobs {
pub symbol_to_uuid: HashMap<String, Uuid>,
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
}
impl Jobs {
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
let uuid = Uuid::new_v4();
for symbol in jobs {
self.symbol_to_uuid.insert(symbol.clone(), uuid);
}
self.uuid_to_job.insert(uuid, fut);
}
pub fn contains_key(&self, symbol: &str) -> bool {
self.symbol_to_uuid.contains_key(symbol)
}
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
self.symbol_to_uuid
.remove(symbol)
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
}
pub fn remove_many<T>(&mut self, symbols: &[T])
where
T: AsRef<str> + Hash + Eq,
{
for symbol in symbols {
self.symbol_to_uuid
.remove(symbol.as_ref())
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
}
}
pub fn len(&self) -> usize {
self.symbol_to_uuid.len()
}
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(Jobs {
symbol_to_uuid: HashMap::new(),
uuid_to_job: HashMap::new(),
}));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<Jobs>>,
message: Message,
) {
let backfill_jobs_clone = backfill_jobs.clone();
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = Vec::from(message.symbols);
match message.action {
Action::Backfill => {
let log_string = handler.log_string();
let max_limit = handler.max_limit();
let backfills = handler
.select_latest_backfills(&symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
let mut jobs = Vec::with_capacity(symbols.len());
for symbol in symbols {
if backfill_jobs.contains_key(&symbol) {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
let backfill = backfills.get(&symbol);
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
jobs.push(Job {
symbol,
fetch_from,
fetch_to,
fresh,
});
}
let mut current_minutes = 0;
let job_groups = jobs
.into_iter()
.sorted_unstable_by_key(|job| job.fetch_from)
.fold(Vec::<NonEmpty<Job>>::new(), |mut job_groups, job| {
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
if let Some(job_group) = job_groups.last_mut() {
if current_minutes + minutes <= max_limit {
job_group.push(job);
current_minutes += minutes;
return job_groups;
}
}
job_groups.push(nonempty![job]);
current_minutes = minutes;
job_groups
});
for job_group in job_groups {
let symbols = job_group
.iter()
.map(|job| job.symbol.clone())
.collect::<Vec<_>>();
let handler = handler.clone();
let symbols_clone = symbols.clone();
let backfill_jobs_clone = backfill_jobs_clone.clone();
let fut = spawn(async move {
handler.queue_backfill(&job_group).await;
handler.backfill(job_group).await;
let mut backfill_jobs = backfill_jobs_clone.lock().await;
backfill_jobs.remove_many(&symbols_clone);
let remaining = backfill_jobs.len();
drop(backfill_jobs);
info!("{} {} backfills remaining.", remaining, log_string);
});
backfill_jobs.insert(symbols, fut);
}
}
Action::Purge => {
for symbol in &symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
job.abort();
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&symbols),
handler.delete_data(&symbols)
)
.unwrap();
}
}
message.response.send(()).unwrap();
}

View File

@@ -0,0 +1,186 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
database,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::shared::{Sort, Source},
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
database::backfills_news::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
Box::new(Handler { config })
}

View File

@@ -0,0 +1,391 @@
mod backfill;
mod websocket;
use super::clock;
use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database,
};
use itertools::{Either, Itertools};
use log::error;
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Action {
Add,
Enable,
Remove,
Disable,
}
pub struct Message {
pub action: Action,
pub assets: NonEmpty<(String, Class)>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
assets,
response: sender,
},
receiver,
)
}
}
#[derive(Clone, Copy)]
pub enum ThreadType {
Bars(Class),
News,
}
pub async fn run(
config: Arc<Config>,
mut receiver: mpsc::Receiver<Message>,
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News);
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
config.clone(),
bars_us_equity_websocket_sender.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_websocket_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_websocket_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
Some(message) = clock_receiver.recv() => {
spawn(handle_clock_message(
config.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
fn init_thread(
config: Arc<Config>,
thread_type: ThreadType,
) -> (
mpsc::Sender<websocket::Message>,
mpsc::Sender<backfill::Message>,
) {
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
}
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
};
let backfill_handler = match thread_type {
ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
ThreadType::News => backfill::news::create_handler(config.clone()),
};
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(backfill_handler.into(), backfill_receiver));
let websocket_handler = match thread_type {
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
ThreadType::News => websocket::news::create_handler(&config),
};
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run(
websocket_handler.into(),
websocket_receiver,
websocket_url,
));
(websocket_sender, backfill_sender)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_lines)]
async fn handle_message(
config: Arc<Config>,
bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_websocket_sender: mpsc::Sender<websocket::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message,
) {
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets
.clone()
.into_iter()
.partition_map(|asset| match asset.1 {
Class::UsEquity => Either::Left(asset.0),
Class::Crypto => Either::Right(asset.0),
});
let symbols = message.assets.map(|(symbol, _)| symbol);
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_websocket_sender,
websocket::Message::new,
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
match message.action {
Action::Add | Action::Enable => {
let symbols = Vec::from(symbols.clone());
let assets = async {
alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let batch =
symbols
.iter()
.fold(Vec::with_capacity(symbols.len()), |mut batch, symbol| {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}.", symbol);
}
batch
});
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&Vec::from(symbols.clone()),
)
.await
.unwrap();
}
Action::Disable => unreachable!(),
}
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap();
}
async fn handle_clock_message(
config: Arc<Config>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: clock::Message,
) {
if message.status == clock::Status::Closed {
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
}
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()
.into_iter()
.partition_map(|asset| match asset.class {
Class::UsEquity => Either::Left(asset.symbol),
Class::Crypto => Either::Right(asset.symbol),
});
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
crypto_symbols
);
}
};
let news_future = async {
if let Some(symbols) = NonEmpty::from_vec(symbols) {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
symbols
);
}
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
}

View File

@@ -0,0 +1,171 @@
use super::State;
use crate::{
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, Bar, Class},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<Bar>>>,
pub subscription_message_constructor:
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
self.inserter.lock().await.write(&bar).await.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("bars")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
));
let subscription_message_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
websocket::data::outgoing::subscribe::Message::new_market_us_equity
}
ThreadType::Bars(Class::Crypto) => {
websocket::data::outgoing::subscribe::Message::new_market_crypto
}
_ => unreachable!(),
};
Box::new(Handler {
config,
inserter,
subscription_message_constructor,
})
}

View File

@@ -0,0 +1,353 @@
pub mod bars;
pub mod news;
use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait;
use backoff::{future::retry_notify, ExponentialBackoff};
use clickhouse::{inserter::Inserter, Row};
use futures_util::{future::join_all, SinkExt, StreamExt};
use log::error;
use nonempty::NonEmpty;
use qrust::types::alpaca::{self, websocket};
use serde::Serialize;
use serde_json::{from_str, to_string};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct State {
pub active_subscriptions: HashSet<String>,
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync + 'static {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
);
fn log_string(&self) -> &'static str;
async fn run_inserter(&self);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
websocket_url: String,
) {
let state = Arc::new(RwLock::new(State {
active_subscriptions: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let handler_clone = handler.clone();
spawn(async move { handler_clone.run_inserter().await });
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
state.clone(),
sink_sender.clone(),
message,
));
}
Some(message) = stream_receiver.recv() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let state = state.clone();
spawn(async move {
handler.handle_websocket_message(state, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<State>>,
sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message,
) {
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
pending
.write()
.await
.pending_subscriptions
.extend(pending_subscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.pending_unsubscriptions
.extend(pending_unsubscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>)
where
T: Row + Serialize,
{
loop {
let time_left = inserter.lock().await.time_left().unwrap();
tokio::time::sleep(time_left).await;
inserter.lock().await.commit().await.unwrap();
}
}

View File

@@ -0,0 +1,119 @@
use super::State;
use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, News},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub inserter: Arc<Mutex<Inserter<News>>>,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
self.inserter.lock().await.write(&news).await.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"news"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: &Arc<Config>) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("news")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()),
));
Box::new(Handler { inserter })
}

View File

@@ -1,2 +1,3 @@
pub mod clock;
pub mod data;
pub mod trading;

View File

@@ -0,0 +1,27 @@
mod websocket;
use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET};
use futures_util::StreamExt;
use qrust::types::alpaca;
use std::sync::Arc;
use tokio::spawn;
use tokio_tungstenite::connect_async;
pub async fn run(config: Arc<Config>) {
let (websocket, _) =
connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE))
.await
.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::trading::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
spawn(websocket::run(config, websocket_stream));
}

View File

@@ -0,0 +1,79 @@
use crate::{config::Config, database};
use futures_util::{stream::SplitStream, StreamExt};
use log::{debug, error};
use qrust::types::{alpaca::websocket, Order};
use serde_json::from_str;
use std::sync::Arc;
use tokio::{net::TcpStream, spawn};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run(
config: Arc<Config>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
loop {
let message = websocket_stream.next().await.unwrap().unwrap();
match message {
tungstenite::Message::Binary(message) => {
let parsed_message = from_str::<websocket::trading::incoming::Message>(
&String::from_utf8_lossy(&message),
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
spawn(handle_websocket_message(
config.clone(),
parsed_message.unwrap(),
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
}
async fn handle_websocket_message(
config: Arc<Config>,
message: websocket::trading::incoming::Message,
) {
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}.",
message.order.symbol, message.event
);
let order = Order::from(message.order);
database::orders::upsert(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
| websocket::trading::incoming::order::Event::PartialFill {
position_qty, ..
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol,
position_qty,
)
.await
.unwrap();
}
_ => (),
}
}
_ => unreachable!(),
}
}

133
src/bin/trainer/mod.rs Normal file
View File

@@ -0,0 +1,133 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
use burn::{
config::Config,
data::{
dataloader::{DataLoaderBuilder, Dataset},
dataset::transform::{PartialDataset, ShuffledDataset},
},
module::Module,
optim::AdamConfig,
record::CompactRecorder,
tensor::backend::AutodiffBackend,
train::LearnerBuilder,
};
use dotenv::dotenv;
use log::info;
use qrust::{
database,
ml::{
BarWindow, BarWindowBatcher, ModelConfig, MultipleSymbolDataset, MyAutodiffBackend, DEVICE,
},
types::Bar,
};
use std::{env, fs, path::Path, sync::Arc};
use tokio::sync::Semaphore;
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 100)]
pub epochs: usize,
#[config(default = 256)]
pub batch_size: usize,
#[config(default = 16)]
pub num_workers: usize,
#[config(default = 0)]
pub seed: u64,
#[config(default = 0.2)]
pub valid_pct: f64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
#[tokio::main]
async fn main() {
dotenv().ok();
let dir = Path::new(file!()).parent().unwrap();
let model_config = ModelConfig::new();
let optimizer = AdamConfig::new();
let training_config = TrainingConfig::new(model_config, optimizer);
let clickhouse_client = clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."))
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."));
let clickhouse_concurrency_limiter = Arc::new(Semaphore::new(Semaphore::MAX_PERMITS));
let bars = database::ta::select(&clickhouse_client, &clickhouse_concurrency_limiter)
.await
.unwrap();
info!("Loaded {} bars.", bars.len());
train::<MyAutodiffBackend>(
bars,
&training_config,
dir.join("artifacts").to_str().unwrap(),
&DEVICE,
);
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
fn train<B: AutodiffBackend<FloatElem = f32, IntElem = i32>>(
bars: Vec<Bar>,
config: &TrainingConfig,
dir: &str,
device: &B::Device,
) {
B::seed(config.seed);
fs::create_dir_all(dir).unwrap();
let dataset = MultipleSymbolDataset::new(bars);
let dataset = ShuffledDataset::with_seed(dataset, config.seed);
let dataset = Arc::new(dataset);
let split = (dataset.len() as f64 * (1.0 - config.valid_pct)) as usize;
let train: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), 0, split);
let batcher_train = BarWindowBatcher::<B> {
device: device.clone(),
};
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(train);
let valid: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), split, dataset.len());
let batcher_valid = BarWindowBatcher::<B::InnerBackend> {
device: device.clone(),
};
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(valid);
let learner = LearnerBuilder::new(dir)
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.epochs)
.build(
config.model.init::<B>(device),
config.optimizer.init(),
config.learning_rate,
);
let trained = learner.fit(dataloader_train, dataloader_valid);
trained.save_file(dir, &CompactRecorder::new()).unwrap();
}

View File

@@ -1,78 +0,0 @@
use crate::types::alpaca::Source;
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets";
pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock";
pub const ALPACA_STOCK_DATA_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
pub struct Config {
pub alpaca_api_key: String,
pub alpaca_api_secret: String,
pub alpaca_client: Client,
pub alpaca_rate_limit: DefaultDirectRateLimiter,
pub alpaca_source: Source,
pub clickhouse_client: clickhouse::Client,
}
impl Config {
pub fn from_env() -> Self {
let alpaca_api_key = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
let alpaca_api_secret =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
let alpaca_source: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be a either 'iex' or 'sip'.");
let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.");
let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.");
let clickhouse_password =
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.");
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&alpaca_api_key)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&alpaca_api_secret)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source {
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
})),
alpaca_source,
clickhouse_client: clickhouse::Client::default()
.with_url(clickhouse_url)
.with_user(clickhouse_user)
.with_password(clickhouse_password)
.with_database(clickhouse_db),
alpaca_api_key,
alpaca_api_secret,
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

View File

@@ -1,48 +0,0 @@
use crate::types::Asset;
use clickhouse::Client;
use serde::Serialize;
pub async fn select(clickhouse_client: &Client) -> Vec<Asset> {
clickhouse_client
.query("SELECT ?fields FROM assets FINAL")
.fetch_all::<Asset>()
.await
.unwrap()
}
pub async fn select_where_symbol<T>(clickhouse_client: &Client, symbol: &T) -> Option<Asset>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("SELECT ?fields FROM assets FINAL WHERE symbol = ? OR abbreviation = ?")
.bind(symbol)
.bind(symbol)
.fetch_optional::<Asset>()
.await
.unwrap()
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, assets: T)
where
T: IntoIterator<Item = Asset> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("assets").unwrap();
for asset in assets {
insert.write(&asset).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM assets WHERE symbol IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}

View File

@@ -1,93 +0,0 @@
use crate::{database::assets, threads::data::ThreadType, types::Backfill};
use clickhouse::Client;
use serde::Serialize;
use tokio::join;
pub async fn select_latest_where_symbol<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbol: &T,
) -> Option<Backfill>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbol)
.fetch_optional::<Backfill>()
.await
.unwrap()
}
pub async fn upsert(clickhouse_client: &Client, thread_type: &ThreadType, backfill: &Backfill) {
let mut insert = clickhouse_client
.insert(match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
})
.unwrap();
insert.write(backfill).await.unwrap();
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbols: &[T],
) where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"DELETE FROM {} WHERE symbol IN ?",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let bars_symbols = assets
.clone()
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let news_symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
let delete_bars_future = async {
clickhouse_client
.query("DELETE FROM backfills_bars WHERE symbol NOT IN ?")
.bind(bars_symbols)
.execute()
.await
.unwrap();
};
let delete_news_future = async {
clickhouse_client
.query("DELETE FROM backfills_news WHERE symbol NOT IN ?")
.bind(news_symbols)
.execute()
.await
.unwrap();
};
join!(delete_bars_future, delete_news_future);
}

View File

@@ -1,50 +0,0 @@
use super::assets;
use crate::types::Bar;
use clickhouse::Client;
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, bar: &Bar) {
let mut insert = clickhouse_client.insert("bars").unwrap();
insert.write(bar).await.unwrap();
insert.end().await.unwrap();
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, bars: T)
where
T: IntoIterator<Item = Bar> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("bars").unwrap();
for bar in bars {
insert.write(&bar).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM bars WHERE symbol IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
clickhouse_client
.query("DELETE FROM bars WHERE symbol NOT IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}

View File

@@ -1,4 +0,0 @@
pub mod assets;
pub mod backfills;
pub mod bars;
pub mod news;

View File

@@ -1,50 +0,0 @@
use super::assets;
use crate::types::News;
use clickhouse::Client;
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, news: &News) {
let mut insert = clickhouse_client.insert("news").unwrap();
insert.write(news).await.unwrap();
insert.end().await.unwrap();
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, news: T)
where
T: IntoIterator<Item = News> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("news").unwrap();
for news in news {
insert.write(&news).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM news WHERE hasAny(symbols, ?)")
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
clickhouse_client
.query("DELETE FROM news WHERE NOT hasAny(symbols, ?)")
.bind(symbols)
.execute()
.await
.unwrap();
}

View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Account>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,132 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Asset>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Asset>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -0,0 +1,50 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,41 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Calendar>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Clock>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,27 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;
use http::StatusCode;
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
if err.is_status() {
return match err.status() {
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
| None => backoff::Error::Permanent(err),
_ => err.into(),
};
}
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
{
return backoff::Error::Permanent(err);
}
err.into()
}

View File

@@ -0,0 +1,49 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,43 @@
use super::error_to_backoff;
use crate::types::alpaca::{api::outgoing, shared::order};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use order::Order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Order>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,109 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use http::StatusCode;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Position>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(error_to_backoff)?
.json::<Position>()
.await
.map_err(error_to_backoff)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -0,0 +1,50 @@
use std::sync::Arc;
use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets");
select_where_symbol!(Asset, "assets");
upsert_batch!(Asset, "assets");
delete_where_symbols!("assets");
optimize!("assets");
pub async fn update_status_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
status: bool,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
.execute()
.await
}
pub async fn update_qty_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol)
.execute()
.await
}

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_bars");
upsert_batch!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
set_fresh_where_symbols!("backfills_bars");

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_news");
upsert_batch!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
set_fresh_where_symbols!("backfills_news");

View File

@@ -0,0 +1,21 @@
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -0,0 +1,42 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar};
use clickhouse::{error::Error, Client};
use tokio::{sync::Semaphore, try_join};
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, I>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: I,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
I::IntoIter: Send,
{
let upsert_future = async {
let mut insert = client.insert("calendar")?;
for record in records.clone() {
insert.write(record).await?;
}
insert.end().await
};
let delete_future = async {
let dates = records
.clone()
.into_iter()
.map(|r| r.date)
.collect::<Vec<_>>();
client
.query("DELETE FROM calendar WHERE date NOT IN ?")
.bind(dates)
.execute()
.await
};
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ())
}

View File

@@ -0,0 +1,224 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
pub mod ta;
use clickhouse::{error::Error, Client};
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, I>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: I,
) -> Result<(), clickhouse::error::Error>
where
I: IntoIterator<Item = &'a $record> + Send + Sync,
I::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! set_fresh_where_symbols {
($table_name:expr) => {
pub async fn set_fresh_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
fresh: bool,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
$table_name
))
.bind(fresh)
.bind(symbols)
.execute()
.await
}
};
}
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -0,0 +1,36 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");
pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols)
.execute()
.await
}
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)
.execute()
.await
}

View File

@@ -0,0 +1,5 @@
use crate::{optimize, types::Order, upsert, upsert_batch};
upsert!(Order, "orders");
upsert_batch!(Order, "orders");
optimize!("orders");

View File

@@ -0,0 +1,30 @@
use crate::types::Bar;
use clickhouse::{error::Error, Client};
use std::sync::Arc;
use tokio::sync::Semaphore;
pub async fn select(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<Vec<Bar>, Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"
SELECT symbol,
toStartOfHour(bars.time) AS time,
any(bars.open) AS open,
max(bars.high) AS high,
min(bars.low) AS low,
anyLast(bars.close) AS close,
sum(bars.volume) AS volume,
sum(bars.trades) AS trades
FROM bars FINAL
GROUP BY ALL
ORDER BY symbol,
time
",
)
.fetch_all::<Bar>()
.await
}

View File

@@ -0,0 +1,75 @@
use super::BarWindow;
use burn::{
data::dataloader::batcher::Batcher,
tensor::{self, backend::Backend, Tensor},
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Clone, Debug)]
pub struct BarWindowBatcher<B: Backend> {
pub device: B::Device,
}
#[derive(Clone, Debug)]
pub struct BarWindowBatch<B: Backend> {
pub hour_tensor: Tensor<B, 2, tensor::Int>,
pub day_tensor: Tensor<B, 2, tensor::Int>,
pub numerical_tensor: Tensor<B, 3>,
pub target_tensor: Tensor<B, 2>,
}
impl<B: Backend<FloatElem = f32, IntElem = i32>> Batcher<BarWindow, BarWindowBatch<B>>
for BarWindowBatcher<B>
{
fn batch(&self, items: Vec<BarWindow>) -> BarWindowBatch<B> {
let batch_size = items.len();
let (hour_tensors, day_tensors, numerical_tensors, target_tensors) = items
.into_par_iter()
.fold(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.push(Tensor::from_data(item.hours, &self.device));
day_tensors.push(Tensor::from_data(item.days, &self.device));
numerical_tensors.push(Tensor::from_data(item.numerical, &self.device));
target_tensors.push(Tensor::from_data(item.target, &self.device));
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
)
.reduce(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.extend(item.0);
day_tensors.extend(item.1);
numerical_tensors.extend(item.2);
target_tensors.extend(item.3);
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
);
BarWindowBatch {
hour_tensor: Tensor::stack(hour_tensors, 0).to_device(&self.device),
day_tensor: Tensor::stack(day_tensors, 0).to_device(&self.device),
numerical_tensor: Tensor::stack(numerical_tensors, 0).to_device(&self.device),
target_tensor: Tensor::stack(target_tensors, 0).to_device(&self.device),
}
}
}

219
src/lib/qrust/ml/dataset.rs Normal file
View File

@@ -0,0 +1,219 @@
use crate::types::{
ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT},
Bar,
};
use burn::{
data::dataset::{transform::ComposedDataset, Dataset},
tensor::Data,
};
pub const WINDOW_SIZE: usize = 48;
#[derive(Clone, Debug)]
pub struct BarWindow {
pub hours: Data<i32, 1>,
pub days: Data<i32, 1>,
pub numerical: Data<f32, 2>,
pub target: Data<f32, 1>,
}
#[derive(Clone, Debug)]
struct SingleSymbolDataset {
hours: Vec<i32>,
days: Vec<i32>,
numerical: Vec<[f32; NUMERICAL_FIELD_COUNT]>,
targets: Vec<f32>,
}
impl SingleSymbolDataset {
#[allow(clippy::cast_possible_truncation)]
pub fn new(bars: Vec<IndicatedBar>) -> Self {
if !bars.is_empty() {
let symbol = &bars[0].symbol;
assert!(bars.iter().all(|bar| bar.symbol == *symbol));
}
let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold(
(
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
),
|(mut hours, mut days, mut numerical, mut targets), bar| {
hours.push(i32::from(bar[0].hour));
days.push(i32::from(bar[0].day));
numerical.push([
bar[0].open as f32,
(bar[0].open_pct as f32).min(f32::MAX),
bar[0].high as f32,
(bar[0].high_pct as f32).min(f32::MAX),
bar[0].low as f32,
(bar[0].low_pct as f32).min(f32::MAX),
bar[0].close as f32,
(bar[0].close_pct as f32).min(f32::MAX),
bar[0].volume as f32,
(bar[0].volume_pct as f32).min(f32::MAX),
bar[0].trades as f32,
(bar[0].trades_pct as f32).min(f32::MAX),
bar[0].sma_3 as f32,
bar[0].sma_6 as f32,
bar[0].sma_12 as f32,
bar[0].sma_24 as f32,
bar[0].sma_48 as f32,
bar[0].sma_72 as f32,
bar[0].ema_3 as f32,
bar[0].ema_6 as f32,
bar[0].ema_12 as f32,
bar[0].ema_24 as f32,
bar[0].ema_48 as f32,
bar[0].ema_72 as f32,
bar[0].macd as f32,
bar[0].macd_signal as f32,
bar[0].obv as f32,
bar[0].rsi as f32,
bar[0].bbands_lower as f32,
bar[0].bbands_mean as f32,
bar[0].bbands_upper as f32,
]);
targets.push(bar[1].close_pct as f32);
(hours, days, numerical, targets)
},
);
Self {
hours,
days,
numerical,
targets,
}
}
}
impl Dataset<BarWindow> for SingleSymbolDataset {
fn len(&self) -> usize {
self.targets.len() - WINDOW_SIZE + 1
}
#[allow(clippy::single_range_in_vec_init)]
fn get(&self, idx: usize) -> Option<BarWindow> {
if idx >= self.len() {
return None;
}
let hours: [i32; WINDOW_SIZE] = self.hours[idx..idx + WINDOW_SIZE].try_into().unwrap();
let days: [i32; WINDOW_SIZE] = self.days[idx..idx + WINDOW_SIZE].try_into().unwrap();
let numerical: [[f32; NUMERICAL_FIELD_COUNT]; WINDOW_SIZE] =
self.numerical[idx..idx + WINDOW_SIZE].try_into().unwrap();
let target: [f32; 1] = [self.targets[idx + WINDOW_SIZE - 1]];
Some(BarWindow {
hours: Data::from(hours),
days: Data::from(days),
numerical: Data::from(numerical),
target: Data::from(target),
})
}
}
pub struct MultipleSymbolDataset {
composed_dataset: ComposedDataset<SingleSymbolDataset>,
}
impl MultipleSymbolDataset {
pub fn new(bars: Vec<Bar>) -> Self {
let groups = calculate_indicators(bars)
.into_iter()
.map(SingleSymbolDataset::new)
.collect::<Vec<_>>();
Self {
composed_dataset: ComposedDataset::new(groups),
}
}
}
impl Dataset<BarWindow> for MultipleSymbolDataset {
fn len(&self) -> usize {
self.composed_dataset.len()
}
fn get(&self, idx: usize) -> Option<BarWindow> {
self.composed_dataset.get(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{
distributions::{Distribution, Uniform},
Rng,
};
use time::OffsetDateTime;
fn generate_random_dataset(length: usize) -> MultipleSymbolDataset {
let mut rng = rand::thread_rng();
let uniform = Uniform::new(1.0, 100.0);
let mut bars = Vec::with_capacity(length);
for _ in 0..=(length + (HEAD_SIZE - 1) + (WINDOW_SIZE - 1)) {
bars.push(Bar {
symbol: "AAPL".to_string(),
time: OffsetDateTime::now_utc(),
open: uniform.sample(&mut rng),
high: uniform.sample(&mut rng),
low: uniform.sample(&mut rng),
close: uniform.sample(&mut rng),
volume: uniform.sample(&mut rng),
trades: rng.gen_range(1..100),
});
}
MultipleSymbolDataset::new(bars)
}
#[test]
fn test_single_symbol_dataset() {
let length = 100;
let dataset = generate_random_dataset(length);
assert_eq!(dataset.len(), length);
}
#[test]
fn test_single_symbol_dataset_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(0).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_last_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(dataset.len() - 1).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_out_of_bounds() {
let length = 100;
let dataset = generate_random_dataset(length);
assert!(dataset.get(dataset.len()).is_none());
}
}

21
src/lib/qrust/ml/mod.rs Normal file
View File

@@ -0,0 +1,21 @@
pub mod batcher;
pub mod dataset;
pub mod model;
pub use batcher::{BarWindowBatch, BarWindowBatcher};
pub use dataset::{BarWindow, MultipleSymbolDataset};
pub use model::{Model, ModelConfig};
use burn::{
backend::{
wgpu::{AutoGraphicsApi, WgpuDevice},
Autodiff, Wgpu,
},
tensor::backend::Backend,
};
pub type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
pub type MyAutodiffBackend = Autodiff<MyBackend>;
pub type MyDevice = <Autodiff<Wgpu> as Backend>::Device;
pub const DEVICE: MyDevice = WgpuDevice::BestAvailable;

160
src/lib/qrust/ml/model.rs Normal file
View File

@@ -0,0 +1,160 @@
use super::BarWindowBatch;
use crate::types::ta::NUMERICAL_FIELD_COUNT;
use burn::{
config::Config,
module::Module,
nn::{
loss::{MseLoss, Reduction},
Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig,
},
tensor::{
self,
backend::{AutodiffBackend, Backend},
Tensor,
},
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
hour_embedding: Embedding<B>,
day_embedding: Embedding<B>,
lstm_1: Lstm<B>,
dropout_1: Dropout,
lstm_2: Lstm<B>,
dropout_2: Dropout,
lstm_3: Lstm<B>,
dropout_3: Dropout,
lstm_4: Lstm<B>,
dropout_4: Dropout,
linear: Linear<B>,
}
#[derive(Config, Debug)]
pub struct ModelConfig {
#[config(default = "3")]
pub hour_features: usize,
#[config(default = "2")]
pub day_features: usize,
#[config(default = "{NUMERICAL_FIELD_COUNT}")]
pub numerical_features: usize,
#[config(default = "0.2")]
pub dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
let num_features = self.numerical_features + self.hour_features + self.day_features;
let lstm_1_hidden_size = 512;
let lstm_2_hidden_size = 256;
let lstm_3_hidden_size = 64;
let lstm_4_hidden_size = 32;
Model {
hour_embedding: EmbeddingConfig::new(24, self.hour_features).init(device),
day_embedding: EmbeddingConfig::new(7, self.day_features).init(device),
lstm_1: LstmConfig::new(num_features, lstm_1_hidden_size, true).init(device),
dropout_1: DropoutConfig::new(self.dropout).init(),
lstm_2: LstmConfig::new(lstm_1_hidden_size, lstm_2_hidden_size, true).init(device),
dropout_2: DropoutConfig::new(self.dropout).init(),
lstm_3: LstmConfig::new(lstm_2_hidden_size, lstm_3_hidden_size, true).init(device),
dropout_3: DropoutConfig::new(self.dropout).init(),
lstm_4: LstmConfig::new(lstm_3_hidden_size, lstm_4_hidden_size, true).init(device),
dropout_4: DropoutConfig::new(self.dropout).init(),
linear: LinearConfig::new(lstm_4_hidden_size, 1).init(device),
}
}
}
impl<B: Backend> Model<B> {
pub fn forward(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
) -> Tensor<B, 2> {
let hour = self.hour_embedding.forward(hour);
let day = self.day_embedding.forward(day);
let x = Tensor::cat(vec![hour, day, numerical], 2);
let (_, x) = self.lstm_1.forward(x, None);
let x = self.dropout_1.forward(x);
let (_, x) = self.lstm_2.forward(x, None);
let x = self.dropout_2.forward(x);
let (_, x) = self.lstm_3.forward(x, None);
let x = self.dropout_3.forward(x);
let (_, x) = self.lstm_4.forward(x, None);
let x = self.dropout_4.forward(x);
let [batch_size, window_size, features] = x.shape().dims;
let x = x.slice([0..batch_size, window_size - 1..window_size, 0..features]);
let x = x.squeeze(1);
self.linear.forward(x)
}
pub fn forward_regression(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
target: Tensor<B, 2>,
) -> RegressionOutput<B> {
let output = self.forward(hour, day, numerical);
let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean);
RegressionOutput::new(loss, output, target)
}
}
impl<B: AutodiffBackend> TrainStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> TrainOutput<RegressionOutput<B>> {
let item = self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> RegressionOutput<B> {
self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::{backend::Wgpu, tensor::Distribution};
#[test]
#[ignore]
fn test_model() {
let device = Default::default();
let distribution = Distribution::Normal(0.0, 1.0);
let config = ModelConfig::new().with_numerical_features(7);
let model = config.init::<Wgpu>(&device);
let hour = Tensor::ones([2, 10], &device);
let day = Tensor::ones([2, 10], &device);
let numerical = Tensor::random([2, 10, 7], distribution, &device);
let output = model.forward(hour, day, numerical);
assert_eq!(output.shape().dims, [2, 1]);
}
}

6
src/lib/qrust/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod alpaca;
pub mod database;
pub mod ml;
pub mod ta;
pub mod types;
pub mod utils;

149
src/lib/qrust/ta/bbands.rs Normal file
View File

@@ -0,0 +1,149 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct BbandsState {
window: VecDeque<f64>,
sum: f64,
squared_sum: f64,
multiplier: f64,
}
#[allow(clippy::type_complexity)]
pub trait Bbands<T>: Iterator + Sized {
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64, // Typically 2.0
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>>;
}
impl<I, T> Bbands<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64,
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>> {
self.scan(
BbandsState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
squared_sum: 0.0,
multiplier,
},
|state: &mut BbandsState, value: T| {
let value = *value.borrow();
let front = state.window.pop_front().unwrap();
state.sum -= front;
state.squared_sum -= front.powi(2);
state.window.push_back(value);
state.sum += value;
state.squared_sum += value.powi(2);
let mean = state.sum / state.window.len() as f64;
let variance =
((state.squared_sum / state.window.len() as f64) - mean.powi(2)).max(0.0);
let standard_deviation = variance.sqrt();
let upper_band = mean + state.multiplier * standard_deviation;
let lower_band = mean - state.multiplier * standard_deviation;
Some((upper_band, mean, lower_band))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bbands() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
#[test]
fn test_bbands_empty() {
let data = Vec::<f64>::new();
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(bbands, Vec::<(f64, f64, f64)>::new());
}
#[test]
fn test_bbands_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(1).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.0, 1.0, 1.0),
(2.0, 2.0, 2.0),
(3.0, 3.0, 3.0),
(4.0, 4.0, 4.0),
(5.0, 5.0, 5.0)
]
);
}
#[test]
fn test_bbands_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
}

59
src/lib/qrust/ta/deriv.rs Normal file
View File

@@ -0,0 +1,59 @@
use std::{borrow::Borrow, iter::Scan};
pub struct DerivState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Deriv<T>: Iterator + Sized {
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>>;
}
impl<I, T> Deriv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>> {
self.scan(
DerivState { last: 0.0 },
|state: &mut DerivState, value: T| {
let value = *value.borrow();
let deriv = value - state.last;
state.last = value;
Some(deriv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deriv() {
let data = vec![1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
#[test]
fn test_deriv_empty() {
let data = Vec::<f64>::new();
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, Vec::<f64>::new());
}
#[test]
fn test_deriv_borrow() {
let data = [1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
}

95
src/lib/qrust/ta/ema.rs Normal file
View File

@@ -0,0 +1,95 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct EmaState {
weight: f64,
ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Ema<T>: Iterator + Sized {
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>>;
}
impl<I, T> Ema<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>> {
let smoothing = 2.0;
let weight = smoothing / (1.0 + period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
EmaState { weight, ema: first },
|state: &mut EmaState, value: T| {
let value = *value.borrow();
state.ema = (value * state.weight) + (state.ema * (1.0 - state.weight));
Some(state.ema)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
#[test]
fn test_ema_empty() {
let data = Vec::<f64>::new();
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, Vec::<f64>::new());
}
#[test]
fn test_ema_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_ema_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
}

216
src/lib/qrust/ta/macd.rs Normal file
View File

@@ -0,0 +1,216 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct MacdState {
short_weight: f64,
long_weight: f64,
signal_weight: f64,
short_ema: f64,
long_ema: f64,
signal_ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Macd<T>: Iterator + Sized {
fn macd(
self,
short_period: NonZeroUsize, // Typically 12
long_period: NonZeroUsize, // Typically 26
signal_period: NonZeroUsize, // Typically 9
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>>;
}
impl<I, T> Macd<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn macd(
self,
short_period: NonZeroUsize,
long_period: NonZeroUsize,
signal_period: NonZeroUsize,
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>> {
let smoothing = 2.0;
let short_weight = smoothing / (1.0 + short_period.get() as f64);
let long_weight = smoothing / (1.0 + long_period.get() as f64);
let signal_weight = smoothing / (1.0 + signal_period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
MacdState {
short_weight,
long_weight,
signal_weight,
short_ema: first,
long_ema: first,
signal_ema: 0.0,
},
|state: &mut MacdState, value: T| {
let value = *value.borrow();
state.short_ema =
(value * state.short_weight) + (state.short_ema * (1.0 - state.short_weight));
state.long_ema =
(value * state.long_weight) + (state.long_ema * (1.0 - state.long_weight));
let macd = state.short_ema - state.long_ema;
state.signal_ema =
(macd * state.signal_weight) + (state.signal_ema * (1.0 - state.signal_weight));
Some((macd, state.signal_ema))
},
)
}
}
#[cfg(test)]
mod tests {
use super::super::ema::Ema;
use super::*;
#[test]
fn test_macd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_empty() {
let data = Vec::<f64>::new();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
vec![]
);
}
#[test]
fn test_macd_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
}

17
src/lib/qrust/ta/mod.rs Normal file
View File

@@ -0,0 +1,17 @@
pub mod bbands;
pub mod deriv;
pub mod ema;
pub mod macd;
pub mod obv;
pub mod pct;
pub mod rsi;
pub mod sma;
pub use bbands::Bbands;
pub use deriv::Deriv;
pub use ema::Ema;
pub use macd::Macd;
pub use obv::Obv;
pub use pct::Pct;
pub use rsi::Rsi;
pub use sma::Sma;

73
src/lib/qrust/ta/obv.rs Normal file
View File

@@ -0,0 +1,73 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
};
pub struct ObvState {
last: f64,
obv: f64,
}
#[allow(clippy::type_complexity)]
pub trait Obv<T>: Iterator + Sized {
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>>;
}
impl<I, T> Obv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<(f64, f64)>,
{
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
ObvState {
last: first.0,
obv: 0.0,
},
|state: &mut ObvState, value: T| {
let (close, volume) = *value.borrow();
if close > state.last {
state.obv += volume;
} else if close < state.last {
state.obv -= volume;
}
state.last = close;
Some(state.obv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_obv() {
let data = vec![(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
#[test]
fn test_obv_empty() {
let data = Vec::<(f64, f64)>::new();
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, Vec::<f64>::new());
}
#[test]
fn test_obv_borrow() {
let data = [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
}

64
src/lib/qrust/ta/pct.rs Normal file
View File

@@ -0,0 +1,64 @@
use std::{borrow::Borrow, iter::Scan};
pub struct PctState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Pct<T>: Iterator + Sized {
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>>;
}
impl<I, T> Pct<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>> {
self.scan(PctState { last: 0.0 }, |state: &mut PctState, value: T| {
let value = *value.borrow();
let pct = value / state.last - 1.0;
state.last = value;
Some(pct)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pct() {
let data = vec![1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
#[test]
fn test_pct_empty() {
let data = Vec::<f64>::new();
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, Vec::<f64>::new());
}
#[test]
fn test_pct_0() {
let data = vec![1.0, 0.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, -1.0, f64::INFINITY, -0.5, -0.5]);
}
#[test]
fn test_pct_borrow() {
let data = [1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
}

135
src/lib/qrust/ta/rsi.rs Normal file
View File

@@ -0,0 +1,135 @@
use std::{
borrow::Borrow,
collections::VecDeque,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct RsiState {
last: f64,
window_gains: VecDeque<f64>,
window_losses: VecDeque<f64>,
sum_gains: f64,
sum_losses: f64,
}
#[allow(clippy::type_complexity)]
pub trait Rsi<T>: Iterator + Sized {
fn rsi(
self,
period: NonZeroUsize, // Typically 14
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>>;
}
impl<I, T> Rsi<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn rsi(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
RsiState {
last: first,
window_gains: VecDeque::from(vec![0.0; period.get()]),
window_losses: VecDeque::from(vec![0.0; period.get()]),
sum_gains: 0.0,
sum_losses: 0.0,
},
|state, value| {
let value = *value.borrow();
state.sum_gains -= state.window_gains.pop_front().unwrap();
state.sum_losses -= state.window_losses.pop_front().unwrap();
let gain = (value - state.last).max(0.0);
let loss = (state.last - value).max(0.0);
state.last = value;
state.window_gains.push_back(gain);
state.window_losses.push_back(loss);
state.sum_gains += gain;
state.sum_losses += loss;
let avg_loss = state.sum_losses / state.window_losses.len() as f64;
if avg_loss == 0.0 {
return Some(100.0);
}
let avg_gain = state.sum_gains / state.window_gains.len() as f64;
let rs = avg_gain / avg_loss;
Some(100.0 - (100.0 / (1.0 + rs)))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi() {
let data = vec![1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
#[test]
fn test_rsi_empty() {
let data = Vec::<f64>::new();
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, Vec::<f64>::new());
}
#[test]
fn test_rsi_no_loss() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 100.0, 100.0]);
}
#[test]
fn test_rsi_no_gain() {
let data = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_rsi_borrow() {
let data = [1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
}

88
src/lib/qrust/ta/sma.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct SmaState {
window: VecDeque<f64>,
sum: f64,
}
#[allow(clippy::type_complexity)]
pub trait Sma<T>: Iterator + Sized {
fn sma(self, period: NonZeroUsize)
-> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>>;
}
impl<I, T> Sma<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn sma(
self,
period: NonZeroUsize,
) -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>> {
self.scan(
SmaState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
},
|state: &mut SmaState, value: T| {
let value = *value.borrow();
state.sum -= state.window.pop_front().unwrap();
state.window.push_back(value);
state.sum += value;
Some(state.sum / state.window.len() as f64)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sma() {
let data = vec![3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_sma_empty() {
let data = Vec::<f64>::new();
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, Vec::<f64>::new());
}
#[test]
fn test_sma_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_sma_borrow() {
let data = [3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
}

View File

@@ -0,0 +1,75 @@
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Status {
Onboarding,
SubmissionFailed,
Submitted,
AccountUpdated,
ApprovalPending,
Active,
Rejected,
}
#[derive(Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct Account {
pub id: Uuid,
#[serde(rename = "account_number")]
pub number: String,
pub status: Status,
pub currency: String,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cash: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub non_marginable_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub accrued_fees: f64,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_in: Option<f64>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_out: Option<f64>,
pub pattern_day_trader: bool,
#[serde(default)]
pub trade_suspend_by_user: bool,
pub trading_blocked: bool,
pub transfers_blocked: bool,
#[serde(rename = "account_blocked")]
pub blocked: bool,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
pub shorting_enabled: bool,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub long_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub short_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub multiplier: i8,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub initial_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub sma: f64,
pub daytrade_count: i64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub daytrading_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}

View File

@@ -0,0 +1,39 @@
use super::position::Position;
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}

View File

@@ -1,9 +1,8 @@
use crate::types;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct Bar {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
@@ -27,21 +26,14 @@ pub struct Bar {
impl From<(Bar, String)> for types::Bar {
fn from((bar, symbol): (Bar, String)) -> Self {
Self {
time: bar.time,
symbol,
time: bar.time,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}

View File

@@ -0,0 +1,26 @@
use crate::{
types,
utils::{de, time::EST_OFFSET},
};
use serde::Deserialize;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime,

View File

@@ -0,0 +1,8 @@
pub mod account;
pub mod asset;
pub mod bar;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod order;
pub mod position;

View File

@@ -0,0 +1,57 @@
use crate::{
types::{self, alpaca::shared::news::strip},
utils::de,
};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageSize {
Thumb,
Small,
Large,
}
#[derive(Deserialize)]
pub struct Image {
pub size: ImageSize,
pub url: String,
}
#[derive(Deserialize)]
pub struct News {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
pub images: Vec<Image>,
}
impl From<News> for types::News {
fn from(news: News) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,3 @@
use crate::types::alpaca::shared::order;
pub use order::{Order, Side};

View File

@@ -0,0 +1,61 @@
use crate::{
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}

View File

@@ -0,0 +1,6 @@
pub mod incoming;
pub mod outgoing;
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";

View File

@@ -0,0 +1,23 @@
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -0,0 +1,108 @@
use crate::{
alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
utils::{ser, ONE_MINUTE},
};
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
pub use shared::{Sort, Source};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Adjustment {
Raw,
Split,
Dividend,
All,
}
#[derive(Serialize)]
pub struct UsEquity {
#[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>,
#[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub adjustment: Option<Adjustment>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub asof: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feed: Option<Source>,
#[serde(skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for UsEquity {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
adjustment: Some(Adjustment::All),
asof: None,
feed: Some(Source::Iex),
currency: None,
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
pub struct Crypto {
#[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>,
#[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for Crypto {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Bar {
UsEquity(UsEquity),
Crypto(Crypto),
}

View File

@@ -0,0 +1,31 @@
use crate::utils::time::MAX_TIMESTAMP;
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
#[allow(dead_code)]
pub enum DateType {
Trading,
Settlement,
}
#[derive(Serialize)]
pub struct Calendar {
#[serde(with = "time::serde::rfc3339")]
pub start: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub end: OffsetDateTime,
#[serde(rename = "date")]
pub date_type: DateType,
}
impl Default for Calendar {
fn default() -> Self {
Self {
start: OffsetDateTime::UNIX_EPOCH,
end: *MAX_TIMESTAMP,
date_type: DateType::Trading,
}
}
}

View File

@@ -1,4 +1,5 @@
pub mod asset;
pub mod bar;
pub mod clock;
pub mod calendar;
pub mod news;
pub mod order;

View File

@@ -0,0 +1,40 @@
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
pub struct News {
#[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
pub symbols: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_content: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_contentless: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for News {
fn default() -> Self {
Self {
symbols: vec![],
start: None,
end: None,
limit: Some(MAX_LIMIT),
include_content: Some(true),
exclude_contentless: Some(false),
page_token: None,
sort: Some(Sort::Asc),
}
}
}

View File

@@ -0,0 +1,55 @@
use crate::{
types::alpaca::shared::{order, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Status {
Open,
Closed,
All,
}
#[derive(Serialize)]
pub struct Order {
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<Status>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub after: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub until: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub direction: Option<Sort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nested: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "ser::join_symbols_option")]
pub symbols: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub side: Option<Side>,
}
impl Default for Order {
fn default() -> Self {
Self {
status: Some(Status::All),
limit: Some(500),
after: None,
until: None,
direction: Some(Sort::Asc),
nested: Some(true),
symbols: None,
side: None,
}
}
}

View File

@@ -0,0 +1,3 @@
pub mod api;
pub mod shared;
pub mod websocket;

View File

@@ -0,0 +1,53 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
Crypto,
}
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto,
}
impl_from_enum!(
types::Exchange,
Exchange,
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto
);
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,
Inactive,
}
impl From<Status> for bool {
fn from(status: Status) -> Self {
match status {
Status::Active => true,
Status::Inactive => false,
}
}
}

View File

@@ -0,0 +1,10 @@
pub mod asset;
pub mod mode;
pub mod news;
pub mod order;
pub mod sort;
pub mod source;
pub use mode::Mode;
pub use sort::Sort;
pub use source::Source;

View File

@@ -0,0 +1,33 @@
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter},
str::FromStr,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Mode {
Live,
Paper,
}
impl FromStr for Mode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"live" => Ok(Self::Live),
"paper" => Ok(Self::Paper),
_ => Err(format!("Unknown mode: {s}")),
}
}
}
impl Display for Mode {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Self::Live => write!(f, "live"),
Self::Paper => write!(f, "paper"),
}
}
}

View File

@@ -0,0 +1,26 @@
use lazy_static::lazy_static;
use regex::Regex;
lazy_static! {
static ref RE_TAGS: Regex = Regex::new("<[^>]+>").unwrap();
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
}
pub fn strip(content: &str) -> String {
let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = content.trim();
content.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip() {
let content = "<p> <b> Hello, </b> <i> World! </i> </p>";
assert_eq!(strip(content), "Hello, World!");
}
}

View File

@@ -0,0 +1,275 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
#[serde(alias = "")]
Simple,
Bracket,
Oco,
Oto,
}
impl_from_enum!(types::order::Class, Class, Simple, Bracket, Oco, Oto);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Type {
Market,
Limit,
Stop,
StopLimit,
TrailingStop,
}
impl_from_enum!(
types::order::Type,
Type,
Market,
Limit,
Stop,
StopLimit,
TrailingStop
);
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Buy,
Sell,
}
impl_from_enum!(types::order::Side, Side, Buy, Sell);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TimeInForce {
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok,
}
impl_from_enum!(
types::order::TimeInForce,
TimeInForce,
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok
);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated,
}
impl_from_enum!(
types::order::Status,
Status,
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated
);
#[derive(Deserialize, Clone, Debug, PartialEq)]
#[allow(clippy::struct_field_names)]
pub struct Order {
pub id: Uuid,
pub client_order_id: Uuid,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub updated_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339")]
pub submitted_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub filled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub expired_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub cancel_requested_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub canceled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub failed_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub replaced_at: Option<OffsetDateTime>,
pub replaced_by: Option<Uuid>,
pub replaces: Option<Uuid>,
pub asset_id: Uuid,
pub symbol: String,
pub asset_class: super::asset::Class,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub notional: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub qty: Option<f64>,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub filled_qty: f64,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub filled_avg_price: Option<f64>,
pub order_class: Class,
#[serde(rename = "type")]
pub order_type: Type,
pub side: Side,
pub time_in_force: TimeInForce,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub limit_price: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub stop_price: Option<f64>,
pub status: Status,
pub extended_hours: bool,
pub legs: Option<Vec<Order>>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_percent: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_price: Option<f64>,
pub hwm: Option<f64>,
}
impl From<Order> for types::Order {
fn from(order: Order) -> Self {
Self {
id: order.id,
client_order_id: order.client_order_id,
time_submitted: order.submitted_at,
time_created: order.created_at,
time_updated: order.updated_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_filled: order.filled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_expired: order.expired_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_cancel_requested: order
.cancel_requested_at
.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_canceled: order.canceled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_failed: order.failed_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_replaced: order.replaced_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
replaced_by: order.replaced_by.unwrap_or_default(),
replaces: order.replaces.unwrap_or_default(),
symbol: order.symbol,
order_class: order.order_class.into(),
order_type: order.order_type.into(),
side: order.side.into(),
time_in_force: order.time_in_force.into(),
notional: order.notional.unwrap_or_default(),
qty: order.qty.unwrap_or_default(),
filled_qty: order.filled_qty,
filled_avg_price: order.filled_avg_price.unwrap_or_default(),
status: order.status.into(),
extended_hours: order.extended_hours,
limit_price: order.limit_price.unwrap_or_default(),
stop_price: order.stop_price.unwrap_or_default(),
trail_percent: order.trail_percent.unwrap_or_default(),
trail_price: order.trail_price.unwrap_or_default(),
hwm: order.hwm.unwrap_or_default(),
legs: order
.legs
.unwrap_or_default()
.into_iter()
.map(|order| order.id)
.collect(),
}
}
}
impl Order {
pub fn normalize(self) -> Vec<types::Order> {
let mut orders = vec![self.clone().into()];
if let Some(legs) = self.legs {
for leg in legs {
orders.extend(leg.normalize());
}
}
orders
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let order_template = Order {
id: Uuid::new_v4(),
client_order_id: Uuid::new_v4(),
created_at: OffsetDateTime::now_utc(),
updated_at: None,
submitted_at: OffsetDateTime::now_utc(),
filled_at: None,
expired_at: None,
cancel_requested_at: None,
canceled_at: None,
failed_at: None,
replaced_at: None,
replaced_by: None,
replaces: None,
asset_id: Uuid::new_v4(),
symbol: "AAPL".to_string(),
asset_class: super::super::asset::Class::UsEquity,
notional: None,
qty: None,
filled_qty: 0.0,
filled_avg_price: None,
order_class: Class::Simple,
order_type: Type::Market,
side: Side::Buy,
time_in_force: TimeInForce::Day,
limit_price: None,
stop_price: None,
status: Status::New,
extended_hours: false,
legs: None,
trail_percent: None,
trail_price: None,
hwm: None,
};
let mut order = order_template.clone();
order.legs = Some(vec![order_template.clone(), order_template.clone()]);
order.legs.as_mut().unwrap()[0].legs = Some(vec![order_template.clone()]);
let orders = order.normalize();
assert_eq!(orders.len(), 4);
}
}

View File

@@ -0,0 +1,9 @@
use serde::Serialize;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Sort {
Asc,
Desc,
}

View File

@@ -1,12 +1,15 @@
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter},
str::FromStr,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Source {
Iex,
Sip,
Otc,
}
impl FromStr for Source {
@@ -26,6 +29,7 @@ impl Display for Source {
match self {
Self::Iex => write!(f, "iex"),
Self::Sip => write!(f, "sip"),
Self::Otc => write!(f, "otc"),
}
}
}

View File

@@ -0,0 +1,7 @@
use serde::Serialize;
#[derive(Serialize)]
pub struct Message {
pub key: String,
pub secret: String,
}

View File

@@ -1,8 +1,8 @@
use crate::types;
use serde::{Deserialize, Serialize};
use crate::types::Bar;
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Deserialize, Debug, PartialEq)]
pub struct Message {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
@@ -25,18 +25,17 @@ pub struct Message {
pub vwap: f64,
}
impl From<Message> for types::Bar {
impl From<Message> for Bar {
fn from(bar: Message) -> Self {
Self {
time: bar.time,
symbol: bar.symbol,
time: bar.time,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}

View File

@@ -0,0 +1,8 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub code: u16,
#[serde(rename = "msg")]
pub message: String,
}

View File

@@ -1,12 +1,13 @@
pub mod bar;
pub mod error;
pub mod news;
pub mod status;
pub mod subscription;
pub mod success;
use serde::{Deserialize, Serialize};
use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "T")]
pub enum Message {
#[serde(rename = "success")]
@@ -19,6 +20,8 @@ pub enum Message {
UpdatedBar(bar::Message),
#[serde(rename = "n")]
News(news::Message),
#[serde(rename = "s")]
Status(status::Message),
#[serde(rename = "error")]
Error(error::Message),
}

View File

@@ -0,0 +1,42 @@
use crate::{
types::{alpaca::shared::news::strip, News},
utils::de,
};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
}
impl From<Message> for News {
fn from(news: Message) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,154 @@
use serde::Deserialize;
use serde_with::serde_as;
use time::OffsetDateTime;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "sc", content = "sm")]
pub enum Status {
#[serde(rename = "2")]
#[serde(alias = "H")]
TradingHalt(String),
#[serde(rename = "3")]
Resume(String),
#[serde(rename = "5")]
PriceIndication(String),
#[serde(rename = "6")]
TradingRangeIndication(String),
#[serde(rename = "7")]
MarketImbalanceBuy(String),
#[serde(rename = "8")]
MarketImbalanceSell(String),
#[serde(rename = "9")]
MarketOnCloseImbalanceBuy(String),
#[serde(rename = "A")]
MarketOnCloseImbalanceSell(String),
#[serde(rename = "C")]
NoMarketImbalance(String),
#[serde(rename = "D")]
NoMarketOnCloseImbalance(String),
#[serde(rename = "E")]
ShortSaleRestriction(String),
#[serde(rename = "F")]
LimitUpLimitDown(String),
#[serde(rename = "Q")]
QuotationResumption(String),
#[serde(rename = "T")]
TradingResumption(String),
#[serde(rename = "P")]
VolatilityTradingPause(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "rc", content = "rm")]
pub enum Reason {
#[serde(rename = "D")]
NewsReleased(String),
#[serde(rename = "I")]
OrderImbalance(String),
#[serde(rename = "M")]
LimitUpLimitDown(String),
#[serde(rename = "P")]
NewsPending(String),
#[serde(rename = "X")]
Operational(String),
#[serde(rename = "Y")]
SubPennyTrading(String),
#[serde(rename = "1")]
MarketWideCircuitBreakerL1Breached(String),
#[serde(rename = "2")]
MarketWideCircuitBreakerL2Breached(String),
#[serde(rename = "3")]
MarketWideCircuitBreakerL3Breached(String),
#[serde(rename = "T1")]
HaltNewsPending(String),
#[serde(rename = "T2")]
HaltNewsDissemination(String),
#[serde(rename = "T5")]
SingleStockTradingPauseInAffect(String),
#[serde(rename = "T6")]
RegulatoryHaltExtraordinaryMarketActivity(String),
#[serde(rename = "T8")]
HaltETF(String),
#[serde(rename = "T12")]
TradingHaltedForInformationRequestedByNASDAQ(String),
#[serde(rename = "H4")]
HaltNonCompliance(String),
#[serde(rename = "H9")]
HaltFilingsNotCurrent(String),
#[serde(rename = "H10")]
HaltSECTradingSuspension(String),
#[serde(rename = "H11")]
HaltRegulatoryConcern(String),
#[serde(rename = "01")]
OperationsHaltContactMarketOperations(String),
#[serde(rename = "IPO1")]
IPOIssueNotYetTrading(String),
#[serde(rename = "M1")]
CorporateAction(String),
#[serde(rename = "M2")]
QuotationNotAvailable(String),
#[serde(rename = "LUDP")]
VolatilityTradingPause(String),
#[serde(rename = "LUDS")]
VolatilityTradingPauseStraddleCondition(String),
#[serde(rename = "MWC1")]
MarketWideCircuitBreakerHaltL1(String),
#[serde(rename = "MWC2")]
MarketWideCircuitBreakerHaltL2(String),
#[serde(rename = "MWC3")]
MarketWideCircuitBreakerHaltL3(String),
#[serde(rename = "MWC0")]
MarketWideCircuitBreakerHaltCarryOverFromPreviousDay(String),
#[serde(rename = "T3")]
NewsAndResumptionTimes(String),
#[serde(rename = "T7")]
SingleStockTradingPauseQuotationOnlyPeriod(String),
#[serde(rename = "R4")]
QualificationsIssuesReviewedResolvedQuotationsTradingToResume(String),
#[serde(rename = "R9")]
FilingRequirementsSatisfiedResolvedQuotationsTradingToResume(String),
#[serde(rename = "C3")]
IssuerNewsNotForthcomingQuotationsTradingToResume(String),
#[serde(rename = "C4")]
QualificationsHaltEndedMaintReqMetResume(String),
#[serde(rename = "C9")]
QualificationsHaltConcludedFilingsMetQuotesTradesToResume(String),
#[serde(rename = "C11")]
TradeHaltConcludedByOtherRegulatoryAuthQuotesTradesResume(String),
#[serde(rename = "R1")]
NewIssueAvailable(String),
#[serde(rename = "R")]
IssueAvailable(String),
#[serde(rename = "IPOQ")]
IPOSecurityReleasedForQuotation(String),
#[serde(rename = "IPOE")]
IPOSecurityPositioningWindowExtension(String),
#[serde(rename = "MWCQ")]
MarketWideCircuitBreakerResumption(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Tape {
A,
B,
C,
O,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[allow(clippy::struct_field_names)]
#[serde_as]
pub struct Message {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "S")]
pub symbol: String,
#[serde(flatten)]
pub status: Status,
#[serde(flatten)]
#[serde_as(as = "NoneAsEmptyString")]
pub reason: Option<Reason>,
#[serde(rename = "z")]
pub tape: Tape,
}

View File

@@ -0,0 +1,23 @@
use crate::utils::de;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(untagged)]
pub enum Message {
#[serde(rename_all = "camelCase")]
Market {
bars: Vec<String>,
updated_bars: Vec<String>,
statuses: Vec<String>,
trades: Option<Vec<String>>,
quotes: Option<Vec<String>>,
daily_bars: Option<Vec<String>>,
orderbooks: Option<Vec<String>>,
lulds: Option<Vec<String>>,
cancel_errors: Option<Vec<String>>,
},
News {
#[serde(deserialize_with = "de::add_slash_to_symbols")]
news: Vec<String>,
},
}

View File

@@ -0,0 +1,9 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "msg")]
#[serde(rename_all = "snake_case")]
pub enum Message {
Connected,
Authenticated,
}

View File

@@ -0,0 +1,53 @@
pub mod incoming;
pub mod outgoing;
use crate::types::alpaca::websocket;
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde_json::{from_str, to_string};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Connected,
)) => {}
_ => panic!("Failed to connect to Alpaca websocket."),
}
sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message {
key: api_key,
secret: api_secret,
},
))
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Authenticated,
)) => {}
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}

View File

@@ -1,11 +1,11 @@
pub mod auth;
pub mod subscribe;
use crate::types::alpaca::websocket::auth;
use serde::Serialize;
#[derive(Serialize)]
#[serde(tag = "action")]
#[serde(rename_all = "camelCase")]
#[serde(rename_all = "snake_case")]
pub enum Message {
Auth(auth::Message),
Subscribe(subscribe::Message),

View File

@@ -0,0 +1,50 @@
use crate::utils::ser;
use nonempty::NonEmpty;
use serde::Serialize;
#[derive(Serialize)]
#[serde(untagged)]
pub enum Market {
#[serde(rename_all = "camelCase")]
UsEquity {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
statuses: NonEmpty<String>,
},
#[serde(rename_all = "camelCase")]
Crypto {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
},
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Message {
Market(Market),
News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: NonEmpty<String>,
},
}
impl Message {
pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::UsEquity {
bars: symbols.clone(),
updated_bars: symbols.clone(),
statuses: symbols,
})
}
pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::Crypto {
bars: symbols.clone(),
updated_bars: symbols,
})
}
pub fn new_news(symbols: NonEmpty<String>) -> Self {
Self::News { news: symbols }
}
}

View File

@@ -0,0 +1,8 @@
pub mod auth;
pub mod data;
pub mod trading;
pub const ALPACA_US_EQUITY_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";

View File

@@ -0,0 +1,22 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Authorized,
Unauthorized,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Action {
#[serde(rename = "authenticate")]
Auth,
#[serde(rename = "listen")]
Subscribe,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub status: Status,
pub action: Action,
}

View File

@@ -0,0 +1,16 @@
pub mod auth;
pub mod order;
pub mod subscription;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "stream", content = "data")]
pub enum Message {
#[serde(rename = "authorization")]
Auth(auth::Message),
#[serde(rename = "listening")]
Subscription(subscription::Message),
#[serde(rename = "trade_updates")]
Order(order::Message),
}

Some files were not shown because too many files have changed in this diff Show More