Update and fix bugs

It's good to be back

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-05-10 17:49:16 +01:00
parent d7e9350257
commit 90b7f10a77
9 changed files with 430 additions and 427 deletions

597
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,9 +24,9 @@ codegen-units = 1
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
axum = "0.7.4" axum = "0.7.5"
dotenv = "0.15.0" dotenv = "0.15.0"
tokio = { version = "1.32.0", features = [ tokio = { version = "1.37.0", features = [
"macros", "macros",
"rt-multi-thread", "rt-multi-thread",
] } ] }
@@ -34,29 +34,29 @@ tokio-tungstenite = { version = "0.21.0", features = [
"tokio-native-tls", "tokio-native-tls",
"native-tls", "native-tls",
] } ] }
log = "0.4.20" log = "0.4.21"
log4rs = "1.2.0" log4rs = "1.3.0"
serde = "1.0.188" serde = "1.0.201"
serde_json = "1.0.105" serde_json = "1.0.117"
serde_repr = "0.1.18" serde_repr = "0.1.19"
serde_with = "3.6.1" serde_with = "3.8.1"
serde-aux = "4.4.0" serde-aux = "4.5.0"
futures-util = "0.3.28" futures-util = "0.3.30"
reqwest = { version = "0.12.0", features = [ reqwest = { version = "0.12.4", features = [
"json", "json",
] } ] }
http = "1.0.0" http = "1.1.0"
governor = "0.6.0" governor = "0.6.3"
clickhouse = { version = "0.11.6", features = [ clickhouse = { version = "0.11.6", features = [
"watch", "watch",
"time", "time",
"uuid", "uuid",
] } ] }
uuid = { version = "1.6.1", features = [ uuid = { version = "1.8.0", features = [
"serde", "serde",
"v4", "v4",
] } ] }
time = { version = "0.3.31", features = [ time = { version = "0.3.36", features = [
"serde", "serde",
"serde-well-known", "serde-well-known",
"serde-human-readable", "serde-human-readable",
@@ -67,19 +67,20 @@ time = { version = "0.3.31", features = [
backoff = { version = "0.4.0", features = [ backoff = { version = "0.4.0", features = [
"tokio", "tokio",
] } ] }
regex = "1.10.3" regex = "1.10.4"
async-trait = "0.1.77" async-trait = "0.1.80"
itertools = "0.12.1" itertools = "0.12.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [ nonempty = { version = "0.10.0", features = [
"serialize", "serialize",
] } ] }
rand = "0.8.5" rand = "0.8.5"
rayon = "1.9.0" rayon = "1.10.0"
burn = { version = "0.12.1", features = [ burn = { version = "0.13.2", features = [
"wgpu", "wgpu",
"cuda", "cuda",
"tui", "tui",
"metrics",
"train", "train",
] } ] }

View File

@@ -8,7 +8,10 @@ use qrust::{
types::{alpaca::websocket, News}, types::{alpaca::websocket, News},
utils::ONE_SECOND, utils::ONE_SECOND,
}; };
use std::{collections::HashMap, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
pub struct Handler { pub struct Handler {
@@ -37,6 +40,7 @@ impl super::Handler for Handler {
unreachable!() unreachable!()
}; };
let symbols = symbols.into_iter().collect::<HashSet<_>>();
let mut state = state.write().await; let mut state = state.write().await;
let newly_subscribed = state let newly_subscribed = state

View File

@@ -5,6 +5,7 @@ use burn::{
}; };
use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Clone, Debug)]
pub struct BarWindowBatcher<B: Backend> { pub struct BarWindowBatcher<B: Backend> {
pub device: B::Device, pub device: B::Device,
} }

View File

@@ -1,12 +1,11 @@
use crate::types::{ use crate::types::{
ta::{calculate_indicators, HEAD_SIZE, NUMERICAL_FIELD_COUNT}, ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT},
Bar, Bar,
}; };
use burn::{ use burn::{
data::dataset::{transform::ComposedDataset, Dataset}, data::dataset::{transform::ComposedDataset, Dataset},
tensor::Data, tensor::Data,
}; };
use itertools::Itertools;
pub const WINDOW_SIZE: usize = 48; pub const WINDOW_SIZE: usize = 48;
@@ -28,8 +27,11 @@ struct SingleSymbolDataset {
impl SingleSymbolDataset { impl SingleSymbolDataset {
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
pub fn new(bars: Vec<Bar>) -> Self { pub fn new(bars: Vec<IndicatedBar>) -> Self {
let bars = calculate_indicators(&bars); if !bars.is_empty() {
let symbol = &bars[0].symbol;
assert!(bars.iter().all(|bar| bar.symbol == *symbol));
}
let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold( let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold(
( (
@@ -54,48 +56,27 @@ impl SingleSymbolDataset {
(bar[0].volume_pct as f32).min(f32::MAX), (bar[0].volume_pct as f32).min(f32::MAX),
bar[0].trades as f32, bar[0].trades as f32,
(bar[0].trades_pct as f32).min(f32::MAX), (bar[0].trades_pct as f32).min(f32::MAX),
bar[0].close_deriv as f32,
(bar[0].close_deriv_pct as f32).min(f32::MAX),
bar[0].sma_3 as f32, bar[0].sma_3 as f32,
(bar[0].sma_3_pct as f32).min(f32::MAX),
bar[0].sma_6 as f32, bar[0].sma_6 as f32,
(bar[0].sma_6_pct as f32).min(f32::MAX),
bar[0].sma_12 as f32, bar[0].sma_12 as f32,
(bar[0].sma_12_pct as f32).min(f32::MAX),
bar[0].sma_24 as f32, bar[0].sma_24 as f32,
(bar[0].sma_24_pct as f32).min(f32::MAX),
bar[0].sma_48 as f32, bar[0].sma_48 as f32,
(bar[0].sma_48_pct as f32).min(f32::MAX),
bar[0].sma_72 as f32, bar[0].sma_72 as f32,
(bar[0].sma_72_pct as f32).min(f32::MAX),
bar[0].ema_3 as f32, bar[0].ema_3 as f32,
(bar[0].ema_3_pct as f32).min(f32::MAX),
bar[0].ema_6 as f32, bar[0].ema_6 as f32,
(bar[0].ema_6_pct as f32).min(f32::MAX),
bar[0].ema_12 as f32, bar[0].ema_12 as f32,
(bar[0].ema_12_pct as f32).min(f32::MAX),
bar[0].ema_24 as f32, bar[0].ema_24 as f32,
(bar[0].ema_24_pct as f32).min(f32::MAX),
bar[0].ema_48 as f32, bar[0].ema_48 as f32,
(bar[0].ema_48_pct as f32).min(f32::MAX),
bar[0].ema_72 as f32, bar[0].ema_72 as f32,
(bar[0].ema_72_pct as f32).min(f32::MAX),
bar[0].macd as f32, bar[0].macd as f32,
(bar[0].macd_pct as f32).min(f32::MAX),
bar[0].macd_signal as f32, bar[0].macd_signal as f32,
(bar[0].macd_signal_pct as f32).min(f32::MAX),
bar[0].obv as f32, bar[0].obv as f32,
(bar[0].obv_pct as f32).min(f32::MAX),
bar[0].rsi as f32, bar[0].rsi as f32,
(bar[0].rsi_pct as f32).min(f32::MAX),
bar[0].bbands_lower as f32, bar[0].bbands_lower as f32,
(bar[0].bbands_lower_pct as f32).min(f32::MAX),
bar[0].bbands_mean as f32, bar[0].bbands_mean as f32,
(bar[0].bbands_mean_pct as f32).min(f32::MAX),
bar[0].bbands_upper as f32, bar[0].bbands_upper as f32,
(bar[0].bbands_upper_pct as f32).min(f32::MAX),
]); ]);
targets.push(bar[1].close as f32); targets.push(bar[1].close_pct as f32);
(hours, days, numerical, targets) (hours, days, numerical, targets)
}, },
); );
@@ -141,13 +122,10 @@ pub struct MultipleSymbolDataset {
impl MultipleSymbolDataset { impl MultipleSymbolDataset {
pub fn new(bars: Vec<Bar>) -> Self { pub fn new(bars: Vec<Bar>) -> Self {
let groups = bars let groups = calculate_indicators(bars)
.into_iter() .into_iter()
.group_by(|bar| bar.symbol.clone())
.into_iter()
.map(|(_, group)| group.collect::<Vec<_>>())
.map(SingleSymbolDataset::new) .map(SingleSymbolDataset::new)
.collect(); .collect::<Vec<_>>();
Self { Self {
composed_dataset: ComposedDataset::new(groups), composed_dataset: ComposedDataset::new(groups),
@@ -174,7 +152,7 @@ mod tests {
}; };
use time::OffsetDateTime; use time::OffsetDateTime;
fn generate_random_dataset(length: usize) -> SingleSymbolDataset { fn generate_random_dataset(length: usize) -> MultipleSymbolDataset {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let uniform = Uniform::new(1.0, 100.0); let uniform = Uniform::new(1.0, 100.0);
let mut bars = Vec::with_capacity(length); let mut bars = Vec::with_capacity(length);
@@ -192,7 +170,7 @@ mod tests {
}); });
} }
SingleSymbolDataset::new(bars) MultipleSymbolDataset::new(bars)
} }
#[test] #[test]
@@ -210,8 +188,6 @@ mod tests {
let item = dataset.get(0).unwrap(); let item = dataset.get(0).unwrap();
assert_eq!(item.hours.shape.dims, [WINDOW_SIZE]);
assert_eq!(item.days.shape.dims, [WINDOW_SIZE]);
assert_eq!( assert_eq!(
item.numerical.shape.dims, item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT] [WINDOW_SIZE, NUMERICAL_FIELD_COUNT]
@@ -226,8 +202,6 @@ mod tests {
let item = dataset.get(dataset.len() - 1).unwrap(); let item = dataset.get(dataset.len() - 1).unwrap();
assert_eq!(item.hours.shape.dims, [WINDOW_SIZE]);
assert_eq!(item.days.shape.dims, [WINDOW_SIZE]);
assert_eq!( assert_eq!(
item.numerical.shape.dims, item.numerical.shape.dims,
[WINDOW_SIZE, NUMERICAL_FIELD_COUNT] [WINDOW_SIZE, NUMERICAL_FIELD_COUNT]

View File

@@ -4,7 +4,7 @@ use burn::{
config::Config, config::Config,
module::Module, module::Module,
nn::{ nn::{
loss::{MSELoss, Reduction}, loss::{MseLoss, Reduction},
Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig, Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig,
}, },
tensor::{ tensor::{
@@ -79,13 +79,13 @@ impl<B: Backend> Model<B> {
let x = Tensor::cat(vec![hour, day, numerical], 2); let x = Tensor::cat(vec![hour, day, numerical], 2);
let (x, _) = self.lstm_1.forward(x, None); let (_, x) = self.lstm_1.forward(x, None);
let x = self.dropout_1.forward(x); let x = self.dropout_1.forward(x);
let (x, _) = self.lstm_2.forward(x, None); let (_, x) = self.lstm_2.forward(x, None);
let x = self.dropout_2.forward(x); let x = self.dropout_2.forward(x);
let (x, _) = self.lstm_3.forward(x, None); let (_, x) = self.lstm_3.forward(x, None);
let x = self.dropout_3.forward(x); let x = self.dropout_3.forward(x);
let (x, _) = self.lstm_4.forward(x, None); let (_, x) = self.lstm_4.forward(x, None);
let x = self.dropout_4.forward(x); let x = self.dropout_4.forward(x);
let [batch_size, window_size, features] = x.shape().dims; let [batch_size, window_size, features] = x.shape().dims;
@@ -104,7 +104,7 @@ impl<B: Backend> Model<B> {
target: Tensor<B, 2>, target: Tensor<B, 2>,
) -> RegressionOutput<B> { ) -> RegressionOutput<B> {
let output = self.forward(hour, day, numerical); let output = self.forward(hour, day, numerical);
let loss = MSELoss::new().forward(output.clone(), target.clone(), Reduction::Mean); let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean);
RegressionOutput::new(loss, output, target) RegressionOutput::new(loss, output, target)
} }

View File

@@ -6,13 +6,13 @@ use serde::Deserialize;
pub enum Message { pub enum Message {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Market { Market {
trades: Vec<String>,
quotes: Vec<String>,
bars: Vec<String>, bars: Vec<String>,
updated_bars: Vec<String>, updated_bars: Vec<String>,
daily_bars: Vec<String>, statuses: Vec<String>,
trades: Option<Vec<String>>,
quotes: Option<Vec<String>>,
daily_bars: Option<Vec<String>>,
orderbooks: Option<Vec<String>>, orderbooks: Option<Vec<String>>,
statuses: Option<Vec<String>>,
lulds: Option<Vec<String>>, lulds: Option<Vec<String>>,
cancel_errors: Option<Vec<String>>, cancel_errors: Option<Vec<String>>,
}, },

View File

@@ -1,15 +1,21 @@
use super::Bar; use super::Bar;
use crate::ta::{Bbands, Deriv, Ema, Macd, Obv, Pct, Rsi, Sma}; use crate::ta::{Bbands, Deriv, Ema, Macd, Obv, Pct, Rsi, Sma};
use clickhouse::Row;
use itertools::Itertools;
use rayon::scope; use rayon::scope;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use time::OffsetDateTime;
pub const HEAD_SIZE: usize = 72; pub const HEAD_SIZE: usize = 72;
pub const FIELD_COUNT: usize = 54; pub const FIELD_COUNT: usize = 33;
pub const NUMERICAL_FIELD_COUNT: usize = FIELD_COUNT - 2; pub const NUMERICAL_FIELD_COUNT: usize = FIELD_COUNT - 2;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct IndicatedBar { pub struct IndicatedBar {
pub symbol: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time: OffsetDateTime,
pub hour: u8, pub hour: u8,
pub day: u8, pub day: u8,
pub open: f64, pub open: f64,
@@ -24,53 +30,32 @@ pub struct IndicatedBar {
pub volume_pct: f64, pub volume_pct: f64,
pub trades: f64, pub trades: f64,
pub trades_pct: f64, pub trades_pct: f64,
pub close_deriv: f64,
pub close_deriv_pct: f64,
pub sma_3: f64, pub sma_3: f64,
pub sma_3_pct: f64,
pub sma_6: f64, pub sma_6: f64,
pub sma_6_pct: f64,
pub sma_12: f64, pub sma_12: f64,
pub sma_12_pct: f64,
pub sma_24: f64, pub sma_24: f64,
pub sma_24_pct: f64,
pub sma_48: f64, pub sma_48: f64,
pub sma_48_pct: f64,
pub sma_72: f64, pub sma_72: f64,
pub sma_72_pct: f64,
pub ema_3: f64, pub ema_3: f64,
pub ema_3_pct: f64,
pub ema_6: f64, pub ema_6: f64,
pub ema_6_pct: f64,
pub ema_12: f64, pub ema_12: f64,
pub ema_12_pct: f64,
pub ema_24: f64, pub ema_24: f64,
pub ema_24_pct: f64,
pub ema_48: f64, pub ema_48: f64,
pub ema_48_pct: f64,
pub ema_72: f64, pub ema_72: f64,
pub ema_72_pct: f64,
pub macd: f64, pub macd: f64,
pub macd_pct: f64,
pub macd_signal: f64, pub macd_signal: f64,
pub macd_signal_pct: f64,
pub obv: f64, pub obv: f64,
pub obv_pct: f64,
pub rsi: f64, pub rsi: f64,
pub rsi_pct: f64,
pub bbands_lower: f64, pub bbands_lower: f64,
pub bbands_lower_pct: f64,
pub bbands_mean: f64, pub bbands_mean: f64,
pub bbands_mean_pct: f64,
pub bbands_upper: f64, pub bbands_upper: f64,
pub bbands_upper_pct: f64,
} }
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> { fn _calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
let length = bars.len(); let length = bars.len();
let (hour, day, open, high, low, close, volume, trades) = bars.iter().fold( let (symbol, time, hour, day, open, high, low, close, volume, trades) = bars.iter().fold(
( (
Vec::with_capacity(length), Vec::with_capacity(length),
Vec::with_capacity(length), Vec::with_capacity(length),
@@ -80,9 +65,24 @@ pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
Vec::with_capacity(length), Vec::with_capacity(length),
Vec::with_capacity(length), Vec::with_capacity(length),
Vec::with_capacity(length), Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(
mut symbol,
mut time,
mut hour,
mut day,
mut open,
mut high,
mut low,
mut close,
mut volume,
mut trades,
), ),
|(mut hour, mut day, mut open, mut high, mut low, mut close, mut volume, mut trades),
bar| { bar| {
symbol.push(bar.symbol.clone());
time.push(bar.time);
hour.push(bar.time.hour()); hour.push(bar.time.hour());
day.push(bar.time.day()); day.push(bar.time.day());
open.push(bar.open); open.push(bar.open);
@@ -91,7 +91,9 @@ pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
close.push(bar.close); close.push(bar.close);
volume.push(bar.volume); volume.push(bar.volume);
trades.push(bar.trades as f64); trades.push(bar.trades as f64);
(hour, day, open, high, low, close, volume, trades) (
symbol, time, hour, day, open, high, low, close, volume, trades,
)
}, },
); );
@@ -167,26 +169,6 @@ pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
let mut close_pct = Vec::with_capacity(length); let mut close_pct = Vec::with_capacity(length);
let mut volume_pct = Vec::with_capacity(length); let mut volume_pct = Vec::with_capacity(length);
let mut trades_pct = Vec::with_capacity(length); let mut trades_pct = Vec::with_capacity(length);
let mut close_deriv_pct = Vec::with_capacity(length);
let mut sma_3_pct = Vec::with_capacity(length);
let mut sma_6_pct = Vec::with_capacity(length);
let mut sma_12_pct = Vec::with_capacity(length);
let mut sma_24_pct = Vec::with_capacity(length);
let mut sma_48_pct = Vec::with_capacity(length);
let mut sma_72_pct = Vec::with_capacity(length);
let mut ema_3_pct = Vec::with_capacity(length);
let mut ema_6_pct = Vec::with_capacity(length);
let mut ema_12_pct = Vec::with_capacity(length);
let mut ema_24_pct = Vec::with_capacity(length);
let mut ema_48_pct = Vec::with_capacity(length);
let mut ema_72_pct = Vec::with_capacity(length);
let mut macd_pct = Vec::with_capacity(length);
let mut macd_signal_pct = Vec::with_capacity(length);
let mut obv_pct = Vec::with_capacity(length);
let mut rsi_pct = Vec::with_capacity(length);
let mut bbands_upper_pct = Vec::with_capacity(length);
let mut bbands_mean_pct = Vec::with_capacity(length);
let mut bbands_lower_pct = Vec::with_capacity(length);
scope(|s| { scope(|s| {
s.spawn(|_| open_pct.extend(open.iter().pct())); s.spawn(|_| open_pct.extend(open.iter().pct()));
@@ -195,31 +177,13 @@ pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
s.spawn(|_| close_pct.extend(close.iter().pct())); s.spawn(|_| close_pct.extend(close.iter().pct()));
s.spawn(|_| volume_pct.extend(volume.iter().pct())); s.spawn(|_| volume_pct.extend(volume.iter().pct()));
s.spawn(|_| trades_pct.extend(trades.iter().pct())); s.spawn(|_| trades_pct.extend(trades.iter().pct()));
s.spawn(|_| close_deriv_pct.extend(close_deriv.iter().pct()));
s.spawn(|_| sma_3_pct.extend(sma_3.iter().pct()));
s.spawn(|_| sma_6_pct.extend(sma_6.iter().pct()));
s.spawn(|_| sma_12_pct.extend(sma_12.iter().pct()));
s.spawn(|_| sma_24_pct.extend(sma_24.iter().pct()));
s.spawn(|_| sma_48_pct.extend(sma_48.iter().pct()));
s.spawn(|_| sma_72_pct.extend(sma_72.iter().pct()));
s.spawn(|_| ema_3_pct.extend(ema_3.iter().pct()));
s.spawn(|_| ema_6_pct.extend(ema_6.iter().pct()));
s.spawn(|_| ema_12_pct.extend(ema_12.iter().pct()));
s.spawn(|_| ema_24_pct.extend(ema_24.iter().pct()));
s.spawn(|_| ema_48_pct.extend(ema_48.iter().pct()));
s.spawn(|_| ema_72_pct.extend(ema_72.iter().pct()));
s.spawn(|_| macd_pct.extend(macd.iter().pct()));
s.spawn(|_| macd_signal_pct.extend(macd_signal.iter().pct()));
s.spawn(|_| obv_pct.extend(obv.iter().pct()));
s.spawn(|_| rsi_pct.extend(rsi.iter().pct()));
s.spawn(|_| bbands_upper_pct.extend(bbands_upper.iter().pct()));
s.spawn(|_| bbands_mean_pct.extend(bbands_mean.iter().pct()));
s.spawn(|_| bbands_lower_pct.extend(bbands_lower.iter().pct()));
}); });
bars.iter() bars.iter()
.enumerate() .enumerate()
.map(|(i, _)| IndicatedBar { .map(|(i, _)| IndicatedBar {
symbol: symbol[i].clone(),
time: time[i],
hour: hour[i], hour: hour[i],
day: day[i], day: day[i],
open: open[i], open: open[i],
@@ -234,50 +198,49 @@ pub fn calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> {
volume_pct: volume_pct[i], volume_pct: volume_pct[i],
trades: trades[i], trades: trades[i],
trades_pct: trades_pct[i], trades_pct: trades_pct[i],
close_deriv: close_deriv[i],
close_deriv_pct: close_deriv_pct[i],
sma_3: sma_3[i], sma_3: sma_3[i],
sma_3_pct: sma_3_pct[i],
sma_6: sma_6[i], sma_6: sma_6[i],
sma_6_pct: sma_6_pct[i],
sma_12: sma_12[i], sma_12: sma_12[i],
sma_12_pct: sma_12_pct[i],
sma_24: sma_24[i], sma_24: sma_24[i],
sma_24_pct: sma_24_pct[i],
sma_48: sma_48[i], sma_48: sma_48[i],
sma_48_pct: sma_48_pct[i],
sma_72: sma_72[i], sma_72: sma_72[i],
sma_72_pct: sma_72_pct[i],
ema_3: ema_3[i], ema_3: ema_3[i],
ema_3_pct: ema_3_pct[i],
ema_6: ema_6[i], ema_6: ema_6[i],
ema_6_pct: ema_6_pct[i],
ema_12: ema_12[i], ema_12: ema_12[i],
ema_12_pct: ema_12_pct[i],
ema_24: ema_24[i], ema_24: ema_24[i],
ema_24_pct: ema_24_pct[i],
ema_48: ema_48[i], ema_48: ema_48[i],
ema_48_pct: ema_48_pct[i],
ema_72: ema_72[i], ema_72: ema_72[i],
ema_72_pct: ema_72_pct[i],
macd: macd[i], macd: macd[i],
macd_pct: macd_pct[i],
macd_signal: macd_signal[i], macd_signal: macd_signal[i],
macd_signal_pct: macd_signal_pct[i],
obv: obv[i], obv: obv[i],
obv_pct: obv_pct[i],
rsi: rsi[i], rsi: rsi[i],
rsi_pct: rsi_pct[i],
bbands_lower: bbands_lower[i], bbands_lower: bbands_lower[i],
bbands_lower_pct: bbands_lower_pct[i],
bbands_mean: bbands_mean[i], bbands_mean: bbands_mean[i],
bbands_mean_pct: bbands_mean_pct[i],
bbands_upper: bbands_upper[i], bbands_upper: bbands_upper[i],
bbands_upper_pct: bbands_upper_pct[i],
}) })
.collect() .collect()
} }
pub fn calculate_indicators<I>(bars: I) -> Vec<Vec<IndicatedBar>>
where
I: IntoIterator<Item = Bar>,
{
bars.into_iter()
.filter(|bar| {
bar.open > 0.0
&& bar.high > 0.0
&& bar.low > 0.0
&& bar.close > 0.0
&& bar.volume > 0.0
&& bar.trades > 0
})
.sorted_by_key(|bar| (bar.symbol.clone(), bar.time))
.group_by(|bar| bar.symbol.clone())
.into_iter()
.map(|(_, group)| _calculate_indicators(&group.collect::<Vec<_>>()))
.collect::<Vec<_>>()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -285,7 +248,6 @@ mod tests {
distributions::{Distribution, Uniform}, distributions::{Distribution, Uniform},
Rng, Rng,
}; };
use time::OffsetDateTime;
#[test] #[test]
fn test_calculate_indicators() { fn test_calculate_indicators() {
@@ -308,8 +270,8 @@ mod tests {
}); });
} }
let indicated_bars = calculate_indicators(&bars); let indicated_bars = calculate_indicators(bars);
assert_eq!(indicated_bars.len(), length); assert_eq!(indicated_bars[0].len(), length);
} }
} }

View File

@@ -1,3 +1,5 @@
CREATE DATABASE IF NOT EXISTS qrust;
CREATE TABLE IF NOT EXISTS qrust.assets ( CREATE TABLE IF NOT EXISTS qrust.assets (
symbol LowCardinality(String), symbol LowCardinality(String),
class Enum('us_equity' = 1, 'crypto' = 2), class Enum('us_equity' = 1, 'crypto' = 2),