1 Commits
main ... ollama

Author SHA1 Message Date
bbd902c6fa Add Ollama news sentiment analysis
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-02 10:56:45 +00:00
155 changed files with 2662 additions and 10396 deletions

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
# will have compiled files and executables
debug/
target/
log/
# These are backup files generated by rustfmt
**/*.rs.bk

View File

@@ -22,7 +22,7 @@ build:
cache:
<<: *global_cache
script:
- cargo +nightly build --workspace
- cargo +nightly build
test:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -30,7 +30,7 @@ test:
cache:
<<: *global_cache
script:
- cargo +nightly test --workspace
- cargo +nightly test
lint:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -39,7 +39,7 @@ lint:
<<: *global_cache
script:
- cargo +nightly fmt --all -- --check
- cargo +nightly clippy --workspace --all-targets --all-features
- cargo +nightly clippy --all-targets --all-features
depcheck:
image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -48,7 +48,7 @@ depcheck:
<<: *global_cache
script:
- cargo +nightly outdated
- cargo +nightly udeps --workspace --all-targets
- cargo +nightly udeps
build-release:
image: registry.karaolidis.com/karaolidis/qrust/rust

3208
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,18 +3,6 @@ name = "qrust"
version = "0.1.0"
edition = "2021"
[lib]
name = "qrust"
path = "src/lib/qrust/mod.rs"
[[bin]]
name = "qrust"
path = "src/bin/qrust/mod.rs"
[[bin]]
name = "trainer"
path = "src/bin/trainer/mod.rs"
[profile.release]
panic = 'abort'
strip = true
@@ -24,9 +12,9 @@ codegen-units = 1
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.7.5"
axum = "0.7.4"
dotenv = "0.15.0"
tokio = { version = "1.37.0", features = [
tokio = { version = "1.32.0", features = [
"macros",
"rt-multi-thread",
] }
@@ -34,55 +22,32 @@ tokio-tungstenite = { version = "0.21.0", features = [
"tokio-native-tls",
"native-tls",
] }
log = "0.4.21"
log4rs = "1.3.0"
serde = "1.0.201"
serde_json = "1.0.117"
serde_repr = "0.1.19"
serde_with = "3.8.1"
serde-aux = "4.5.0"
futures-util = "0.3.30"
reqwest = { version = "0.12.4", features = [
log = "0.4.20"
log4rs = "1.2.0"
serde = "1.0.188"
serde_json = "1.0.105"
serde_repr = "0.1.18"
futures-util = "0.3.28"
reqwest = { version = "0.11.20", features = [
"json",
"serde_json",
] }
http = "1.1.0"
governor = "0.6.3"
http = "1.0.0"
governor = "0.6.0"
clickhouse = { version = "0.11.6", features = [
"watch",
"time",
"uuid",
] }
uuid = { version = "1.8.0", features = [
uuid = "1.6.1"
time = { version = "0.3.31", features = [
"serde",
"v4",
] }
time = { version = "0.3.36", features = [
"serde",
"serde-well-known",
"serde-human-readable",
"formatting",
"macros",
"local-offset",
"serde-well-known",
] }
backoff = { version = "0.4.0", features = [
"tokio",
] }
regex = "1.10.4"
async-trait = "0.1.80"
itertools = "0.12.1"
lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [
"serialize",
] }
rand = "0.8.5"
rayon = "1.10.0"
burn = { version = "0.13.2", features = [
"wgpu",
"cuda",
"tui",
"metrics",
"train",
] }
[dev-dependencies]
serde_test = "1.0.176"
regex = "1.10.3"
html-escape = "0.2.13"

View File

@@ -1,5 +1,3 @@
# qrust
# QRust
![XKCD - Engineer Syllogism](https://imgs.xkcd.com/comics/engineer_syllogism.png)
`qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.
QRust (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.

View File

@@ -4,6 +4,11 @@ services:
file: support/clickhouse/docker-compose.yml
service: clickhouse
ollama:
extends:
file: support/ollama/docker-compose.yml
service: ollama
grafana:
extends:
file: support/grafana/docker-compose.yml
@@ -19,10 +24,12 @@ services:
- 7878:7878
depends_on:
- clickhouse
- ollama
env_file:
- .env.docker
volumes:
clickhouse-lib:
clickhouse-log:
ollama:
grafana-lib:

View File

@@ -4,14 +4,7 @@ appenders:
encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root:
level: info
appenders:
- stdout
- file

View File

@@ -1,86 +0,0 @@
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use qrust::types::alpaca::shared::{Mode, Source};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
use tokio::sync::Semaphore;
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
};
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String =
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
}
pub struct Config {
pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
}
impl Config {
pub fn from_env() -> Self {
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&ALPACA_API_KEY)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&ALPACA_API_SECRET)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.build()
.unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => NonZeroU32::new(200).unwrap(),
Source::Sip => NonZeroU32::new(10_000).unwrap(),
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})),
clickhouse_client: clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

View File

@@ -1,136 +0,0 @@
use crate::{
config::{Config, ALPACA_API_BASE},
database,
};
use log::{info, warn};
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
assert!(
!(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
assert!(
!account.trade_suspend_by_user,
"Account trading is suspended by user."
);
assert!(!account.trading_blocked, "Account trading is blocked.");
assert!(!account.blocked, "Account is blocked.");
if account.cash == 0.0 {
warn!("Account cash is zero, qrust will not be able to trade.");
}
info!(
"qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_API_BASE, account.currency, account.cash
);
}
pub async fn rehydrate_orders(config: &Arc<Config>) {
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
loop {
let message = alpaca::orders::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::order::Order {
status: Some(types::alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
&ALPACA_API_BASE,
)
.await
.unwrap();
if message.is_empty() {
break;
}
orders.extend(message);
after = orders.last().unwrap().submitted_at;
}
let orders = orders
.into_iter()
.flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&orders,
)
.await
.unwrap();
}
pub async fn rehydrate_positions(config: &Arc<Config>) {
let positions_future = async {
alpaca::positions::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let assets_future = async {
database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
};
let (mut positions, assets) = join!(positions_future, assets_future);
let assets = assets
.into_iter()
.map(|mut asset| {
if let Some(position) = positions.remove(&asset.symbol) {
asset.qty = position.qty_available;
} else {
asset.qty = 0.0;
}
asset
})
.collect::<Vec<_>>();
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
for position in positions.values() {
warn!(
"Position for unmonitored asset: {}, {} shares.",
position.symbol, position.qty
);
}
}

View File

@@ -1,115 +0,0 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
#![allow(clippy::missing_docs_in_private_items)]
#![feature(hash_extract_if)]
mod config;
mod init;
mod routes;
mod threads;
use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
};
use dotenv::dotenv;
use log::info;
use log4rs::config::Deserializers;
use nonempty::NonEmpty;
use qrust::{create_send_await, database};
use tokio::{join, spawn, sync::mpsc, try_join};
#[tokio::main]
async fn main() {
dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
info!("Marking all assets as stale.");
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
try_join!(
database::backfills_bars::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
),
database::backfills_news::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
)
)
.unwrap();
info!("Cleaning up database.");
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Optimizing database.");
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Rehydrating account data.");
init::check_account(&config).await;
join!(
init::rehydrate_orders(&config),
init::rehydrate_positions(&config)
);
info!("Starting threads.");
spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
spawn(threads::data::run(
config.clone(),
data_receiver,
clock_receiver,
));
spawn(threads::clock::run(config.clone(), clock_sender));
if let Some(assets) = NonEmpty::from_vec(assets) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
}
routes::run(config, data_sender).await;
}

View File

@@ -1,197 +0,0 @@
use crate::{
config::{Config, ALPACA_API_BASE},
create_send_await, database, threads,
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use nonempty::{nonempty, NonEmpty};
use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
})
}
#[derive(Deserialize)]
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
}
Ok((
StatusCode::OK,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?;
if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN);
}
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
nonempty![(asset.symbol, asset.class.into())]
);
Ok(StatusCode::CREATED)
}
pub async fn delete(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
nonempty![(asset.symbol, asset.class)]
);
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -1,89 +0,0 @@
use crate::{
config::{Config, ALPACA_API_BASE},
database,
};
use log::info;
use qrust::{
alpaca,
types::{self, Calendar},
utils::{backoff, duration_until},
};
use std::sync::Arc;
use tokio::{join, sync::mpsc, time::sleep};
#[derive(PartialEq, Eq)]
pub enum Status {
Open,
Closed,
}
pub struct Message {
pub status: Status,
}
impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
Self {
status: if clock.is_open {
Status::Open
} else {
Status::Closed
},
}
}
}
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop {
let clock_future = async {
alpaca::clock::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
};
let calendar_future = async {
alpaca::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(Calendar::from)
.collect::<Vec<_>>()
};
let (clock, calendar) = join!(clock_future, calendar_future);
let sleep_until = duration_until(if clock.is_open {
info!("Market is open, will close at {}.", clock.next_close);
clock.next_close
} else {
info!("Market is closed, will reopen at {}.", clock.next_open);
clock.next_open
});
let sleep_future = sleep(sleep_until);
let calendar_future = async {
database::calendar::upsert_batch_and_delete(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
};
join!(sleep_future, calendar_future);
sender.send(clock.into()).await.unwrap();
}
}

View File

@@ -1,238 +0,0 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::{
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
shared::{Sort, Source},
},
Backfill, Bar, Class,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
pub data_url: &'static str,
pub api_query_constructor: fn(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
feed: Some(*ALPACA_SOURCE),
..Default::default()
})
}
pub fn crypto_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
..Default::default()
})
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbols.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
database::backfills_bars::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled bars for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"bars"
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
_ => unreachable!(),
};
let api_query_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
_ => unreachable!(),
};
Box::new(Handler {
config,
data_url,
api_query_constructor,
})
}

View File

@@ -1,243 +0,0 @@
pub mod bars;
pub mod news;
use async_trait::async_trait;
use itertools::Itertools;
use log::{info, warn};
use nonempty::{nonempty, NonEmpty};
use qrust::{
types::Backfill,
utils::{last_minute, ONE_SECOND},
};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::JoinHandle,
try_join,
};
use uuid::Uuid;
pub enum Action {
Backfill,
Purge,
}
pub struct Message {
pub action: Action,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[derive(Clone)]
pub struct Job {
pub symbol: String,
pub fetch_from: OffsetDateTime,
pub fetch_to: OffsetDateTime,
pub fresh: bool,
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
async fn backfill(&self, jobs: NonEmpty<Job>);
fn max_limit(&self) -> i64;
fn log_string(&self) -> &'static str;
}
pub struct Jobs {
pub symbol_to_uuid: HashMap<String, Uuid>,
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
}
impl Jobs {
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
let uuid = Uuid::new_v4();
for symbol in jobs {
self.symbol_to_uuid.insert(symbol.clone(), uuid);
}
self.uuid_to_job.insert(uuid, fut);
}
pub fn contains_key(&self, symbol: &str) -> bool {
self.symbol_to_uuid.contains_key(symbol)
}
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
self.symbol_to_uuid
.remove(symbol)
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
}
pub fn remove_many<T>(&mut self, symbols: &[T])
where
T: AsRef<str> + Hash + Eq,
{
for symbol in symbols {
self.symbol_to_uuid
.remove(symbol.as_ref())
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
}
}
pub fn len(&self) -> usize {
self.symbol_to_uuid.len()
}
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(Jobs {
symbol_to_uuid: HashMap::new(),
uuid_to_job: HashMap::new(),
}));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<Jobs>>,
message: Message,
) {
let backfill_jobs_clone = backfill_jobs.clone();
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = Vec::from(message.symbols);
match message.action {
Action::Backfill => {
let log_string = handler.log_string();
let max_limit = handler.max_limit();
let backfills = handler
.select_latest_backfills(&symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
let mut jobs = Vec::with_capacity(symbols.len());
for symbol in symbols {
if backfill_jobs.contains_key(&symbol) {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
let backfill = backfills.get(&symbol);
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
jobs.push(Job {
symbol,
fetch_from,
fetch_to,
fresh,
});
}
let mut current_minutes = 0;
let job_groups = jobs
.into_iter()
.sorted_unstable_by_key(|job| job.fetch_from)
.fold(Vec::<NonEmpty<Job>>::new(), |mut job_groups, job| {
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
if let Some(job_group) = job_groups.last_mut() {
if current_minutes + minutes <= max_limit {
job_group.push(job);
current_minutes += minutes;
return job_groups;
}
}
job_groups.push(nonempty![job]);
current_minutes = minutes;
job_groups
});
for job_group in job_groups {
let symbols = job_group
.iter()
.map(|job| job.symbol.clone())
.collect::<Vec<_>>();
let handler = handler.clone();
let symbols_clone = symbols.clone();
let backfill_jobs_clone = backfill_jobs_clone.clone();
let fut = spawn(async move {
handler.queue_backfill(&job_group).await;
handler.backfill(job_group).await;
let mut backfill_jobs = backfill_jobs_clone.lock().await;
backfill_jobs.remove_many(&symbols_clone);
let remaining = backfill_jobs.len();
drop(backfill_jobs);
info!("{} {} backfills remaining.", remaining, log_string);
});
backfill_jobs.insert(symbols, fut);
}
}
Action::Purge => {
for symbol in &symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
job.abort();
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&symbols),
handler.delete_data(&symbols)
)
.unwrap();
}
}
message.response.send(()).unwrap();
}

View File

@@ -1,186 +0,0 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
database,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::shared::{Sort, Source},
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
database::backfills_news::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
Box::new(Handler { config })
}

View File

@@ -1,391 +0,0 @@
mod backfill;
mod websocket;
use super::clock;
use crate::{
config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
create_send_await, database,
};
use itertools::{Either, Itertools};
use log::error;
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
join, select, spawn,
sync::{mpsc, oneshot},
};
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Action {
Add,
Enable,
Remove,
Disable,
}
pub struct Message {
pub action: Action,
pub assets: NonEmpty<(String, Class)>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
assets,
response: sender,
},
receiver,
)
}
}
#[derive(Clone, Copy)]
pub enum ThreadType {
Bars(Class),
News,
}
pub async fn run(
config: Arc<Config>,
mut receiver: mpsc::Receiver<Message>,
mut clock_receiver: mpsc::Receiver<clock::Message>,
) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News);
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
config.clone(),
bars_us_equity_websocket_sender.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_websocket_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_websocket_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
Some(message) = clock_receiver.recv() => {
spawn(handle_clock_message(
config.clone(),
bars_us_equity_backfill_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_backfill_sender.clone(),
message,
));
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
fn init_thread(
config: Arc<Config>,
thread_type: ThreadType,
) -> (
mpsc::Sender<websocket::Message>,
mpsc::Sender<backfill::Message>,
) {
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
}
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
};
let backfill_handler = match thread_type {
ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
ThreadType::News => backfill::news::create_handler(config.clone()),
};
let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(backfill_handler.into(), backfill_receiver));
let websocket_handler = match thread_type {
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
ThreadType::News => websocket::news::create_handler(&config),
};
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run(
websocket_handler.into(),
websocket_receiver,
websocket_url,
));
(websocket_sender, backfill_sender)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_lines)]
async fn handle_message(
config: Arc<Config>,
bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_websocket_sender: mpsc::Sender<websocket::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message,
) {
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets
.clone()
.into_iter()
.partition_map(|asset| match asset.1 {
Class::UsEquity => Either::Left(asset.0),
Class::Crypto => Either::Right(asset.0),
});
let symbols = message.assets.map(|(symbol, _)| symbol);
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_websocket_sender,
websocket::Message::new,
message.action.into(),
symbols.clone()
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
match message.action {
Action::Add | Action::Enable => {
let symbols = Vec::from(symbols.clone());
let assets = async {
alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let batch =
symbols
.iter()
.fold(Vec::with_capacity(symbols.len()), |mut batch, symbol| {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}.", symbol);
}
batch
});
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&Vec::from(symbols.clone()),
)
.await
.unwrap();
}
Action::Disable => unreachable!(),
}
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap();
}
async fn handle_clock_message(
config: Arc<Config>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: clock::Message,
) {
if message.status == clock::Status::Closed {
database::cleanup_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
}
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()
.into_iter()
.partition_map(|asset| match asset.class {
Class::UsEquity => Either::Left(asset.symbol),
Class::Crypto => Either::Right(asset.symbol),
});
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
crypto_symbols
);
}
};
let news_future = async {
if let Some(symbols) = NonEmpty::from_vec(symbols) {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
backfill::Action::Backfill,
symbols
);
}
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
}

View File

@@ -1,171 +0,0 @@
use super::State;
use crate::{
config::{Config, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, Bar, Class},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub config: Arc<Config>,
pub inserter: Arc<Mutex<Inserter<Bar>>>,
pub subscription_message_constructor:
fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
self.inserter.lock().await.write(&bar).await.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"bars"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("bars")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()),
));
let subscription_message_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => {
websocket::data::outgoing::subscribe::Message::new_market_us_equity
}
ThreadType::Bars(Class::Crypto) => {
websocket::data::outgoing::subscribe::Message::new_market_crypto
}
_ => unreachable!(),
};
Box::new(Handler {
config,
inserter,
subscription_message_constructor,
})
}

View File

@@ -1,353 +0,0 @@
pub mod bars;
pub mod news;
use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET};
use async_trait::async_trait;
use backoff::{future::retry_notify, ExponentialBackoff};
use clickhouse::{inserter::Inserter, Row};
use futures_util::{future::join_all, SinkExt, StreamExt};
use log::error;
use nonempty::NonEmpty;
use qrust::types::alpaca::{self, websocket};
use serde::Serialize;
use serde_json::{from_str, to_string};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct State {
pub active_subscriptions: HashSet<String>,
pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>,
pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync + 'static {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
);
fn log_string(&self) -> &'static str;
async fn run_inserter(&self);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
websocket_url: String,
) {
let state = Arc::new(RwLock::new(State {
active_subscriptions: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let handler_clone = handler.clone();
spawn(async move { handler_clone.run_inserter().await });
let (sink_sender, sink_receiver) = mpsc::channel(100);
let (stream_sender, mut stream_receiver) = mpsc::channel(10_000);
spawn(run_connection(
handler.clone(),
sink_receiver,
stream_sender,
websocket_url.clone(),
state.clone(),
));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
state.clone(),
sink_sender.clone(),
message,
));
}
Some(message) = stream_receiver.recv() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let state = state.clone();
spawn(async move {
handler.handle_websocket_message(state, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
#[allow(clippy::too_many_lines)]
async fn run_connection(
handler: Arc<Box<dyn Handler>>,
mut sink_receiver: mpsc::Receiver<tungstenite::Message>,
stream_sender: mpsc::Sender<tungstenite::Message>,
websocket_url: String,
state: Arc<RwLock<State>>,
) {
let mut peek = None;
'connection: loop {
let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify(
ExponentialBackoff::default(),
|| async {
connect_async(websocket_url.clone())
.await
.map_err(Into::into)
},
|e, duration: Duration| {
error!(
"Failed to connect to {} websocket, will retry in {} seconds: {}.",
handler.log_string(),
duration.as_secs(),
e
);
},
)
.await
.unwrap();
let (mut sink, mut stream) = websocket.split();
alpaca::websocket::data::authenticate(
&mut sink,
&mut stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
let mut state = state.write().await;
state
.pending_unsubscriptions
.drain()
.for_each(|(_, sender)| {
sender.send(()).unwrap();
});
let (recovered_subscriptions, receivers) = state
.active_subscriptions
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
state.pending_subscriptions.extend(recovered_subscriptions);
let pending_subscriptions = state
.pending_subscriptions
.keys()
.cloned()
.collect::<Vec<_>>();
drop(state);
if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) {
if let Err(err) = sink
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(pending_subscriptions),
))
.unwrap(),
))
.await
{
error!("Failed to send websocket message: {:?}.", err);
continue;
}
}
join_all(receivers).await;
if peek.is_some() {
if let Err(err) = sink.send(peek.clone().unwrap()).await {
error!("Failed to send websocket message: {:?}.", err);
continue;
}
peek = None;
}
loop {
select! {
Some(message) = sink_receiver.recv() => {
peek = Some(message.clone());
if let Err(err) = sink.send(message).await {
error!("Failed to send websocket message: {:?}.", err);
continue 'connection;
};
peek = None;
}
message = stream.next() => {
if message.is_none() {
error!("Websocket stream unexpectedly closed.");
continue 'connection;
}
let message = message.unwrap();
if let Err(err) = message {
error!("Failed to receive websocket message: {:?}.", err);
continue 'connection;
}
let message = message.unwrap();
if message.is_close() {
error!("Websocket connection closed.");
continue 'connection;
}
stream_sender.send(message).await.unwrap();
}
else => error!("Communication channel unexpectedly closed.")
}
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<State>>,
sink_sender: mpsc::Sender<tungstenite::Message>,
message: Message,
) {
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
pending
.write()
.await
.pending_subscriptions
.extend(pending_subscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.pending_unsubscriptions
.extend(pending_unsubscriptions);
sink_sender
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>)
where
T: Row + Serialize,
{
loop {
let time_left = inserter.lock().await.time_left().unwrap();
tokio::time::sleep(time_left).await;
inserter.lock().await.commit().await.unwrap();
}
}

View File

@@ -1,119 +0,0 @@
use super::State;
use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE};
use async_trait::async_trait;
use clickhouse::inserter::Inserter;
use log::{debug, error, info};
use nonempty::NonEmpty;
use qrust::{
types::{alpaca::websocket, News},
utils::ONE_SECOND,
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock};
pub struct Handler {
pub inserter: Arc<Mutex<Inserter<News>>>,
}
#[async_trait]
impl super::Handler for Handler {
fn create_subscription_message(
&self,
symbols: NonEmpty<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
state: Arc<RwLock<State>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await;
let newly_subscribed = state
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = state
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
state
.active_subscriptions
.extend(newly_subscribed.keys().cloned());
drop(state);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
self.inserter.lock().await.write(&news).await.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
fn log_string(&self) -> &'static str {
"news"
}
async fn run_inserter(&self) {
super::run_inserter(self.inserter.clone()).await;
}
}
pub fn create_handler(config: &Arc<Config>) -> Box<dyn super::Handler> {
let inserter = Arc::new(Mutex::new(
config
.clickhouse_client
.inserter("news")
.unwrap()
.with_period(Some(ONE_SECOND))
.with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()),
));
Box::new(Handler { inserter })
}

View File

@@ -1,27 +0,0 @@
mod websocket;
use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET};
use futures_util::StreamExt;
use qrust::types::alpaca;
use std::sync::Arc;
use tokio::spawn;
use tokio_tungstenite::connect_async;
pub async fn run(config: Arc<Config>) {
let (websocket, _) =
connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE))
.await
.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::trading::authenticate(
&mut websocket_sink,
&mut websocket_stream,
(*ALPACA_API_KEY).to_string(),
(*ALPACA_API_SECRET).to_string(),
)
.await;
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
spawn(websocket::run(config, websocket_stream));
}

View File

@@ -1,79 +0,0 @@
use crate::{config::Config, database};
use futures_util::{stream::SplitStream, StreamExt};
use log::{debug, error};
use qrust::types::{alpaca::websocket, Order};
use serde_json::from_str;
use std::sync::Arc;
use tokio::{net::TcpStream, spawn};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run(
config: Arc<Config>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
loop {
let message = websocket_stream.next().await.unwrap().unwrap();
match message {
tungstenite::Message::Binary(message) => {
let parsed_message = from_str::<websocket::trading::incoming::Message>(
&String::from_utf8_lossy(&message),
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}.", message);
continue;
}
spawn(handle_websocket_message(
config.clone(),
parsed_message.unwrap(),
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}.", message),
}
}
}
async fn handle_websocket_message(
config: Arc<Config>,
message: websocket::trading::incoming::Message,
) {
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}.",
message.order.symbol, message.event
);
let order = Order::from(message.order);
database::orders::upsert(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order,
)
.await
.unwrap();
match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
| websocket::trading::incoming::order::Event::PartialFill {
position_qty, ..
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&order.symbol,
position_qty,
)
.await
.unwrap();
}
_ => (),
}
}
_ => unreachable!(),
}
}

View File

@@ -1,133 +0,0 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
use burn::{
config::Config,
data::{
dataloader::{DataLoaderBuilder, Dataset},
dataset::transform::{PartialDataset, ShuffledDataset},
},
module::Module,
optim::AdamConfig,
record::CompactRecorder,
tensor::backend::AutodiffBackend,
train::LearnerBuilder,
};
use dotenv::dotenv;
use log::info;
use qrust::{
database,
ml::{
BarWindow, BarWindowBatcher, ModelConfig, MultipleSymbolDataset, MyAutodiffBackend, DEVICE,
},
types::Bar,
};
use std::{env, fs, path::Path, sync::Arc};
use tokio::sync::Semaphore;
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 100)]
pub epochs: usize,
#[config(default = 256)]
pub batch_size: usize,
#[config(default = 16)]
pub num_workers: usize,
#[config(default = 0)]
pub seed: u64,
#[config(default = 0.2)]
pub valid_pct: f64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
#[tokio::main]
async fn main() {
dotenv().ok();
let dir = Path::new(file!()).parent().unwrap();
let model_config = ModelConfig::new();
let optimizer = AdamConfig::new();
let training_config = TrainingConfig::new(model_config, optimizer);
let clickhouse_client = clickhouse::Client::default()
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."))
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."));
let clickhouse_concurrency_limiter = Arc::new(Semaphore::new(Semaphore::MAX_PERMITS));
let bars = database::ta::select(&clickhouse_client, &clickhouse_concurrency_limiter)
.await
.unwrap();
info!("Loaded {} bars.", bars.len());
train::<MyAutodiffBackend>(
bars,
&training_config,
dir.join("artifacts").to_str().unwrap(),
&DEVICE,
);
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
fn train<B: AutodiffBackend<FloatElem = f32, IntElem = i32>>(
bars: Vec<Bar>,
config: &TrainingConfig,
dir: &str,
device: &B::Device,
) {
B::seed(config.seed);
fs::create_dir_all(dir).unwrap();
let dataset = MultipleSymbolDataset::new(bars);
let dataset = ShuffledDataset::with_seed(dataset, config.seed);
let dataset = Arc::new(dataset);
let split = (dataset.len() as f64 * (1.0 - config.valid_pct)) as usize;
let train: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), 0, split);
let batcher_train = BarWindowBatcher::<B> {
device: device.clone(),
};
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(train);
let valid: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> =
PartialDataset::new(dataset.clone(), split, dataset.len());
let batcher_valid = BarWindowBatcher::<B::InnerBackend> {
device: device.clone(),
};
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.num_workers(config.num_workers)
.build(valid);
let learner = LearnerBuilder::new(dir)
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.epochs)
.build(
config.model.init::<B>(device),
config.optimizer.init(),
config.learning_rate,
);
let trained = learner.fit(dataloader_train, dataloader_valid);
trained.save_file(dir, &CompactRecorder::new()).unwrap();
}

91
src/config.rs Normal file
View File

@@ -0,0 +1,91 @@
use crate::types::alpaca::Source;
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc, time::Duration};
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets";
pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock";
pub const ALPACA_STOCK_DATA_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
pub struct Config {
pub alpaca_api_key: String,
pub alpaca_api_secret: String,
pub alpaca_rate_limit: DefaultDirectRateLimiter,
pub alpaca_source: Source,
pub alpaca_client: Client,
pub ollama_url: String,
pub ollama_model: String,
pub ollama_client: Client,
pub clickhouse_client: clickhouse::Client,
}
impl Config {
pub fn from_env() -> Self {
let alpaca_api_key = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
let alpaca_api_secret =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
let alpaca_source: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be a either 'iex' or 'sip'.");
let ollama_url = env::var("OLLAMA_URL").expect("OLLAMA_URL must be set.");
let ollama_model = env::var("OLLAMA_MODEL").expect("OLLAMA_MODEL must be set.");
let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.");
let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.");
let clickhouse_password =
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.");
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
(
HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&alpaca_api_key)
.expect("Alpaca API key must not contain invalid characters."),
),
(
HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&alpaca_api_secret)
.expect("Alpaca API secret must not contain invalid characters."),
),
]))
.timeout(Duration::from_secs(60))
.build()
.unwrap(),
alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source {
Source::Iex => unsafe { NonZeroU32::new_unchecked(190) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(9990) },
})),
alpaca_source,
alpaca_api_key,
alpaca_api_secret,
ollama_url,
ollama_model,
ollama_client: Client::builder()
.timeout(Duration::from_secs(15))
.build()
.unwrap(),
clickhouse_client: clickhouse::Client::default()
.with_url(clickhouse_url)
.with_user(clickhouse_user)
.with_password(clickhouse_password)
.with_database(clickhouse_db),
}
}
pub fn arc_from_env() -> Arc<Self> {
Arc::new(Self::from_env())
}
}

48
src/database/assets.rs Normal file
View File

@@ -0,0 +1,48 @@
use crate::types::Asset;
use clickhouse::Client;
use serde::Serialize;
pub async fn select(clickhouse_client: &Client) -> Vec<Asset> {
clickhouse_client
.query("SELECT ?fields FROM assets FINAL")
.fetch_all::<Asset>()
.await
.unwrap()
}
pub async fn select_where_symbol<T>(clickhouse_client: &Client, symbol: &T) -> Option<Asset>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("SELECT ?fields FROM assets FINAL WHERE symbol = ? OR abbreviation = ?")
.bind(symbol)
.bind(symbol)
.fetch_optional::<Asset>()
.await
.unwrap()
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, assets: T)
where
T: IntoIterator<Item = Asset> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("assets").unwrap();
for asset in assets {
insert.write(&asset).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM assets WHERE symbol IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}

93
src/database/backfills.rs Normal file
View File

@@ -0,0 +1,93 @@
use crate::{database::assets, threads::data::ThreadType, types::Backfill};
use clickhouse::Client;
use serde::Serialize;
use tokio::join;
pub async fn select_latest_where_symbol<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbol: &T,
) -> Option<Backfill>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbol)
.fetch_optional::<Backfill>()
.await
.unwrap()
}
pub async fn upsert(clickhouse_client: &Client, thread_type: &ThreadType, backfill: &Backfill) {
let mut insert = clickhouse_client
.insert(match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
})
.unwrap();
insert.write(backfill).await.unwrap();
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbols: &[T],
) where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"DELETE FROM {} WHERE symbol IN ?",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let bars_symbols = assets
.clone()
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let news_symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
let delete_bars_future = async {
clickhouse_client
.query("DELETE FROM backfills_bars WHERE symbol NOT IN ?")
.bind(bars_symbols)
.execute()
.await
.unwrap();
};
let delete_news_future = async {
clickhouse_client
.query("DELETE FROM backfills_news WHERE symbol NOT IN ?")
.bind(news_symbols)
.execute()
.await
.unwrap();
};
join!(delete_bars_future, delete_news_future);
}

50
src/database/bars.rs Normal file
View File

@@ -0,0 +1,50 @@
use super::assets;
use crate::types::Bar;
use clickhouse::Client;
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, bar: &Bar) {
let mut insert = clickhouse_client.insert("bars").unwrap();
insert.write(bar).await.unwrap();
insert.end().await.unwrap();
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, bars: T)
where
T: IntoIterator<Item = Bar> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("bars").unwrap();
for bar in bars {
insert.write(&bar).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM bars WHERE symbol IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
clickhouse_client
.query("DELETE FROM bars WHERE symbol NOT IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}

4
src/database/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod assets;
pub mod backfills;
pub mod bars;
pub mod news;

50
src/database/news.rs Normal file
View File

@@ -0,0 +1,50 @@
use super::assets;
use crate::types::News;
use clickhouse::Client;
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, news: &News) {
let mut insert = clickhouse_client.insert("news").unwrap();
insert.write(news).await.unwrap();
insert.end().await.unwrap();
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, news: T)
where
T: IntoIterator<Item = News> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("news").unwrap();
for news in news {
insert.write(&news).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM news WHERE hasAny(symbols, ?)")
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
clickhouse_client
.query("DELETE FROM news WHERE NOT hasAny(symbols, ?)")
.bind(symbols)
.execute()
.await
.unwrap();
}

View File

@@ -1,39 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Account>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,132 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Asset>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Asset>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

View File

@@ -1,50 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,41 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Calendar>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,39 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Clock>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,27 +0,0 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;
use http::StatusCode;
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
if err.is_status() {
return match err.status() {
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
| None => backoff::Error::Permanent(err),
_ => err.into(),
};
}
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
{
return backoff::Error::Permanent(err);
}
err.into()
}

View File

@@ -1,49 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,43 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::{api::outgoing, shared::order};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use order::Order;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Order>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,109 +0,0 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use http::StatusCode;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Position>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(error_to_backoff)?
.json::<Position>()
.await
.map_err(error_to_backoff)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,50 +0,0 @@
use std::sync::Arc;
use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets");
select_where_symbol!(Asset, "assets");
upsert_batch!(Asset, "assets");
delete_where_symbols!("assets");
optimize!("assets");
pub async fn update_status_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
status: bool,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
.execute()
.await
}
pub async fn update_qty_where_symbol<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol)
.execute()
.await
}

View File

@@ -1,11 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_bars");
upsert_batch!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
set_fresh_where_symbols!("backfills_bars");

View File

@@ -1,11 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_news");
upsert_batch!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
set_fresh_where_symbols!("backfills_news");

View File

@@ -1,21 +0,0 @@
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,42 +0,0 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar};
use clickhouse::{error::Error, Client};
use tokio::{sync::Semaphore, try_join};
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, I>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
records: I,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
I::IntoIter: Send,
{
let upsert_future = async {
let mut insert = client.insert("calendar")?;
for record in records.clone() {
insert.write(record).await?;
}
insert.end().await
};
let delete_future = async {
let dates = records
.clone()
.into_iter()
.map(|r| r.date)
.collect::<Vec<_>>();
client
.query("DELETE FROM calendar WHERE date NOT IN ?")
.bind(dates)
.execute()
.await
};
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ())
}

View File

@@ -1,224 +0,0 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
pub mod ta;
use clickhouse::{error::Error, Client};
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, I>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: I,
) -> Result<(), clickhouse::error::Error>
where
I: IntoIterator<Item = &'a $record> + Send + Sync,
I::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! set_fresh_where_symbols {
($table_name:expr) => {
pub async fn set_fresh_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
fresh: bool,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
$table_name
))
.bind(fresh)
.bind(symbols)
.execute()
.await
}
};
}
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -1,36 +0,0 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client};
use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news");
upsert_batch!(News, "news");
optimize!("news");
pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols)
.execute()
.await
}
pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)
.execute()
.await
}

View File

@@ -1,5 +0,0 @@
use crate::{optimize, types::Order, upsert, upsert_batch};
upsert!(Order, "orders");
upsert_batch!(Order, "orders");
optimize!("orders");

View File

@@ -1,30 +0,0 @@
use crate::types::Bar;
use clickhouse::{error::Error, Client};
use std::sync::Arc;
use tokio::sync::Semaphore;
pub async fn select(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<Vec<Bar>, Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(
"
SELECT symbol,
toStartOfHour(bars.time) AS time,
any(bars.open) AS open,
max(bars.high) AS high,
min(bars.low) AS low,
anyLast(bars.close) AS close,
sum(bars.volume) AS volume,
sum(bars.trades) AS trades
FROM bars FINAL
GROUP BY ALL
ORDER BY symbol,
time
",
)
.fetch_all::<Bar>()
.await
}

View File

@@ -1,75 +0,0 @@
use super::BarWindow;
use burn::{
data::dataloader::batcher::Batcher,
tensor::{self, backend::Backend, Tensor},
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Clone, Debug)]
pub struct BarWindowBatcher<B: Backend> {
pub device: B::Device,
}
#[derive(Clone, Debug)]
pub struct BarWindowBatch<B: Backend> {
pub hour_tensor: Tensor<B, 2, tensor::Int>,
pub day_tensor: Tensor<B, 2, tensor::Int>,
pub numerical_tensor: Tensor<B, 3>,
pub target_tensor: Tensor<B, 2>,
}
impl<B: Backend<FloatElem = f32, IntElem = i32>> Batcher<BarWindow, BarWindowBatch<B>>
for BarWindowBatcher<B>
{
fn batch(&self, items: Vec<BarWindow>) -> BarWindowBatch<B> {
let batch_size = items.len();
let (hour_tensors, day_tensors, numerical_tensors, target_tensors) = items
.into_par_iter()
.fold(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.push(Tensor::from_data(item.hours, &self.device));
day_tensors.push(Tensor::from_data(item.days, &self.device));
numerical_tensors.push(Tensor::from_data(item.numerical, &self.device));
target_tensors.push(Tensor::from_data(item.target, &self.device));
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
)
.reduce(
|| {
(
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
Vec::with_capacity(batch_size),
)
},
|(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors),
item| {
hour_tensors.extend(item.0);
day_tensors.extend(item.1);
numerical_tensors.extend(item.2);
target_tensors.extend(item.3);
(hour_tensors, day_tensors, numerical_tensors, target_tensors)
},
);
BarWindowBatch {
hour_tensor: Tensor::stack(hour_tensors, 0).to_device(&self.device),
day_tensor: Tensor::stack(day_tensors, 0).to_device(&self.device),
numerical_tensor: Tensor::stack(numerical_tensors, 0).to_device(&self.device),
target_tensor: Tensor::stack(target_tensors, 0).to_device(&self.device),
}
}
}

View File

@@ -1,219 +0,0 @@
use crate::types::{
ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT},
Bar,
};
use burn::{
data::dataset::{transform::ComposedDataset, Dataset},
tensor::Data,
};
pub const WINDOW_SIZE: usize = 48;
#[derive(Clone, Debug)]
pub struct BarWindow {
pub hours: Data<i32, 1>,
pub days: Data<i32, 1>,
pub numerical: Data<f32, 2>,
pub target: Data<f32, 1>,
}
#[derive(Clone, Debug)]
struct SingleSymbolDataset {
hours: Vec<i32>,
days: Vec<i32>,
numerical: Vec<[f32; NUMERICAL_FIELD_COUNT]>,
targets: Vec<f32>,
}
impl SingleSymbolDataset {
#[allow(clippy::cast_possible_truncation)]
pub fn new(bars: Vec<IndicatedBar>) -> Self {
if !bars.is_empty() {
let symbol = &bars[0].symbol;
assert!(bars.iter().all(|bar| bar.symbol == *symbol));
}
let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold(
(
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
Vec::with_capacity(bars.len() - 1),
),
|(mut hours, mut days, mut numerical, mut targets), bar| {
hours.push(i32::from(bar[0].hour));
days.push(i32::from(bar[0].day));
numerical.push([
bar[0].open as f32,
(bar[0].open_pct as f32).min(f32::MAX),
bar[0].high as f32,
(bar[0].high_pct as f32).min(f32::MAX),
bar[0].low as f32,
(bar[0].low_pct as f32).min(f32::MAX),
bar[0].close as f32,
(bar[0].close_pct as f32).min(f32::MAX),
bar[0].volume as f32,
(bar[0].volume_pct as f32).min(f32::MAX),
bar[0].trades as f32,
(bar[0].trades_pct as f32).min(f32::MAX),
bar[0].sma_3 as f32,
bar[0].sma_6 as f32,
bar[0].sma_12 as f32,
bar[0].sma_24 as f32,
bar[0].sma_48 as f32,
bar[0].sma_72 as f32,
bar[0].ema_3 as f32,
bar[0].ema_6 as f32,
bar[0].ema_12 as f32,
bar[0].ema_24 as f32,
bar[0].ema_48 as f32,
bar[0].ema_72 as f32,
bar[0].macd as f32,
bar[0].macd_signal as f32,
bar[0].obv as f32,
bar[0].rsi as f32,
bar[0].bbands_lower as f32,
bar[0].bbands_mean as f32,
bar[0].bbands_upper as f32,
]);
targets.push(bar[1].close_pct as f32);
(hours, days, numerical, targets)
},
);
Self {
hours,
days,
numerical,
targets,
}
}
}
impl Dataset<BarWindow> for SingleSymbolDataset {
fn len(&self) -> usize {
self.targets.len() - WINDOW_SIZE + 1
}
#[allow(clippy::single_range_in_vec_init)]
fn get(&self, idx: usize) -> Option<BarWindow> {
if idx >= self.len() {
return None;
}
let hours: [i32; WINDOW_SIZE] = self.hours[idx..idx + WINDOW_SIZE].try_into().unwrap();
let days: [i32; WINDOW_SIZE] = self.days[idx..idx + WINDOW_SIZE].try_into().unwrap();
let numerical: [[f32; NUMERICAL_FIELD_COUNT]; WINDOW_SIZE] =
self.numerical[idx..idx + WINDOW_SIZE].try_into().unwrap();
let target: [f32; 1] = [self.targets[idx + WINDOW_SIZE - 1]];
Some(BarWindow {
hours: Data::from(hours),
days: Data::from(days),
numerical: Data::from(numerical),
target: Data::from(target),
})
}
}
pub struct MultipleSymbolDataset {
composed_dataset: ComposedDataset<SingleSymbolDataset>,
}
impl MultipleSymbolDataset {
pub fn new(bars: Vec<Bar>) -> Self {
let groups = calculate_indicators(bars)
.into_iter()
.map(SingleSymbolDataset::new)
.collect::<Vec<_>>();
Self {
composed_dataset: ComposedDataset::new(groups),
}
}
}
impl Dataset<BarWindow> for MultipleSymbolDataset {
fn len(&self) -> usize {
self.composed_dataset.len()
}
fn get(&self, idx: usize) -> Option<BarWindow> {
self.composed_dataset.get(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{
distributions::{Distribution, Uniform},
Rng,
};
use time::OffsetDateTime;
fn generate_random_dataset(length: usize) -> MultipleSymbolDataset {
let mut rng = rand::thread_rng();
let uniform = Uniform::new(1.0, 100.0);
let mut bars = Vec::with_capacity(length);
for _ in 0..=(length + (HEAD_SIZE - 1) + (WINDOW_SIZE - 1)) {
bars.push(Bar {
symbol: "AAPL".to_string(),
time: OffsetDateTime::now_utc(),
open: uniform.sample(&mut rng),
high: uniform.sample(&mut rng),
low: uniform.sample(&mut rng),
close: uniform.sample(&mut rng),
volume: uniform.sample(&mut rng),
trades: rng.gen_range(1..100),
});
}
MultipleSymbolDataset::new(bars)
}
#[test]
fn test_single_symbol_dataset() {
let length = 100;
let dataset = generate_random_dataset(length);
assert_eq!(dataset.len(), length);
}
#[test]
fn test_single_symbol_dataset_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(0).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_last_window() {
let length = 100;
let dataset = generate_random_dataset(length);
let item = dataset.get(dataset.len() - 1).unwrap();
assert_eq!(
item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
);
assert_eq!(item.target.shape.dims, [1]);
}
#[test]
fn test_single_symbol_dataset_out_of_bounds() {
let length = 100;
let dataset = generate_random_dataset(length);
assert!(dataset.get(dataset.len()).is_none());
}
}

View File

@@ -1,21 +0,0 @@
pub mod batcher;
pub mod dataset;
pub mod model;
pub use batcher::{BarWindowBatch, BarWindowBatcher};
pub use dataset::{BarWindow, MultipleSymbolDataset};
pub use model::{Model, ModelConfig};
use burn::{
backend::{
wgpu::{AutoGraphicsApi, WgpuDevice},
Autodiff, Wgpu,
},
tensor::backend::Backend,
};
pub type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
pub type MyAutodiffBackend = Autodiff<MyBackend>;
pub type MyDevice = <Autodiff<Wgpu> as Backend>::Device;
pub const DEVICE: MyDevice = WgpuDevice::BestAvailable;

View File

@@ -1,160 +0,0 @@
use super::BarWindowBatch;
use crate::types::ta::NUMERICAL_FIELD_COUNT;
use burn::{
config::Config,
module::Module,
nn::{
loss::{MseLoss, Reduction},
Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig,
},
tensor::{
self,
backend::{AutodiffBackend, Backend},
Tensor,
},
train::{RegressionOutput, TrainOutput, TrainStep, ValidStep},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
hour_embedding: Embedding<B>,
day_embedding: Embedding<B>,
lstm_1: Lstm<B>,
dropout_1: Dropout,
lstm_2: Lstm<B>,
dropout_2: Dropout,
lstm_3: Lstm<B>,
dropout_3: Dropout,
lstm_4: Lstm<B>,
dropout_4: Dropout,
linear: Linear<B>,
}
#[derive(Config, Debug)]
pub struct ModelConfig {
#[config(default = "3")]
pub hour_features: usize,
#[config(default = "2")]
pub day_features: usize,
#[config(default = "{NUMERICAL_FIELD_COUNT}")]
pub numerical_features: usize,
#[config(default = "0.2")]
pub dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
let num_features = self.numerical_features + self.hour_features + self.day_features;
let lstm_1_hidden_size = 512;
let lstm_2_hidden_size = 256;
let lstm_3_hidden_size = 64;
let lstm_4_hidden_size = 32;
Model {
hour_embedding: EmbeddingConfig::new(24, self.hour_features).init(device),
day_embedding: EmbeddingConfig::new(7, self.day_features).init(device),
lstm_1: LstmConfig::new(num_features, lstm_1_hidden_size, true).init(device),
dropout_1: DropoutConfig::new(self.dropout).init(),
lstm_2: LstmConfig::new(lstm_1_hidden_size, lstm_2_hidden_size, true).init(device),
dropout_2: DropoutConfig::new(self.dropout).init(),
lstm_3: LstmConfig::new(lstm_2_hidden_size, lstm_3_hidden_size, true).init(device),
dropout_3: DropoutConfig::new(self.dropout).init(),
lstm_4: LstmConfig::new(lstm_3_hidden_size, lstm_4_hidden_size, true).init(device),
dropout_4: DropoutConfig::new(self.dropout).init(),
linear: LinearConfig::new(lstm_4_hidden_size, 1).init(device),
}
}
}
impl<B: Backend> Model<B> {
pub fn forward(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
) -> Tensor<B, 2> {
let hour = self.hour_embedding.forward(hour);
let day = self.day_embedding.forward(day);
let x = Tensor::cat(vec![hour, day, numerical], 2);
let (_, x) = self.lstm_1.forward(x, None);
let x = self.dropout_1.forward(x);
let (_, x) = self.lstm_2.forward(x, None);
let x = self.dropout_2.forward(x);
let (_, x) = self.lstm_3.forward(x, None);
let x = self.dropout_3.forward(x);
let (_, x) = self.lstm_4.forward(x, None);
let x = self.dropout_4.forward(x);
let [batch_size, window_size, features] = x.shape().dims;
let x = x.slice([0..batch_size, window_size - 1..window_size, 0..features]);
let x = x.squeeze(1);
self.linear.forward(x)
}
pub fn forward_regression(
&self,
hour: Tensor<B, 2, tensor::Int>,
day: Tensor<B, 2, tensor::Int>,
numerical: Tensor<B, 3>,
target: Tensor<B, 2>,
) -> RegressionOutput<B> {
let output = self.forward(hour, day, numerical);
let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean);
RegressionOutput::new(loss, output, target)
}
}
impl<B: AutodiffBackend> TrainStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> TrainOutput<RegressionOutput<B>> {
let item = self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> {
fn step(&self, batch: BarWindowBatch<B>) -> RegressionOutput<B> {
self.forward_regression(
batch.hour_tensor,
batch.day_tensor,
batch.numerical_tensor,
batch.target_tensor,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::{backend::Wgpu, tensor::Distribution};
#[test]
#[ignore]
fn test_model() {
let device = Default::default();
let distribution = Distribution::Normal(0.0, 1.0);
let config = ModelConfig::new().with_numerical_features(7);
let model = config.init::<Wgpu>(&device);
let hour = Tensor::ones([2, 10], &device);
let day = Tensor::ones([2, 10], &device);
let numerical = Tensor::random([2, 10, 7], distribution, &device);
let output = model.forward(hour, day, numerical);
assert_eq!(output.shape().dims, [2, 1]);
}
}

View File

@@ -1,6 +0,0 @@
pub mod alpaca;
pub mod database;
pub mod ml;
pub mod ta;
pub mod types;
pub mod utils;

View File

@@ -1,149 +0,0 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct BbandsState {
window: VecDeque<f64>,
sum: f64,
squared_sum: f64,
multiplier: f64,
}
#[allow(clippy::type_complexity)]
pub trait Bbands<T>: Iterator + Sized {
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64, // Typically 2.0
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>>;
}
impl<I, T> Bbands<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn bbands(
self,
period: NonZeroUsize,
multiplier: f64,
) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>> {
self.scan(
BbandsState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
squared_sum: 0.0,
multiplier,
},
|state: &mut BbandsState, value: T| {
let value = *value.borrow();
let front = state.window.pop_front().unwrap();
state.sum -= front;
state.squared_sum -= front.powi(2);
state.window.push_back(value);
state.sum += value;
state.squared_sum += value.powi(2);
let mean = state.sum / state.window.len() as f64;
let variance =
((state.squared_sum / state.window.len() as f64) - mean.powi(2)).max(0.0);
let standard_deviation = variance.sqrt();
let upper_band = mean + state.multiplier * standard_deviation;
let lower_band = mean - state.multiplier * standard_deviation;
Some((upper_band, mean, lower_band))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bbands() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
#[test]
fn test_bbands_empty() {
let data = Vec::<f64>::new();
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(bbands, Vec::<(f64, f64, f64)>::new());
}
#[test]
fn test_bbands_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.into_iter()
.bbands(NonZeroUsize::new(1).unwrap(), 2.0)
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.0, 1.0, 1.0),
(2.0, 2.0, 2.0),
(3.0, 3.0, 3.0),
(4.0, 4.0, 4.0),
(5.0, 5.0, 5.0)
]
);
}
#[test]
fn test_bbands_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let bbands = data
.iter()
.bbands(NonZeroUsize::new(3).unwrap(), 2.0)
.map(|(upper, mean, lower)| {
(
(upper * 100.0).round() / 100.0,
(mean * 100.0).round() / 100.0,
(lower * 100.0).round() / 100.0,
)
})
.collect::<Vec<_>>();
assert_eq!(
bbands,
vec![
(1.28, 0.33, -0.61),
(2.63, 1.0, -0.63),
(3.63, 2.0, 0.37),
(4.63, 3.0, 1.37),
(5.63, 4.0, 2.37)
]
);
}
}

View File

@@ -1,59 +0,0 @@
use std::{borrow::Borrow, iter::Scan};
pub struct DerivState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Deriv<T>: Iterator + Sized {
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>>;
}
impl<I, T> Deriv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>> {
self.scan(
DerivState { last: 0.0 },
|state: &mut DerivState, value: T| {
let value = *value.borrow();
let deriv = value - state.last;
state.last = value;
Some(deriv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deriv() {
let data = vec![1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
#[test]
fn test_deriv_empty() {
let data = Vec::<f64>::new();
let deriv = data.into_iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, Vec::<f64>::new());
}
#[test]
fn test_deriv_borrow() {
let data = [1.0, 3.0, 6.0, 3.0, 1.0];
let deriv = data.iter().deriv().collect::<Vec<_>>();
assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]);
}
}

View File

@@ -1,95 +0,0 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct EmaState {
weight: f64,
ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Ema<T>: Iterator + Sized {
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>>;
}
impl<I, T> Ema<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn ema(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>> {
let smoothing = 2.0;
let weight = smoothing / (1.0 + period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
EmaState { weight, ema: first },
|state: &mut EmaState, value: T| {
let value = *value.borrow();
state.ema = (value * state.weight) + (state.ema * (1.0 - state.weight));
Some(state.ema)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
#[test]
fn test_ema_empty() {
let data = Vec::<f64>::new();
let ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, Vec::<f64>::new());
}
#[test]
fn test_ema_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_ema_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let ema = data
.iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]);
}
}

View File

@@ -1,216 +0,0 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct MacdState {
short_weight: f64,
long_weight: f64,
signal_weight: f64,
short_ema: f64,
long_ema: f64,
signal_ema: f64,
}
#[allow(clippy::type_complexity)]
pub trait Macd<T>: Iterator + Sized {
fn macd(
self,
short_period: NonZeroUsize, // Typically 12
long_period: NonZeroUsize, // Typically 26
signal_period: NonZeroUsize, // Typically 9
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>>;
}
impl<I, T> Macd<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn macd(
self,
short_period: NonZeroUsize,
long_period: NonZeroUsize,
signal_period: NonZeroUsize,
) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>> {
let smoothing = 2.0;
let short_weight = smoothing / (1.0 + short_period.get() as f64);
let long_weight = smoothing / (1.0 + long_period.get() as f64);
let signal_weight = smoothing / (1.0 + signal_period.get() as f64);
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
MacdState {
short_weight,
long_weight,
signal_weight,
short_ema: first,
long_ema: first,
signal_ema: 0.0,
},
|state: &mut MacdState, value: T| {
let value = *value.borrow();
state.short_ema =
(value * state.short_weight) + (state.short_ema * (1.0 - state.short_weight));
state.long_ema =
(value * state.long_weight) + (state.long_ema * (1.0 - state.long_weight));
let macd = state.short_ema - state.long_ema;
state.signal_ema =
(macd * state.signal_weight) + (state.signal_ema * (1.0 - state.signal_weight));
Some((macd, state.signal_ema))
},
)
}
}
#[cfg(test)]
mod tests {
use super::super::ema::Ema;
use super::*;
#[test]
fn test_macd() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_empty() {
let data = Vec::<f64>::new();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
vec![]
);
}
#[test]
fn test_macd_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.into_iter()
.macd(
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(1).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
#[test]
fn test_macd_borrow() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0];
let short_ema = data
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let long_ema = data
.into_iter()
.ema(NonZeroUsize::new(5).unwrap())
.collect::<Vec<_>>();
let macd = short_ema
.into_iter()
.zip(long_ema)
.map(|(short, long)| short - long)
.collect::<Vec<_>>();
let signal = macd
.clone()
.into_iter()
.ema(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
let expected = macd.into_iter().zip(signal).collect::<Vec<_>>();
assert_eq!(
data.iter()
.macd(
NonZeroUsize::new(3).unwrap(),
NonZeroUsize::new(5).unwrap(),
NonZeroUsize::new(3).unwrap()
)
.collect::<Vec<_>>(),
expected
);
}
}

View File

@@ -1,17 +0,0 @@
pub mod bbands;
pub mod deriv;
pub mod ema;
pub mod macd;
pub mod obv;
pub mod pct;
pub mod rsi;
pub mod sma;
pub use bbands::Bbands;
pub use deriv::Deriv;
pub use ema::Ema;
pub use macd::Macd;
pub use obv::Obv;
pub use pct::Pct;
pub use rsi::Rsi;
pub use sma::Sma;

View File

@@ -1,73 +0,0 @@
use std::{
borrow::Borrow,
iter::{Peekable, Scan},
};
pub struct ObvState {
last: f64,
obv: f64,
}
#[allow(clippy::type_complexity)]
pub trait Obv<T>: Iterator + Sized {
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>>;
}
impl<I, T> Obv<T> for I
where
I: Iterator<Item = T>,
T: Borrow<(f64, f64)>,
{
fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
ObvState {
last: first.0,
obv: 0.0,
},
|state: &mut ObvState, value: T| {
let (close, volume) = *value.borrow();
if close > state.last {
state.obv += volume;
} else if close < state.last {
state.obv -= volume;
}
state.last = close;
Some(state.obv)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_obv() {
let data = vec![(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
#[test]
fn test_obv_empty() {
let data = Vec::<(f64, f64)>::new();
let obv = data.into_iter().obv().collect::<Vec<_>>();
assert_eq!(obv, Vec::<f64>::new());
}
#[test]
fn test_obv_borrow() {
let data = [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)];
let obv = data.iter().obv().collect::<Vec<_>>();
assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]);
}
}

View File

@@ -1,64 +0,0 @@
use std::{borrow::Borrow, iter::Scan};
pub struct PctState {
pub last: f64,
}
#[allow(clippy::type_complexity)]
pub trait Pct<T>: Iterator + Sized {
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>>;
}
impl<I, T> Pct<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>> {
self.scan(PctState { last: 0.0 }, |state: &mut PctState, value: T| {
let value = *value.borrow();
let pct = value / state.last - 1.0;
state.last = value;
Some(pct)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pct() {
let data = vec![1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
#[test]
fn test_pct_empty() {
let data = Vec::<f64>::new();
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, Vec::<f64>::new());
}
#[test]
fn test_pct_0() {
let data = vec![1.0, 0.0, 4.0, 2.0, 1.0];
let pct = data.into_iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, -1.0, f64::INFINITY, -0.5, -0.5]);
}
#[test]
fn test_pct_borrow() {
let data = [1.0, 2.0, 4.0, 2.0, 1.0];
let pct = data.iter().pct().collect::<Vec<_>>();
assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]);
}
}

View File

@@ -1,135 +0,0 @@
use std::{
borrow::Borrow,
collections::VecDeque,
iter::{Peekable, Scan},
num::NonZeroUsize,
};
pub struct RsiState {
last: f64,
window_gains: VecDeque<f64>,
window_losses: VecDeque<f64>,
sum_gains: f64,
sum_losses: f64,
}
#[allow(clippy::type_complexity)]
pub trait Rsi<T>: Iterator + Sized {
fn rsi(
self,
period: NonZeroUsize, // Typically 14
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>>;
}
impl<I, T> Rsi<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn rsi(
self,
period: NonZeroUsize,
) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>> {
let mut iter = self.peekable();
let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default();
iter.scan(
RsiState {
last: first,
window_gains: VecDeque::from(vec![0.0; period.get()]),
window_losses: VecDeque::from(vec![0.0; period.get()]),
sum_gains: 0.0,
sum_losses: 0.0,
},
|state, value| {
let value = *value.borrow();
state.sum_gains -= state.window_gains.pop_front().unwrap();
state.sum_losses -= state.window_losses.pop_front().unwrap();
let gain = (value - state.last).max(0.0);
let loss = (state.last - value).max(0.0);
state.last = value;
state.window_gains.push_back(gain);
state.window_losses.push_back(loss);
state.sum_gains += gain;
state.sum_losses += loss;
let avg_loss = state.sum_losses / state.window_losses.len() as f64;
if avg_loss == 0.0 {
return Some(100.0);
}
let avg_gain = state.sum_gains / state.window_gains.len() as f64;
let rs = avg_gain / avg_loss;
Some(100.0 - (100.0 / (1.0 + rs)))
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi() {
let data = vec![1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
#[test]
fn test_rsi_empty() {
let data = Vec::<f64>::new();
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, Vec::<f64>::new());
}
#[test]
fn test_rsi_no_loss() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 100.0, 100.0]);
}
#[test]
fn test_rsi_no_gain() {
let data = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let rsi = data
.into_iter()
.rsi(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_rsi_borrow() {
let data = [1.0, 4.0, 7.0, 4.0, 1.0];
let rsi = data
.iter()
.rsi(NonZeroUsize::new(3).unwrap())
.map(|v| (v * 100.0).round() / 100.0)
.collect::<Vec<_>>();
assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]);
}
}

View File

@@ -1,88 +0,0 @@
use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize};
pub struct SmaState {
window: VecDeque<f64>,
sum: f64,
}
#[allow(clippy::type_complexity)]
pub trait Sma<T>: Iterator + Sized {
fn sma(self, period: NonZeroUsize)
-> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>>;
}
impl<I, T> Sma<T> for I
where
I: Iterator<Item = T>,
T: Borrow<f64>,
{
fn sma(
self,
period: NonZeroUsize,
) -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>> {
self.scan(
SmaState {
window: VecDeque::from(vec![0.0; period.get()]),
sum: 0.0,
},
|state: &mut SmaState, value: T| {
let value = *value.borrow();
state.sum -= state.window.pop_front().unwrap();
state.window.push_back(value);
state.sum += value;
Some(state.sum / state.window.len() as f64)
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sma() {
let data = vec![3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_sma_empty() {
let data = Vec::<f64>::new();
let sma = data
.into_iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, Vec::<f64>::new());
}
#[test]
fn test_sma_1_period() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sma = data
.into_iter()
.sma(NonZeroUsize::new(1).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_sma_borrow() {
let data = [3.0, 6.0, 9.0, 12.0, 15.0];
let sma = data
.iter()
.sma(NonZeroUsize::new(3).unwrap())
.collect::<Vec<_>>();
assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]);
}
}

View File

@@ -1,75 +0,0 @@
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Status {
Onboarding,
SubmissionFailed,
Submitted,
AccountUpdated,
ApprovalPending,
Active,
Rejected,
}
#[derive(Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct Account {
pub id: Uuid,
#[serde(rename = "account_number")]
pub number: String,
pub status: Status,
pub currency: String,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cash: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub non_marginable_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub accrued_fees: f64,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_in: Option<f64>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_out: Option<f64>,
pub pattern_day_trader: bool,
#[serde(default)]
pub trade_suspend_by_user: bool,
pub trading_blocked: bool,
pub transfers_blocked: bool,
#[serde(rename = "account_blocked")]
pub blocked: bool,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
pub shorting_enabled: bool,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub long_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub short_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub multiplier: i8,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub initial_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub sma: f64,
pub daytrade_count: i64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub daytrading_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}

View File

@@ -1,39 +0,0 @@
use super::position::Position;
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}

View File

@@ -1,26 +0,0 @@
use crate::{
types,
utils::{de, time::EST_OFFSET},
};
use serde::Deserialize;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}

View File

@@ -1,8 +0,0 @@
pub mod account;
pub mod asset;
pub mod bar;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod order;
pub mod position;

View File

@@ -1,3 +0,0 @@
use crate::types::alpaca::shared::order;
pub use order::{Order, Side};

View File

@@ -1,61 +0,0 @@
use crate::{
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}

View File

@@ -1,6 +0,0 @@
pub mod incoming;
pub mod outgoing;
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";

View File

@@ -1,23 +0,0 @@
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,108 +0,0 @@
use crate::{
alpaca::bars::MAX_LIMIT,
types::alpaca::shared,
utils::{ser, ONE_MINUTE},
};
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
pub use shared::{Sort, Source};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Adjustment {
Raw,
Split,
Dividend,
All,
}
#[derive(Serialize)]
pub struct UsEquity {
#[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>,
#[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub adjustment: Option<Adjustment>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub asof: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feed: Option<Source>,
#[serde(skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for UsEquity {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
adjustment: Some(Adjustment::All),
asof: None,
feed: Some(Source::Iex),
currency: None,
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
pub struct Crypto {
#[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>,
#[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for Crypto {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(MAX_LIMIT),
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Bar {
UsEquity(UsEquity),
Crypto(Crypto),
}

View File

@@ -1,31 +0,0 @@
use crate::utils::time::MAX_TIMESTAMP;
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
#[allow(dead_code)]
pub enum DateType {
Trading,
Settlement,
}
#[derive(Serialize)]
pub struct Calendar {
#[serde(with = "time::serde::rfc3339")]
pub start: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub end: OffsetDateTime,
#[serde(rename = "date")]
pub date_type: DateType,
}
impl Default for Calendar {
fn default() -> Self {
Self {
start: OffsetDateTime::UNIX_EPOCH,
end: *MAX_TIMESTAMP,
date_type: DateType::Trading,
}
}
}

View File

@@ -1,40 +0,0 @@
use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
pub struct News {
#[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
pub symbols: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_content: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_contentless: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for News {
fn default() -> Self {
Self {
symbols: vec![],
start: None,
end: None,
limit: Some(MAX_LIMIT),
include_content: Some(true),
exclude_contentless: Some(false),
page_token: None,
sort: Some(Sort::Asc),
}
}
}

View File

@@ -1,55 +0,0 @@
use crate::{
types::alpaca::shared::{order, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Status {
Open,
Closed,
All,
}
#[derive(Serialize)]
pub struct Order {
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<Status>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub after: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub until: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub direction: Option<Sort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nested: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "ser::join_symbols_option")]
pub symbols: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub side: Option<Side>,
}
impl Default for Order {
fn default() -> Self {
Self {
status: Some(Status::All),
limit: Some(500),
after: None,
until: None,
direction: Some(Sort::Asc),
nested: Some(true),
symbols: None,
side: None,
}
}
}

View File

@@ -1,3 +0,0 @@
pub mod api;
pub mod shared;
pub mod websocket;

View File

@@ -1,53 +0,0 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
Crypto,
}
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto,
}
impl_from_enum!(
types::Exchange,
Exchange,
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto
);
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,
Inactive,
}
impl From<Status> for bool {
fn from(status: Status) -> Self {
match status {
Status::Active => true,
Status::Inactive => false,
}
}
}

View File

@@ -1,10 +0,0 @@
pub mod asset;
pub mod mode;
pub mod news;
pub mod order;
pub mod sort;
pub mod source;
pub use mode::Mode;
pub use sort::Sort;
pub use source::Source;

View File

@@ -1,33 +0,0 @@
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter},
str::FromStr,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Mode {
Live,
Paper,
}
impl FromStr for Mode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"live" => Ok(Self::Live),
"paper" => Ok(Self::Paper),
_ => Err(format!("Unknown mode: {s}")),
}
}
}
impl Display for Mode {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Self::Live => write!(f, "live"),
Self::Paper => write!(f, "paper"),
}
}
}

View File

@@ -1,26 +0,0 @@
use lazy_static::lazy_static;
use regex::Regex;
lazy_static! {
static ref RE_TAGS: Regex = Regex::new("<[^>]+>").unwrap();
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
}
pub fn strip(content: &str) -> String {
let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = content.trim();
content.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip() {
let content = "<p> <b> Hello, </b> <i> World! </i> </p>";
assert_eq!(strip(content), "Hello, World!");
}
}

View File

@@ -1,275 +0,0 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
#[serde(alias = "")]
Simple,
Bracket,
Oco,
Oto,
}
impl_from_enum!(types::order::Class, Class, Simple, Bracket, Oco, Oto);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Type {
Market,
Limit,
Stop,
StopLimit,
TrailingStop,
}
impl_from_enum!(
types::order::Type,
Type,
Market,
Limit,
Stop,
StopLimit,
TrailingStop
);
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Buy,
Sell,
}
impl_from_enum!(types::order::Side, Side, Buy, Sell);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TimeInForce {
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok,
}
impl_from_enum!(
types::order::TimeInForce,
TimeInForce,
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok
);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated,
}
impl_from_enum!(
types::order::Status,
Status,
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated
);
#[derive(Deserialize, Clone, Debug, PartialEq)]
#[allow(clippy::struct_field_names)]
pub struct Order {
pub id: Uuid,
pub client_order_id: Uuid,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub updated_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339")]
pub submitted_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub filled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub expired_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub cancel_requested_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub canceled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub failed_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub replaced_at: Option<OffsetDateTime>,
pub replaced_by: Option<Uuid>,
pub replaces: Option<Uuid>,
pub asset_id: Uuid,
pub symbol: String,
pub asset_class: super::asset::Class,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub notional: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub qty: Option<f64>,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub filled_qty: f64,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub filled_avg_price: Option<f64>,
pub order_class: Class,
#[serde(rename = "type")]
pub order_type: Type,
pub side: Side,
pub time_in_force: TimeInForce,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub limit_price: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub stop_price: Option<f64>,
pub status: Status,
pub extended_hours: bool,
pub legs: Option<Vec<Order>>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_percent: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_price: Option<f64>,
pub hwm: Option<f64>,
}
impl From<Order> for types::Order {
fn from(order: Order) -> Self {
Self {
id: order.id,
client_order_id: order.client_order_id,
time_submitted: order.submitted_at,
time_created: order.created_at,
time_updated: order.updated_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_filled: order.filled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_expired: order.expired_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_cancel_requested: order
.cancel_requested_at
.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_canceled: order.canceled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_failed: order.failed_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_replaced: order.replaced_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
replaced_by: order.replaced_by.unwrap_or_default(),
replaces: order.replaces.unwrap_or_default(),
symbol: order.symbol,
order_class: order.order_class.into(),
order_type: order.order_type.into(),
side: order.side.into(),
time_in_force: order.time_in_force.into(),
notional: order.notional.unwrap_or_default(),
qty: order.qty.unwrap_or_default(),
filled_qty: order.filled_qty,
filled_avg_price: order.filled_avg_price.unwrap_or_default(),
status: order.status.into(),
extended_hours: order.extended_hours,
limit_price: order.limit_price.unwrap_or_default(),
stop_price: order.stop_price.unwrap_or_default(),
trail_percent: order.trail_percent.unwrap_or_default(),
trail_price: order.trail_price.unwrap_or_default(),
hwm: order.hwm.unwrap_or_default(),
legs: order
.legs
.unwrap_or_default()
.into_iter()
.map(|order| order.id)
.collect(),
}
}
}
impl Order {
pub fn normalize(self) -> Vec<types::Order> {
let mut orders = vec![self.clone().into()];
if let Some(legs) = self.legs {
for leg in legs {
orders.extend(leg.normalize());
}
}
orders
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let order_template = Order {
id: Uuid::new_v4(),
client_order_id: Uuid::new_v4(),
created_at: OffsetDateTime::now_utc(),
updated_at: None,
submitted_at: OffsetDateTime::now_utc(),
filled_at: None,
expired_at: None,
cancel_requested_at: None,
canceled_at: None,
failed_at: None,
replaced_at: None,
replaced_by: None,
replaces: None,
asset_id: Uuid::new_v4(),
symbol: "AAPL".to_string(),
asset_class: super::super::asset::Class::UsEquity,
notional: None,
qty: None,
filled_qty: 0.0,
filled_avg_price: None,
order_class: Class::Simple,
order_type: Type::Market,
side: Side::Buy,
time_in_force: TimeInForce::Day,
limit_price: None,
stop_price: None,
status: Status::New,
extended_hours: false,
legs: None,
trail_percent: None,
trail_price: None,
hwm: None,
};
let mut order = order_template.clone();
order.legs = Some(vec![order_template.clone(), order_template.clone()]);
order.legs.as_mut().unwrap()[0].legs = Some(vec![order_template.clone()]);
let orders = order.normalize();
assert_eq!(orders.len(), 4);
}
}

View File

@@ -1,9 +0,0 @@
use serde::Serialize;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Sort {
Asc,
Desc,
}

View File

@@ -1,154 +0,0 @@
use serde::Deserialize;
use serde_with::serde_as;
use time::OffsetDateTime;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "sc", content = "sm")]
pub enum Status {
#[serde(rename = "2")]
#[serde(alias = "H")]
TradingHalt(String),
#[serde(rename = "3")]
Resume(String),
#[serde(rename = "5")]
PriceIndication(String),
#[serde(rename = "6")]
TradingRangeIndication(String),
#[serde(rename = "7")]
MarketImbalanceBuy(String),
#[serde(rename = "8")]
MarketImbalanceSell(String),
#[serde(rename = "9")]
MarketOnCloseImbalanceBuy(String),
#[serde(rename = "A")]
MarketOnCloseImbalanceSell(String),
#[serde(rename = "C")]
NoMarketImbalance(String),
#[serde(rename = "D")]
NoMarketOnCloseImbalance(String),
#[serde(rename = "E")]
ShortSaleRestriction(String),
#[serde(rename = "F")]
LimitUpLimitDown(String),
#[serde(rename = "Q")]
QuotationResumption(String),
#[serde(rename = "T")]
TradingResumption(String),
#[serde(rename = "P")]
VolatilityTradingPause(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "rc", content = "rm")]
pub enum Reason {
#[serde(rename = "D")]
NewsReleased(String),
#[serde(rename = "I")]
OrderImbalance(String),
#[serde(rename = "M")]
LimitUpLimitDown(String),
#[serde(rename = "P")]
NewsPending(String),
#[serde(rename = "X")]
Operational(String),
#[serde(rename = "Y")]
SubPennyTrading(String),
#[serde(rename = "1")]
MarketWideCircuitBreakerL1Breached(String),
#[serde(rename = "2")]
MarketWideCircuitBreakerL2Breached(String),
#[serde(rename = "3")]
MarketWideCircuitBreakerL3Breached(String),
#[serde(rename = "T1")]
HaltNewsPending(String),
#[serde(rename = "T2")]
HaltNewsDissemination(String),
#[serde(rename = "T5")]
SingleStockTradingPauseInAffect(String),
#[serde(rename = "T6")]
RegulatoryHaltExtraordinaryMarketActivity(String),
#[serde(rename = "T8")]
HaltETF(String),
#[serde(rename = "T12")]
TradingHaltedForInformationRequestedByNASDAQ(String),
#[serde(rename = "H4")]
HaltNonCompliance(String),
#[serde(rename = "H9")]
HaltFilingsNotCurrent(String),
#[serde(rename = "H10")]
HaltSECTradingSuspension(String),
#[serde(rename = "H11")]
HaltRegulatoryConcern(String),
#[serde(rename = "01")]
OperationsHaltContactMarketOperations(String),
#[serde(rename = "IPO1")]
IPOIssueNotYetTrading(String),
#[serde(rename = "M1")]
CorporateAction(String),
#[serde(rename = "M2")]
QuotationNotAvailable(String),
#[serde(rename = "LUDP")]
VolatilityTradingPause(String),
#[serde(rename = "LUDS")]
VolatilityTradingPauseStraddleCondition(String),
#[serde(rename = "MWC1")]
MarketWideCircuitBreakerHaltL1(String),
#[serde(rename = "MWC2")]
MarketWideCircuitBreakerHaltL2(String),
#[serde(rename = "MWC3")]
MarketWideCircuitBreakerHaltL3(String),
#[serde(rename = "MWC0")]
MarketWideCircuitBreakerHaltCarryOverFromPreviousDay(String),
#[serde(rename = "T3")]
NewsAndResumptionTimes(String),
#[serde(rename = "T7")]
SingleStockTradingPauseQuotationOnlyPeriod(String),
#[serde(rename = "R4")]
QualificationsIssuesReviewedResolvedQuotationsTradingToResume(String),
#[serde(rename = "R9")]
FilingRequirementsSatisfiedResolvedQuotationsTradingToResume(String),
#[serde(rename = "C3")]
IssuerNewsNotForthcomingQuotationsTradingToResume(String),
#[serde(rename = "C4")]
QualificationsHaltEndedMaintReqMetResume(String),
#[serde(rename = "C9")]
QualificationsHaltConcludedFilingsMetQuotesTradesToResume(String),
#[serde(rename = "C11")]
TradeHaltConcludedByOtherRegulatoryAuthQuotesTradesResume(String),
#[serde(rename = "R1")]
NewIssueAvailable(String),
#[serde(rename = "R")]
IssueAvailable(String),
#[serde(rename = "IPOQ")]
IPOSecurityReleasedForQuotation(String),
#[serde(rename = "IPOE")]
IPOSecurityPositioningWindowExtension(String),
#[serde(rename = "MWCQ")]
MarketWideCircuitBreakerResumption(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Tape {
A,
B,
C,
O,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[allow(clippy::struct_field_names)]
#[serde_as]
pub struct Message {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "S")]
pub symbol: String,
#[serde(flatten)]
pub status: Status,
#[serde(flatten)]
#[serde_as(as = "NoneAsEmptyString")]
pub reason: Option<Reason>,
#[serde(rename = "z")]
pub tape: Tape,
}

View File

@@ -1,23 +0,0 @@
use crate::utils::de;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(untagged)]
pub enum Message {
#[serde(rename_all = "camelCase")]
Market {
bars: Vec<String>,
updated_bars: Vec<String>,
statuses: Vec<String>,
trades: Option<Vec<String>>,
quotes: Option<Vec<String>>,
daily_bars: Option<Vec<String>>,
orderbooks: Option<Vec<String>>,
lulds: Option<Vec<String>>,
cancel_errors: Option<Vec<String>>,
},
News {
#[serde(deserialize_with = "de::add_slash_to_symbols")]
news: Vec<String>,
},
}

View File

@@ -1,53 +0,0 @@
pub mod incoming;
pub mod outgoing;
use crate::types::alpaca::websocket;
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde_json::{from_str, to_string};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Connected,
)) => {}
_ => panic!("Failed to connect to Alpaca websocket."),
}
sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message {
key: api_key,
secret: api_secret,
},
))
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Authenticated,
)) => {}
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}

View File

@@ -1,50 +0,0 @@
use crate::utils::ser;
use nonempty::NonEmpty;
use serde::Serialize;
#[derive(Serialize)]
#[serde(untagged)]
pub enum Market {
#[serde(rename_all = "camelCase")]
UsEquity {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
statuses: NonEmpty<String>,
},
#[serde(rename_all = "camelCase")]
Crypto {
bars: NonEmpty<String>,
updated_bars: NonEmpty<String>,
},
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Message {
Market(Market),
News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: NonEmpty<String>,
},
}
impl Message {
pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::UsEquity {
bars: symbols.clone(),
updated_bars: symbols.clone(),
statuses: symbols,
})
}
pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::Crypto {
bars: symbols.clone(),
updated_bars: symbols,
})
}
pub fn new_news(symbols: NonEmpty<String>) -> Self {
Self::News { news: symbols }
}
}

View File

@@ -1,8 +0,0 @@
pub mod auth;
pub mod data;
pub mod trading;
pub const ALPACA_US_EQUITY_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";

View File

@@ -1,22 +0,0 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Authorized,
Unauthorized,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Action {
#[serde(rename = "authenticate")]
Auth,
#[serde(rename = "listen")]
Subscribe,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub status: Status,
pub action: Action,
}

View File

@@ -1,16 +0,0 @@
pub mod auth;
pub mod order;
pub mod subscription;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "stream", content = "data")]
pub enum Message {
#[serde(rename = "authorization")]
Auth(auth::Message),
#[serde(rename = "listening")]
Subscription(subscription::Message),
#[serde(rename = "trade_updates")]
Order(order::Message),
}

View File

@@ -1,57 +0,0 @@
use crate::types::alpaca::shared::order;
use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime;
use uuid::Uuid;
pub use order::Order;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "event")]
pub enum Event {
New,
Fill {
timestamp: OffsetDateTime,
#[serde(deserialize_with = "deserialize_number_from_string")]
position_qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
price: f64,
},
PartialFill {
timestamp: OffsetDateTime,
#[serde(deserialize_with = "deserialize_number_from_string")]
position_qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
price: f64,
},
Canceled {
timestamp: OffsetDateTime,
},
Expired {
timestamp: OffsetDateTime,
},
DoneForDay,
Replaced {
timestamp: OffsetDateTime,
},
Rejected {
timestamp: OffsetDateTime,
},
PendingNew,
Stopped,
PendingCancel,
PendingReplace,
Calculated,
Suspended,
OrderReplaceRejected,
OrderCancelRejected,
}
#[derive(Deserialize, Debug, PartialEq)]
pub struct Message {
pub execution_id: Uuid,
pub order: Order,
#[serde(flatten)]
pub event: Event,
}

View File

@@ -1,6 +0,0 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub streams: Vec<String>,
}

View File

@@ -1,82 +0,0 @@
pub mod incoming;
pub mod outgoing;
use crate::types::alpaca::websocket;
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde_json::{from_str, to_string};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) {
sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Auth(
websocket::auth::Message {
key: api_key,
secret: api_secret,
},
))
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Binary(data) => {
let data = String::from_utf8(data).unwrap();
if from_str::<websocket::trading::incoming::Message>(&data).unwrap()
!= websocket::trading::incoming::Message::Auth(
websocket::trading::incoming::auth::Message {
status: websocket::trading::incoming::auth::Status::Authorized,
action: websocket::trading::incoming::auth::Action::Auth,
},
)
{
panic!("Failed to authenticate with Alpaca websocket.");
}
}
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}
pub async fn subscribe(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Subscribe {
data: websocket::trading::outgoing::subscribe::Message {
streams: vec![String::from("trade_updates")],
},
})
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Binary(data) => {
let data = String::from_utf8(data).unwrap();
if from_str::<websocket::trading::incoming::Message>(&data).unwrap()
!= websocket::trading::incoming::Message::Subscription(
websocket::trading::incoming::subscription::Message {
streams: vec![String::from("trade_updates")],
},
)
{
panic!("Failed to subscribe to Alpaca websocket.");
}
}
_ => panic!("Failed to subscribe to Alpaca websocket."),
};
}

View File

@@ -1,15 +0,0 @@
pub mod subscribe;
use crate::types::alpaca::websocket::auth;
use serde::Serialize;
#[derive(Serialize)]
#[serde(tag = "action")]
#[serde(rename_all = "snake_case")]
pub enum Message {
Auth(auth::Message),
#[serde(rename = "listen")]
Subscribe {
data: subscribe::Message,
},
}

View File

@@ -1,6 +0,0 @@
use serde::Serialize;
#[derive(Serialize)]
pub struct Message {
pub streams: Vec<String>,
}

View File

@@ -1,11 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)]
pub struct Backfill {
pub symbol: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time: OffsetDateTime,
pub fresh: bool,
}

View File

@@ -1,13 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::{Date, OffsetDateTime};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)]
pub struct Calendar {
#[serde(with = "clickhouse::serde::time::date")]
pub date: Date,
#[serde(with = "clickhouse::serde::time::datetime")]
pub open: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub close: OffsetDateTime,
}

View File

@@ -1,19 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News {
pub id: i64,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_created: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime,
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: String,
}

View File

@@ -1,107 +0,0 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Class {
Simple = 1,
Bracket = 2,
Oco = 3,
Oto = 4,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Type {
Market = 1,
Limit = 2,
Stop = 3,
StopLimit = 4,
TrailingStop = 5,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Side {
Buy = 1,
Sell = -1,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum TimeInForce {
Day = 1,
Gtc = 2,
Opg = 3,
Cls = 4,
Ioc = 5,
Fok = 6,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Status {
New = 1,
PartiallyFilled = 2,
Filled = 3,
DoneForDay = 4,
Canceled = 5,
Expired = 6,
Replaced = 7,
PendingCancel = 8,
PendingReplace = 9,
Accepted = 10,
PendingNew = 11,
AcceptedForBidding = 12,
Stopped = 13,
Rejected = 14,
Suspended = 15,
Calculated = 16,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
#[allow(clippy::struct_field_names)]
pub struct Order {
pub id: Uuid,
pub client_order_id: Uuid,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_submitted: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_created: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_filled: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_expired: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_cancel_requested: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_canceled: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_failed: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_replaced: OffsetDateTime,
pub replaced_by: Uuid,
pub replaces: Uuid,
pub symbol: String,
pub order_class: Class,
pub order_type: Type,
pub side: Side,
pub time_in_force: TimeInForce,
pub extended_hours: bool,
pub notional: f64,
pub qty: f64,
pub filled_qty: f64,
pub filled_avg_price: f64,
pub status: Status,
pub limit_price: f64,
pub stop_price: f64,
pub trail_percent: f64,
pub trail_price: f64,
pub hwm: f64,
pub legs: Vec<Uuid>,
}

View File

@@ -1,277 +0,0 @@
use super::Bar;
use crate::ta::{Bbands, Deriv, Ema, Macd, Obv, Pct, Rsi, Sma};
use clickhouse::Row;
use itertools::Itertools;
use rayon::scope;
use serde::{Deserialize, Serialize};
use std::num::NonZeroUsize;
use time::OffsetDateTime;
pub const HEAD_SIZE: usize = 72;
pub const FIELD_COUNT: usize = 33;
pub const NUMERICAL_FIELD_COUNT: usize = FIELD_COUNT - 2;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct IndicatedBar {
pub symbol: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time: OffsetDateTime,
pub hour: u8,
pub day: u8,
pub open: f64,
pub open_pct: f64,
pub high: f64,
pub high_pct: f64,
pub low: f64,
pub low_pct: f64,
pub close: f64,
pub close_pct: f64,
pub volume: f64,
pub volume_pct: f64,
pub trades: f64,
pub trades_pct: f64,
pub sma_3: f64,
pub sma_6: f64,
pub sma_12: f64,
pub sma_24: f64,
pub sma_48: f64,
pub sma_72: f64,
pub ema_3: f64,
pub ema_6: f64,
pub ema_12: f64,
pub ema_24: f64,
pub ema_48: f64,
pub ema_72: f64,
pub macd: f64,
pub macd_signal: f64,
pub obv: f64,
pub rsi: f64,
pub bbands_lower: f64,
pub bbands_mean: f64,
pub bbands_upper: f64,
}
#[allow(clippy::too_many_lines)]
fn _calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
let length = bars.len();
let (symbol, time, hour, day, open, high, low, close, volume, trades) = bars.iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(
mut symbol,
mut time,
mut hour,
mut day,
mut open,
mut high,
mut low,
mut close,
mut volume,
mut trades,
),
bar| {
symbol.push(bar.symbol.clone());
time.push(bar.time);
hour.push(bar.time.hour());
day.push(bar.time.day());
open.push(bar.open);
high.push(bar.high);
low.push(bar.low);
close.push(bar.close);
volume.push(bar.volume);
trades.push(bar.trades as f64);
(
symbol, time, hour, day, open, high, low, close, volume, trades,
)
},
);
let mut close_deriv = Vec::with_capacity(length);
let mut sma_3 = Vec::with_capacity(length);
let mut sma_6 = Vec::with_capacity(length);
let mut sma_12 = Vec::with_capacity(length);
let mut sma_24 = Vec::with_capacity(length);
let mut sma_48 = Vec::with_capacity(length);
let mut sma_72 = Vec::with_capacity(length);
let mut ema_3 = Vec::with_capacity(length);
let mut ema_6 = Vec::with_capacity(length);
let mut ema_12 = Vec::with_capacity(length);
let mut ema_24 = Vec::with_capacity(length);
let mut ema_48 = Vec::with_capacity(length);
let mut ema_72 = Vec::with_capacity(length);
let mut macd = Vec::with_capacity(length);
let mut macd_signal = Vec::with_capacity(length);
let mut obv = Vec::with_capacity(length);
let mut rsi = Vec::with_capacity(length);
let mut bbands_upper = Vec::with_capacity(length);
let mut bbands_mean = Vec::with_capacity(length);
let mut bbands_lower = Vec::with_capacity(length);
scope(|s| {
s.spawn(|_| close_deriv.extend(close.iter().deriv()));
s.spawn(|_| sma_3.extend(close.iter().sma(NonZeroUsize::new(3).unwrap())));
s.spawn(|_| sma_6.extend(close.iter().sma(NonZeroUsize::new(6).unwrap())));
s.spawn(|_| sma_12.extend(close.iter().sma(NonZeroUsize::new(12).unwrap())));
s.spawn(|_| sma_24.extend(close.iter().sma(NonZeroUsize::new(24).unwrap())));
s.spawn(|_| sma_48.extend(close.iter().sma(NonZeroUsize::new(48).unwrap())));
s.spawn(|_| sma_72.extend(close.iter().sma(NonZeroUsize::new(72).unwrap())));
s.spawn(|_| ema_3.extend(close.iter().ema(NonZeroUsize::new(3).unwrap())));
s.spawn(|_| ema_6.extend(close.iter().ema(NonZeroUsize::new(6).unwrap())));
s.spawn(|_| ema_12.extend(close.iter().ema(NonZeroUsize::new(12).unwrap())));
s.spawn(|_| ema_24.extend(close.iter().ema(NonZeroUsize::new(24).unwrap())));
s.spawn(|_| ema_48.extend(close.iter().ema(NonZeroUsize::new(48).unwrap())));
s.spawn(|_| ema_72.extend(close.iter().ema(NonZeroUsize::new(72).unwrap())));
s.spawn(|_| {
close
.iter()
.macd(
NonZeroUsize::new(12).unwrap(),
NonZeroUsize::new(26).unwrap(),
NonZeroUsize::new(9).unwrap(),
)
.for_each(|(macd_val, signal_val)| {
macd.push(macd_val);
macd_signal.push(signal_val);
});
});
s.spawn(|_| {
obv.extend(bars.iter().map(|bar| (bar.close, bar.volume)).obv());
});
s.spawn(|_: &_| {
rsi.extend(close.iter().rsi(NonZeroUsize::new(14).unwrap()));
});
s.spawn(|_| {
close
.iter()
.bbands(NonZeroUsize::new(20).unwrap(), 2.0)
.for_each(|(upper, mean, lower)| {
bbands_upper.push(upper);
bbands_mean.push(mean);
bbands_lower.push(lower);
});
})
});
let mut open_pct = Vec::with_capacity(length);
let mut high_pct = Vec::with_capacity(length);
let mut low_pct = Vec::with_capacity(length);
let mut close_pct = Vec::with_capacity(length);
let mut volume_pct = Vec::with_capacity(length);
let mut trades_pct = Vec::with_capacity(length);
scope(|s| {
s.spawn(|_| open_pct.extend(open.iter().pct()));
s.spawn(|_| high_pct.extend(high.iter().pct()));
s.spawn(|_| low_pct.extend(low.iter().pct()));
s.spawn(|_| close_pct.extend(close.iter().pct()));
s.spawn(|_| volume_pct.extend(volume.iter().pct()));
s.spawn(|_| trades_pct.extend(trades.iter().pct()));
});
bars.iter()
.enumerate()
.map(|(i, _)| IndicatedBar {
symbol: symbol[i].clone(),
time: time[i],
hour: hour[i],
day: day[i],
open: open[i],
open_pct: open_pct[i],
high: high[i],
high_pct: high_pct[i],
low: low[i],
low_pct: low_pct[i],
close: close[i],
close_pct: close_pct[i],
volume: volume[i],
volume_pct: volume_pct[i],
trades: trades[i],
trades_pct: trades_pct[i],
sma_3: sma_3[i],
sma_6: sma_6[i],
sma_12: sma_12[i],
sma_24: sma_24[i],
sma_48: sma_48[i],
sma_72: sma_72[i],
ema_3: ema_3[i],
ema_6: ema_6[i],
ema_12: ema_12[i],
ema_24: ema_24[i],
ema_48: ema_48[i],
ema_72: ema_72[i],
macd: macd[i],
macd_signal: macd_signal[i],
obv: obv[i],
rsi: rsi[i],
bbands_lower: bbands_lower[i],
bbands_mean: bbands_mean[i],
bbands_upper: bbands_upper[i],
})
.collect()
}
pub fn calculate_indicators<I>(bars: I) -> Vec<Vec<IndicatedBar>>
where
I: IntoIterator<Item = Bar>,
{
bars.into_iter()
.filter(|bar| {
bar.open > 0.0
&& bar.high > 0.0
&& bar.low > 0.0
&& bar.close > 0.0
&& bar.volume > 0.0
&& bar.trades > 0
})
.sorted_by_key(|bar| (bar.symbol.clone(), bar.time))
.group_by(|bar| bar.symbol.clone())
.into_iter()
.map(|(_, group)| _calculate_indicators(&group.collect::<Vec<_>>()))
.collect::<Vec<_>>()
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{
distributions::{Distribution, Uniform},
Rng,
};
#[test]
fn test_calculate_indicators() {
let length = 1_000_000;
let mut rng = rand::thread_rng();
let uniform = Uniform::new(1.0, 100.0);
let mut bars = Vec::with_capacity(length);
for _ in 0..length {
bars.push(Bar {
symbol: "AAPL".to_string(),
time: OffsetDateTime::now_utc(),
open: uniform.sample(&mut rng),
high: uniform.sample(&mut rng),
low: uniform.sample(&mut rng),
close: uniform.sample(&mut rng),
volume: uniform.sample(&mut rng),
trades: rng.gen_range(1..100),
});
}
let indicated_bars = calculate_indicators(bars);
assert_eq!(indicated_bars[0].len(), length);
}
}

View File

@@ -1,8 +0,0 @@
use backoff::ExponentialBackoff;
pub fn infinite() -> ExponentialBackoff {
ExponentialBackoff {
max_elapsed_time: None,
..ExponentialBackoff::default()
}
}

View File

@@ -1,192 +0,0 @@
use lazy_static::lazy_static;
use regex::Regex;
use serde::{
de::{self, SeqAccess, Visitor},
Deserializer,
};
use std::fmt;
use time::{format_description::OwnedFormatItem, macros::format_description, Time};
lazy_static! {
// This *will* break in the future if a crypto pair with one letter is added
static ref RE_SLASH: Regex = Regex::new(r"^(.{2,})(BTC|USD.?)$").unwrap();
static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into();
}
fn add_slash(pair: &str) -> String {
RE_SLASH.captures(pair).map_or_else(
|| pair.to_string(),
|caps| format!("{}/{}", &caps[1], &caps[2]),
)
}
pub fn add_slash_to_symbol<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
struct StringVisitor;
impl<'de> Visitor<'de> for StringVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string without a slash")
}
fn visit_str<E>(self, pair: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(add_slash(pair))
}
fn visit_string<E>(self, pair: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(add_slash(&pair))
}
}
deserializer.deserialize_string(StringVisitor)
}
pub fn add_slash_to_symbols<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
struct VecStringVisitor;
impl<'de> Visitor<'de> for VecStringVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a list of strings without a slash")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Vec<String>, A::Error>
where
A: SeqAccess<'de>,
{
let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(25));
while let Some(value) = seq.next_element::<String>()? {
vec.push(add_slash(&value));
}
Ok(vec)
}
}
deserializer.deserialize_seq(VecStringVisitor)
}
pub fn human_time_hh_mm<'de, D>(deserializer: D) -> Result<Time, D::Error>
where
D: Deserializer<'de>,
{
struct TimeVisitor;
impl<'de> Visitor<'de> for TimeVisitor {
type Value = Time;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string in the format HH:MM")
}
fn visit_str<E>(self, time: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Time::parse(time, &FMT_HH_MM).map_err(|e| de::Error::custom(e.to_string()))
}
}
deserializer.deserialize_str(TimeVisitor)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use serde_test::{assert_de_tokens, Token};
#[test]
fn test_add_slash() {
assert_eq!(add_slash("BTCUSD"), "BTC/USD");
}
#[test]
fn test_add_slash_skip() {
assert_eq!(add_slash("ABTC"), "ABTC");
}
#[derive(PartialEq, Debug, Deserialize)]
#[serde(transparent)]
struct AddSlashToSymbol {
#[serde(deserialize_with = "add_slash_to_symbol")]
symbol: String,
}
#[test]
fn test_add_slash_to_symbol() {
assert_de_tokens::<AddSlashToSymbol>(
&AddSlashToSymbol {
symbol: String::from("BTC/USD"),
},
&[Token::Str("BTCUSD")],
);
}
#[test]
fn test_add_slash_to_symbol_skip() {
assert_de_tokens::<AddSlashToSymbol>(
&AddSlashToSymbol {
symbol: String::from("ABTC"),
},
&[Token::Str("ABTC")],
);
}
#[derive(PartialEq, Debug, Deserialize)]
#[serde(transparent)]
struct AddSlashToSymbols {
#[serde(deserialize_with = "add_slash_to_symbols")]
symbols: Vec<String>,
}
#[test]
fn test_add_slash_to_symbols() {
assert_de_tokens::<AddSlashToSymbols>(
&AddSlashToSymbols {
symbols: vec![
String::from("BTC/USD"),
String::from("ETH/USD"),
String::from("ABTC"),
],
},
&[
Token::Seq { len: Some(3) },
Token::Str("BTCUSD"),
Token::Str("ETHUSD"),
Token::Str("ABTC"),
Token::SeqEnd,
],
);
}
#[derive(PartialEq, Debug, Deserialize)]
#[serde(transparent)]
struct HumanTime {
#[serde(deserialize_with = "human_time_hh_mm")]
time: Time,
}
#[test]
fn test_human_time_hh_mm() {
assert_de_tokens::<HumanTime>(
&HumanTime {
time: Time::from_hms(12, 34, 0).unwrap(),
},
&[Token::Str("12:34")],
);
}
}

View File

@@ -1,7 +0,0 @@
pub mod backoff;
pub mod de;
pub mod r#macro;
pub mod ser;
pub mod time;
pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND};

View File

@@ -1,281 +0,0 @@
use serde::{ser::SerializeSeq, Serializer};
use std::time::Duration;
pub fn timeframe<S>(timeframe: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let secs = timeframe.as_secs();
if secs < 60 || secs % 60 != 0 {
return Err(serde::ser::Error::custom("Invalid timeframe duration"));
}
let mins = secs / 60;
if mins < 60 {
return serializer.serialize_str(&format!("{mins}Min"));
}
if mins % 60 != 0 {
return Err(serde::ser::Error::custom("Invalid timeframe duration"));
}
let hours = mins / 60;
if hours < 24 {
return serializer.serialize_str(&format!("{hours}Hour"));
}
if hours % 24 != 0 {
return Err(serde::ser::Error::custom("Invalid timeframe duration"));
}
let days = hours / 24;
if days == 1 {
return serializer.serialize_str("1Day");
}
if days == 7 {
return serializer.serialize_str("1Week");
}
if days < 30 || days % 30 != 0 {
return Err(serde::ser::Error::custom("Invalid timeframe duration"));
}
let months = days / 30;
if [1, 2, 3, 4, 6, 12].contains(&months) {
return serializer.serialize_str(&format!("{months}Month"));
};
Err(serde::ser::Error::custom("Invalid timeframe duration"))
}
fn remove_slash(pair: &str) -> String {
pair.replace('/', "")
}
pub fn join_symbols<S>(symbols: &[String], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = symbols.join(",");
serializer.serialize_str(&string)
}
pub fn join_symbols_option<S>(
symbols: &Option<Vec<String>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match symbols {
Some(symbols) => join_symbols(symbols, serializer),
None => serializer.serialize_none(),
}
}
pub fn remove_slash_from_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
I: IntoIterator<Item = &'a String>,
{
let symbols = symbols
.into_iter()
.map(|pair| remove_slash(pair))
.collect::<Vec<_>>();
let mut seq = serializer.serialize_seq(Some(symbols.len()))?;
for symbol in symbols {
seq.serialize_element(&symbol)?;
}
seq.end()
}
pub fn remove_slash_and_join_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
I: IntoIterator<Item = &'a String>,
{
let symbols = symbols
.into_iter()
.map(|symbol| remove_slash(symbol))
.collect::<Vec<_>>();
join_symbols(&symbols, serializer)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
use serde_test::{assert_ser_tokens, assert_ser_tokens_error, Token};
#[derive(Serialize)]
#[serde(transparent)]
struct Timeframe {
#[serde(serialize_with = "timeframe")]
duration: Duration,
}
#[test]
fn test_timeframe_30_mins() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 30),
};
assert_ser_tokens(&timeframe, &[Token::Str("30Min")]);
}
#[test]
fn test_timeframe_2_hours() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 60 * 2),
};
assert_ser_tokens(&timeframe, &[Token::Str("2Hour")]);
}
#[test]
fn test_timeframe_1_day() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 60 * 24),
};
assert_ser_tokens(&timeframe, &[Token::Str("1Day")]);
}
#[test]
fn test_timeframe_1_week() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 60 * 24 * 7),
};
assert_ser_tokens(&timeframe, &[Token::Str("1Week")]);
}
#[test]
fn test_timeframe_6_months() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 60 * 24 * 30 * 6),
};
assert_ser_tokens(&timeframe, &[Token::Str("6Month")]);
}
#[test]
fn test_timeframe_invalid_1_second() {
let timeframe = Timeframe {
duration: Duration::from_secs(1),
};
assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration");
}
#[test]
fn test_timeframe_invalid_61_seconds() {
let timeframe = Timeframe {
duration: Duration::from_secs(61),
};
assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration");
}
#[test]
fn test_timeframe_invalid_6_days() {
let timeframe = Timeframe {
duration: Duration::from_secs(60 * 60 * 24 * 6),
};
assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration");
}
#[test]
fn test_remove_slash() {
let pair = "BTC/USDT";
assert_eq!(remove_slash(pair), "BTCUSDT");
}
#[derive(Serialize)]
#[serde(transparent)]
struct JoinSymbols {
#[serde(serialize_with = "join_symbols")]
symbols: Vec<String>,
}
#[test]
fn test_join_symbols() {
let symbols = JoinSymbols {
symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")],
};
assert_ser_tokens(&symbols, &[Token::Str("BTC/USD,ETH/USD")]);
}
#[derive(Serialize)]
#[serde(transparent)]
struct JoinSymbolsOption {
#[serde(serialize_with = "join_symbols_option")]
symbols: Option<Vec<String>>,
}
#[test]
fn test_join_symbols_option_some() {
let symbols = JoinSymbolsOption {
symbols: Some(vec![String::from("BTC/USD"), String::from("ETH/USD")]),
};
assert_ser_tokens(&symbols, &[Token::Str("BTC/USD,ETH/USD")]);
}
#[test]
fn test_join_symbols_option_none() {
let symbols = JoinSymbolsOption { symbols: None };
assert_ser_tokens(&symbols, &[Token::None]);
}
#[derive(Serialize)]
#[serde(transparent)]
struct RemoveSlashFromSymbols {
#[serde(serialize_with = "remove_slash_from_symbols")]
symbols: Vec<String>,
}
#[test]
fn test_remove_slash_from_symbols() {
let symbols = RemoveSlashFromSymbols {
symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")],
};
assert_ser_tokens(
&symbols,
&[
Token::Seq { len: Some(2) },
Token::Str("BTCUSD"),
Token::Str("ETHUSD"),
Token::SeqEnd,
],
);
}
#[derive(Serialize)]
#[serde(transparent)]
struct RemoveSlashAndJoinSymbols {
#[serde(serialize_with = "remove_slash_and_join_symbols")]
symbols: Vec<String>,
}
#[test]
fn test_remove_slash_and_join_symbols() {
let symbols = RemoveSlashAndJoinSymbols {
symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")],
};
assert_ser_tokens(&symbols, &[Token::Str("BTCUSD,ETHUSD")]);
}
}

Some files were not shown because too many files have changed in this diff Show More