Compare commits
	
		
			1 Commits
		
	
	
		
			90b7f10a77
			...
			jupyter
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2036e5fa32 | 
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -2,7 +2,6 @@ | |||||||
| # will have compiled files and executables | # will have compiled files and executables | ||||||
| debug/ | debug/ | ||||||
| target/ | target/ | ||||||
| log/ |  | ||||||
|  |  | ||||||
| # These are backup files generated by rustfmt | # These are backup files generated by rustfmt | ||||||
| **/*.rs.bk | **/*.rs.bk | ||||||
| @@ -11,3 +10,9 @@ log/ | |||||||
| *.pdb | *.pdb | ||||||
|  |  | ||||||
| .env* | .env* | ||||||
|  |  | ||||||
|  | # ML models | ||||||
|  | models/*/rust_model.ot | ||||||
|  | notebooks/models/ | ||||||
|  |  | ||||||
|  | libdevice.10.bc | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ build: | |||||||
|   cache: |   cache: | ||||||
|     <<: *global_cache |     <<: *global_cache | ||||||
|   script: |   script: | ||||||
|     - cargo +nightly build --workspace |     - cargo +nightly build | ||||||
|  |  | ||||||
| test: | test: | ||||||
|   image: registry.karaolidis.com/karaolidis/qrust/rust |   image: registry.karaolidis.com/karaolidis/qrust/rust | ||||||
| @@ -30,7 +30,7 @@ test: | |||||||
|   cache: |   cache: | ||||||
|     <<: *global_cache |     <<: *global_cache | ||||||
|   script: |   script: | ||||||
|     - cargo +nightly test --workspace |     - cargo +nightly test | ||||||
|  |  | ||||||
| lint: | lint: | ||||||
|   image: registry.karaolidis.com/karaolidis/qrust/rust |   image: registry.karaolidis.com/karaolidis/qrust/rust | ||||||
| @@ -39,7 +39,7 @@ lint: | |||||||
|     <<: *global_cache |     <<: *global_cache | ||||||
|   script: |   script: | ||||||
|     - cargo +nightly fmt --all -- --check |     - cargo +nightly fmt --all -- --check | ||||||
|     - cargo +nightly clippy --workspace --all-targets --all-features |     - cargo +nightly clippy --all-targets --all-features | ||||||
|  |  | ||||||
| depcheck: | depcheck: | ||||||
|   image: registry.karaolidis.com/karaolidis/qrust/rust |   image: registry.karaolidis.com/karaolidis/qrust/rust | ||||||
| @@ -48,7 +48,7 @@ depcheck: | |||||||
|     <<: *global_cache |     <<: *global_cache | ||||||
|   script: |   script: | ||||||
|     - cargo +nightly outdated |     - cargo +nightly outdated | ||||||
|     - cargo +nightly udeps --workspace --all-targets |     - cargo +nightly udeps | ||||||
|  |  | ||||||
| build-release: | build-release: | ||||||
|   image: registry.karaolidis.com/karaolidis/qrust/rust |   image: registry.karaolidis.com/karaolidis/qrust/rust | ||||||
|   | |||||||
							
								
								
									
										2751
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2751
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										65
									
								
								Cargo.toml
									
									
									
									
									
								
							
							
						
						
									
										65
									
								
								Cargo.toml
									
									
									
									
									
								
							| @@ -3,18 +3,6 @@ name    = "qrust" | |||||||
| version = "0.1.0" | version = "0.1.0" | ||||||
| edition = "2021" | edition = "2021" | ||||||
|  |  | ||||||
| [lib] |  | ||||||
| name = "qrust" |  | ||||||
| path = "src/lib/qrust/mod.rs" |  | ||||||
|  |  | ||||||
| [[bin]] |  | ||||||
| name = "qrust" |  | ||||||
| path = "src/bin/qrust/mod.rs" |  | ||||||
|  |  | ||||||
| [[bin]] |  | ||||||
| name = "trainer" |  | ||||||
| path = "src/bin/trainer/mod.rs" |  | ||||||
|  |  | ||||||
| [profile.release] | [profile.release] | ||||||
| panic         = 'abort' | panic         = 'abort' | ||||||
| strip         = true | strip         = true | ||||||
| @@ -24,9 +12,9 @@ codegen-units = 1 | |||||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
|  |  | ||||||
| [dependencies] | [dependencies] | ||||||
| axum = "0.7.5" | axum = "0.7.4" | ||||||
| dotenv = "0.15.0" | dotenv = "0.15.0" | ||||||
| tokio = { version = "1.37.0", features = [ | tokio = { version = "1.32.0", features = [ | ||||||
|     "macros", |     "macros", | ||||||
|     "rt-multi-thread", |     "rt-multi-thread", | ||||||
| ] } | ] } | ||||||
| @@ -34,29 +22,29 @@ tokio-tungstenite = { version = "0.21.0", features = [ | |||||||
|     "tokio-native-tls", |     "tokio-native-tls", | ||||||
|     "native-tls", |     "native-tls", | ||||||
| ] } | ] } | ||||||
| log = "0.4.21" | log = "0.4.20" | ||||||
| log4rs = "1.3.0" | log4rs = "1.2.0" | ||||||
| serde = "1.0.201" | serde = "1.0.188" | ||||||
| serde_json = "1.0.117" | serde_json = "1.0.105" | ||||||
| serde_repr = "0.1.19" | serde_repr = "0.1.18" | ||||||
| serde_with = "3.8.1" | serde_with = "3.6.1" | ||||||
| serde-aux = "4.5.0" | serde-aux = "4.4.0" | ||||||
| futures-util = "0.3.30" | futures-util = "0.3.28" | ||||||
| reqwest = { version = "0.12.4", features = [ | reqwest = { version = "0.11.20", features = [ | ||||||
|     "json", |     "json", | ||||||
|  |     "serde_json", | ||||||
| ] } | ] } | ||||||
| http = "1.1.0" | http = "1.0.0" | ||||||
| governor = "0.6.3" | governor = "0.6.0" | ||||||
| clickhouse = { version = "0.11.6", features = [ | clickhouse = { version = "0.11.6", features = [ | ||||||
|     "watch", |     "watch", | ||||||
|     "time", |     "time", | ||||||
|     "uuid", |     "uuid", | ||||||
| ] } | ] } | ||||||
| uuid = { version = "1.8.0", features = [ | uuid = { version = "1.6.1", features = [ | ||||||
|     "serde", |     "serde", | ||||||
|     "v4", |  | ||||||
| ] } | ] } | ||||||
| time = { version = "0.3.36", features = [ | time = { version = "0.3.31", features = [ | ||||||
|     "serde", |     "serde", | ||||||
|     "serde-well-known", |     "serde-well-known", | ||||||
|     "serde-human-readable", |     "serde-human-readable", | ||||||
| @@ -67,22 +55,9 @@ time = { version = "0.3.36", features = [ | |||||||
| backoff = { version = "0.4.0", features = [ | backoff = { version = "0.4.0", features = [ | ||||||
|     "tokio", |     "tokio", | ||||||
| ] } | ] } | ||||||
| regex = "1.10.4" | regex = "1.10.3" | ||||||
| async-trait = "0.1.80" | html-escape = "0.2.13" | ||||||
|  | rust-bert = "0.22.0" | ||||||
|  | async-trait = "0.1.77" | ||||||
| itertools = "0.12.1" | itertools = "0.12.1" | ||||||
| lazy_static = "1.4.0" | lazy_static = "1.4.0" | ||||||
| nonempty = { version = "0.10.0", features = [ |  | ||||||
|     "serialize", |  | ||||||
| ] } |  | ||||||
| rand = "0.8.5" |  | ||||||
| rayon = "1.10.0" |  | ||||||
| burn = { version = "0.13.2", features = [ |  | ||||||
|     "wgpu", |  | ||||||
|     "cuda", |  | ||||||
|     "tui", |  | ||||||
|     "metrics", |  | ||||||
|     "train", |  | ||||||
| ] } |  | ||||||
|  |  | ||||||
| [dev-dependencies] |  | ||||||
| serde_test = "1.0.176" |  | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| # qrust | # qrust | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| `qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust. | `qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust. | ||||||
|   | |||||||
| @@ -4,14 +4,7 @@ appenders: | |||||||
|     encoder: |     encoder: | ||||||
|       pattern: "{d} {h({l})} {M}::{L} - {m}{n}" |       pattern: "{d} {h({l})} {M}::{L} - {m}{n}" | ||||||
|  |  | ||||||
|   file: |  | ||||||
|     kind: file |  | ||||||
|     path: "./log/output.log" |  | ||||||
|     encoder: |  | ||||||
|       pattern: "{d} {l} {M}::{L} - {m}{n}" |  | ||||||
|  |  | ||||||
| root: | root: | ||||||
|   level: info |   level: info | ||||||
|   appenders: |   appenders: | ||||||
|     - stdout |     - stdout | ||||||
|     - file |  | ||||||
|   | |||||||
							
								
								
									
										32
									
								
								models/finbert/config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								models/finbert/config.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | |||||||
|  | { | ||||||
|  |   "_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2", | ||||||
|  |   "architectures": [ | ||||||
|  |     "BertForSequenceClassification" | ||||||
|  |   ], | ||||||
|  |   "attention_probs_dropout_prob": 0.1, | ||||||
|  |   "gradient_checkpointing": false, | ||||||
|  |   "hidden_act": "gelu", | ||||||
|  |   "hidden_dropout_prob": 0.1, | ||||||
|  |   "hidden_size": 768, | ||||||
|  |   "id2label": { | ||||||
|  |     "0": "positive", | ||||||
|  |     "1": "negative", | ||||||
|  |     "2": "neutral" | ||||||
|  |   }, | ||||||
|  |   "initializer_range": 0.02, | ||||||
|  |   "intermediate_size": 3072, | ||||||
|  |   "label2id": { | ||||||
|  |     "positive": 0, | ||||||
|  |     "negative": 1, | ||||||
|  |     "neutral": 2 | ||||||
|  |   }, | ||||||
|  |   "layer_norm_eps": 1e-12, | ||||||
|  |   "max_position_embeddings": 512, | ||||||
|  |   "model_type": "bert", | ||||||
|  |   "num_attention_heads": 12, | ||||||
|  |   "num_hidden_layers": 12, | ||||||
|  |   "pad_token_id": 0, | ||||||
|  |   "position_embedding_type": "absolute", | ||||||
|  |   "type_vocab_size": 2, | ||||||
|  |   "vocab_size": 30522 | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								models/finbert/special_tokens_map.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								models/finbert/special_tokens_map.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} | ||||||
							
								
								
									
										1
									
								
								models/finbert/tokenizer_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								models/finbert/tokenizer_config.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"} | ||||||
							
								
								
									
										30522
									
								
								models/finbert/vocab.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30522
									
								
								models/finbert/vocab.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										3868
									
								
								notebooks/lstm.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3868
									
								
								notebooks/lstm.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,86 +0,0 @@ | |||||||
| use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; |  | ||||||
| use lazy_static::lazy_static; |  | ||||||
| use qrust::types::alpaca::shared::{Mode, Source}; |  | ||||||
| use reqwest::{ |  | ||||||
|     header::{HeaderMap, HeaderName, HeaderValue}, |  | ||||||
|     Client, |  | ||||||
| }; |  | ||||||
| use std::{env, num::NonZeroU32, sync::Arc}; |  | ||||||
| use tokio::sync::Semaphore; |  | ||||||
|  |  | ||||||
| lazy_static! { |  | ||||||
|     pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") |  | ||||||
|         .expect("ALPACA_MODE must be set.") |  | ||||||
|         .parse() |  | ||||||
|         .expect("ALPACA_MODE must be 'live' or 'paper'"); |  | ||||||
|     pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE { |  | ||||||
|         Mode::Live => String::from("api"), |  | ||||||
|         Mode::Paper => String::from("paper-api"), |  | ||||||
|     }; |  | ||||||
|     pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") |  | ||||||
|         .expect("ALPACA_SOURCE must be set.") |  | ||||||
|         .parse() |  | ||||||
|         .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); |  | ||||||
|     pub static ref ALPACA_API_KEY: String = |  | ||||||
|         env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); |  | ||||||
|     pub static ref ALPACA_API_SECRET: String = |  | ||||||
|         env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); |  | ||||||
|     pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE") |  | ||||||
|         .expect("BATCH_BACKFILL_BARS_SIZE must be set.") |  | ||||||
|         .parse() |  | ||||||
|         .expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer."); |  | ||||||
|     pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE") |  | ||||||
|         .expect("BATCH_BACKFILL_NEWS_SIZE must be set.") |  | ||||||
|         .parse() |  | ||||||
|         .expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer."); |  | ||||||
|     pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS") |  | ||||||
|         .expect("CLICKHOUSE_MAX_CONNECTIONS must be set.") |  | ||||||
|         .parse() |  | ||||||
|         .expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer."); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct Config { |  | ||||||
|     pub alpaca_client: Client, |  | ||||||
|     pub alpaca_rate_limiter: DefaultDirectRateLimiter, |  | ||||||
|     pub clickhouse_client: clickhouse::Client, |  | ||||||
|     pub clickhouse_concurrency_limiter: Arc<Semaphore>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Config { |  | ||||||
|     pub fn from_env() -> Self { |  | ||||||
|         Self { |  | ||||||
|             alpaca_client: Client::builder() |  | ||||||
|                 .default_headers(HeaderMap::from_iter([ |  | ||||||
|                     ( |  | ||||||
|                         HeaderName::from_static("apca-api-key-id"), |  | ||||||
|                         HeaderValue::from_str(&ALPACA_API_KEY) |  | ||||||
|                             .expect("Alpaca API key must not contain invalid characters."), |  | ||||||
|                     ), |  | ||||||
|                     ( |  | ||||||
|                         HeaderName::from_static("apca-api-secret-key"), |  | ||||||
|                         HeaderValue::from_str(&ALPACA_API_SECRET) |  | ||||||
|                             .expect("Alpaca API secret must not contain invalid characters."), |  | ||||||
|                     ), |  | ||||||
|                 ])) |  | ||||||
|                 .build() |  | ||||||
|                 .unwrap(), |  | ||||||
|             alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { |  | ||||||
|                 Source::Iex => NonZeroU32::new(200).unwrap(), |  | ||||||
|                 Source::Sip => NonZeroU32::new(10_000).unwrap(), |  | ||||||
|                 Source::Otc => unimplemented!("OTC rate limit not implemented."), |  | ||||||
|             })), |  | ||||||
|             clickhouse_client: clickhouse::Client::default() |  | ||||||
|                 .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) |  | ||||||
|                 .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) |  | ||||||
|                 .with_password( |  | ||||||
|                     env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), |  | ||||||
|                 ) |  | ||||||
|                 .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), |  | ||||||
|             clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn arc_from_env() -> Arc<Self> { |  | ||||||
|         Arc::new(Self::from_env()) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,115 +0,0 @@ | |||||||
| #![warn(clippy::all, clippy::pedantic, clippy::nursery)] |  | ||||||
| #![allow(clippy::missing_docs_in_private_items)] |  | ||||||
| #![feature(hash_extract_if)] |  | ||||||
|  |  | ||||||
| mod config; |  | ||||||
| mod init; |  | ||||||
| mod routes; |  | ||||||
| mod threads; |  | ||||||
|  |  | ||||||
| use config::{ |  | ||||||
|     Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE, |  | ||||||
|     CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS, |  | ||||||
| }; |  | ||||||
| use dotenv::dotenv; |  | ||||||
| use log::info; |  | ||||||
| use log4rs::config::Deserializers; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{create_send_await, database}; |  | ||||||
| use tokio::{join, spawn, sync::mpsc, try_join}; |  | ||||||
|  |  | ||||||
| #[tokio::main] |  | ||||||
| async fn main() { |  | ||||||
|     dotenv().ok(); |  | ||||||
|     log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); |  | ||||||
|     let config = Config::arc_from_env(); |  | ||||||
|  |  | ||||||
|     let _ = *ALPACA_MODE; |  | ||||||
|     let _ = *ALPACA_API_BASE; |  | ||||||
|     let _ = *ALPACA_SOURCE; |  | ||||||
|     let _ = *CLICKHOUSE_BATCH_BARS_SIZE; |  | ||||||
|     let _ = *CLICKHOUSE_BATCH_NEWS_SIZE; |  | ||||||
|     let _ = *CLICKHOUSE_MAX_CONNECTIONS; |  | ||||||
|  |  | ||||||
|     info!("Marking all assets as stale."); |  | ||||||
|  |  | ||||||
|     let assets = database::assets::select( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .unwrap() |  | ||||||
|     .into_iter() |  | ||||||
|     .map(|asset| (asset.symbol, asset.class)) |  | ||||||
|     .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|     let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|     try_join!( |  | ||||||
|         database::backfills_bars::set_fresh_where_symbols( |  | ||||||
|             &config.clickhouse_client, |  | ||||||
|             &config.clickhouse_concurrency_limiter, |  | ||||||
|             false, |  | ||||||
|             &symbols |  | ||||||
|         ), |  | ||||||
|         database::backfills_news::set_fresh_where_symbols( |  | ||||||
|             &config.clickhouse_client, |  | ||||||
|             &config.clickhouse_concurrency_limiter, |  | ||||||
|             false, |  | ||||||
|             &symbols |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     .unwrap(); |  | ||||||
|  |  | ||||||
|     info!("Cleaning up database."); |  | ||||||
|  |  | ||||||
|     database::cleanup_all( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .unwrap(); |  | ||||||
|  |  | ||||||
|     info!("Optimizing database."); |  | ||||||
|  |  | ||||||
|     database::optimize_all( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .unwrap(); |  | ||||||
|  |  | ||||||
|     info!("Rehydrating account data."); |  | ||||||
|  |  | ||||||
|     init::check_account(&config).await; |  | ||||||
|     join!( |  | ||||||
|         init::rehydrate_orders(&config), |  | ||||||
|         init::rehydrate_positions(&config) |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     info!("Starting threads."); |  | ||||||
|  |  | ||||||
|     spawn(threads::trading::run(config.clone())); |  | ||||||
|  |  | ||||||
|     let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100); |  | ||||||
|     let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1); |  | ||||||
|  |  | ||||||
|     spawn(threads::data::run( |  | ||||||
|         config.clone(), |  | ||||||
|         data_receiver, |  | ||||||
|         clock_receiver, |  | ||||||
|     )); |  | ||||||
|  |  | ||||||
|     spawn(threads::clock::run(config.clone(), clock_sender)); |  | ||||||
|  |  | ||||||
|     if let Some(assets) = NonEmpty::from_vec(assets) { |  | ||||||
|         create_send_await!( |  | ||||||
|             data_sender, |  | ||||||
|             threads::data::Message::new, |  | ||||||
|             threads::data::Action::Enable, |  | ||||||
|             assets |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     routes::run(config, data_sender).await; |  | ||||||
| } |  | ||||||
| @@ -1,197 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     config::{Config, ALPACA_API_BASE}, |  | ||||||
|     create_send_await, database, threads, |  | ||||||
| }; |  | ||||||
| use axum::{extract::Path, Extension, Json}; |  | ||||||
| use http::StatusCode; |  | ||||||
| use nonempty::{nonempty, NonEmpty}; |  | ||||||
| use qrust::{ |  | ||||||
|     alpaca, |  | ||||||
|     types::{self, Asset}, |  | ||||||
| }; |  | ||||||
| use serde::{Deserialize, Serialize}; |  | ||||||
| use std::{ |  | ||||||
|     collections::{HashMap, HashSet}, |  | ||||||
|     sync::Arc, |  | ||||||
| }; |  | ||||||
| use tokio::sync::mpsc; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     Extension(config): Extension<Arc<Config>>, |  | ||||||
| ) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> { |  | ||||||
|     let assets = database::assets::select( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; |  | ||||||
|  |  | ||||||
|     Ok((StatusCode::OK, Json(assets))) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get_where_symbol( |  | ||||||
|     Extension(config): Extension<Arc<Config>>, |  | ||||||
|     Path(symbol): Path<String>, |  | ||||||
| ) -> Result<(StatusCode, Json<Asset>), StatusCode> { |  | ||||||
|     let asset = database::assets::select_where_symbol( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|         &symbol, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; |  | ||||||
|  |  | ||||||
|     asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { |  | ||||||
|         Ok((StatusCode::OK, Json(asset))) |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct AddAssetsRequest { |  | ||||||
|     symbols: Vec<String>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Serialize)] |  | ||||||
| pub struct AddAssetsResponse { |  | ||||||
|     added: Vec<String>, |  | ||||||
|     skipped: Vec<String>, |  | ||||||
|     failed: Vec<String>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn add( |  | ||||||
|     Extension(config): Extension<Arc<Config>>, |  | ||||||
|     Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, |  | ||||||
|     Json(request): Json<AddAssetsRequest>, |  | ||||||
| ) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> { |  | ||||||
|     let database_symbols = database::assets::select( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? |  | ||||||
|     .into_iter() |  | ||||||
|     .map(|asset| asset.symbol) |  | ||||||
|     .collect::<HashSet<_>>(); |  | ||||||
|  |  | ||||||
|     let mut alpaca_assets = alpaca::assets::get_by_symbols( |  | ||||||
|         &config.alpaca_client, |  | ||||||
|         &config.alpaca_rate_limiter, |  | ||||||
|         &request.symbols, |  | ||||||
|         None, |  | ||||||
|         &ALPACA_API_BASE, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))? |  | ||||||
|     .into_iter() |  | ||||||
|     .map(|asset| (asset.symbol.clone(), asset)) |  | ||||||
|     .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|     let num_symbols = request.symbols.len(); |  | ||||||
|     let (assets, skipped, failed) = request.symbols.into_iter().fold( |  | ||||||
|         (Vec::with_capacity(num_symbols), vec![], vec![]), |  | ||||||
|         |(mut assets, mut skipped, mut failed), symbol| { |  | ||||||
|             if database_symbols.contains(&symbol) { |  | ||||||
|                 skipped.push(symbol); |  | ||||||
|             } else if let Some(asset) = alpaca_assets.remove(&symbol) { |  | ||||||
|                 if asset.status == types::alpaca::api::incoming::asset::Status::Active |  | ||||||
|                     && asset.tradable |  | ||||||
|                     && asset.fractionable |  | ||||||
|                 { |  | ||||||
|                     assets.push((asset.symbol, asset.class.into())); |  | ||||||
|                 } else { |  | ||||||
|                     failed.push(asset.symbol); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 failed.push(symbol); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             (assets, skipped, failed) |  | ||||||
|         }, |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     if let Some(assets) = NonEmpty::from_vec(assets.clone()) { |  | ||||||
|         create_send_await!( |  | ||||||
|             data_sender, |  | ||||||
|             threads::data::Message::new, |  | ||||||
|             threads::data::Action::Add, |  | ||||||
|             assets |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     Ok(( |  | ||||||
|         StatusCode::OK, |  | ||||||
|         Json(AddAssetsResponse { |  | ||||||
|             added: assets.into_iter().map(|asset| asset.0).collect(), |  | ||||||
|             skipped, |  | ||||||
|             failed, |  | ||||||
|         }), |  | ||||||
|     )) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn add_symbol( |  | ||||||
|     Extension(config): Extension<Arc<Config>>, |  | ||||||
|     Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, |  | ||||||
|     Path(symbol): Path<String>, |  | ||||||
| ) -> Result<StatusCode, StatusCode> { |  | ||||||
|     if database::assets::select_where_symbol( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|         &symbol, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? |  | ||||||
|     .is_some() |  | ||||||
|     { |  | ||||||
|         return Err(StatusCode::CONFLICT); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let asset = alpaca::assets::get_by_symbol( |  | ||||||
|         &config.alpaca_client, |  | ||||||
|         &config.alpaca_rate_limiter, |  | ||||||
|         &symbol, |  | ||||||
|         None, |  | ||||||
|         &ALPACA_API_BASE, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|e| e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))?; |  | ||||||
|  |  | ||||||
|     if asset.status != types::alpaca::api::incoming::asset::Status::Active |  | ||||||
|         || !asset.tradable |  | ||||||
|         || !asset.fractionable |  | ||||||
|     { |  | ||||||
|         return Err(StatusCode::FORBIDDEN); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     create_send_await!( |  | ||||||
|         data_sender, |  | ||||||
|         threads::data::Message::new, |  | ||||||
|         threads::data::Action::Add, |  | ||||||
|         nonempty![(asset.symbol, asset.class.into())] |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     Ok(StatusCode::CREATED) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn delete( |  | ||||||
|     Extension(config): Extension<Arc<Config>>, |  | ||||||
|     Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, |  | ||||||
|     Path(symbol): Path<String>, |  | ||||||
| ) -> Result<StatusCode, StatusCode> { |  | ||||||
|     let asset = database::assets::select_where_symbol( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|         &symbol, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? |  | ||||||
|     .ok_or(StatusCode::NOT_FOUND)?; |  | ||||||
|  |  | ||||||
|     create_send_await!( |  | ||||||
|         data_sender, |  | ||||||
|         threads::data::Message::new, |  | ||||||
|         threads::data::Action::Remove, |  | ||||||
|         nonempty![(asset.symbol, asset.class)] |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     Ok(StatusCode::NO_CONTENT) |  | ||||||
| } |  | ||||||
| @@ -1,238 +0,0 @@ | |||||||
| use super::Job; |  | ||||||
| use crate::{ |  | ||||||
|     config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE}, |  | ||||||
|     database, |  | ||||||
|     threads::data::ThreadType, |  | ||||||
| }; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use log::{error, info}; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{ |  | ||||||
|     alpaca, |  | ||||||
|     types::{ |  | ||||||
|         self, |  | ||||||
|         alpaca::{ |  | ||||||
|             api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL}, |  | ||||||
|             shared::{Sort, Source}, |  | ||||||
|         }, |  | ||||||
|         Backfill, Bar, Class, |  | ||||||
|     }, |  | ||||||
|     utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, |  | ||||||
| }; |  | ||||||
| use std::{collections::HashMap, sync::Arc}; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
| use tokio::time::sleep; |  | ||||||
|  |  | ||||||
| pub struct Handler { |  | ||||||
|     pub config: Arc<Config>, |  | ||||||
|     pub data_url: &'static str, |  | ||||||
|     pub api_query_constructor: fn( |  | ||||||
|         symbols: Vec<String>, |  | ||||||
|         fetch_from: OffsetDateTime, |  | ||||||
|         fetch_to: OffsetDateTime, |  | ||||||
|         next_page_token: Option<String>, |  | ||||||
|     ) -> types::alpaca::api::outgoing::bar::Bar, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn us_equity_query_constructor( |  | ||||||
|     symbols: Vec<String>, |  | ||||||
|     fetch_from: OffsetDateTime, |  | ||||||
|     fetch_to: OffsetDateTime, |  | ||||||
|     next_page_token: Option<String>, |  | ||||||
| ) -> types::alpaca::api::outgoing::bar::Bar { |  | ||||||
|     types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity { |  | ||||||
|         symbols, |  | ||||||
|         start: Some(fetch_from), |  | ||||||
|         end: Some(fetch_to), |  | ||||||
|         page_token: next_page_token, |  | ||||||
|         sort: Some(Sort::Asc), |  | ||||||
|         feed: Some(*ALPACA_SOURCE), |  | ||||||
|         ..Default::default() |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn crypto_query_constructor( |  | ||||||
|     symbols: Vec<String>, |  | ||||||
|     fetch_from: OffsetDateTime, |  | ||||||
|     fetch_to: OffsetDateTime, |  | ||||||
|     next_page_token: Option<String>, |  | ||||||
| ) -> types::alpaca::api::outgoing::bar::Bar { |  | ||||||
|     types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto { |  | ||||||
|         symbols, |  | ||||||
|         start: Some(fetch_from), |  | ||||||
|         end: Some(fetch_to), |  | ||||||
|         page_token: next_page_token, |  | ||||||
|         sort: Some(Sort::Asc), |  | ||||||
|         ..Default::default() |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| impl super::Handler for Handler { |  | ||||||
|     async fn select_latest_backfills( |  | ||||||
|         &self, |  | ||||||
|         symbols: &[String], |  | ||||||
|     ) -> Result<Vec<Backfill>, clickhouse::error::Error> { |  | ||||||
|         database::backfills_bars::select_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { |  | ||||||
|         database::backfills_bars::delete_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { |  | ||||||
|         database::bars::delete_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn queue_backfill(&self, jobs: &NonEmpty<Job>) { |  | ||||||
|         if *ALPACA_SOURCE == Source::Sip { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to; |  | ||||||
|         let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); |  | ||||||
|         let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay); |  | ||||||
|         sleep(run_delay).await; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn backfill(&self, jobs: NonEmpty<Job>) { |  | ||||||
|         let symbols = Vec::from(jobs.clone().map(|job| job.symbol)); |  | ||||||
|         let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from; |  | ||||||
|         let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to; |  | ||||||
|         let freshness = jobs |  | ||||||
|             .into_iter() |  | ||||||
|             .map(|job| (job.symbol, job.fresh)) |  | ||||||
|             .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|         let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE); |  | ||||||
|         let mut last_times = HashMap::new(); |  | ||||||
|         let mut next_page_token = None; |  | ||||||
|  |  | ||||||
|         info!("Backfilling bars for {:?}.", symbols); |  | ||||||
|  |  | ||||||
|         loop { |  | ||||||
|             let message = alpaca::bars::get( |  | ||||||
|                 &self.config.alpaca_client, |  | ||||||
|                 &self.config.alpaca_rate_limiter, |  | ||||||
|                 self.data_url, |  | ||||||
|                 &(self.api_query_constructor)( |  | ||||||
|                     symbols.clone(), |  | ||||||
|                     fetch_from, |  | ||||||
|                     fetch_to, |  | ||||||
|                     next_page_token.clone(), |  | ||||||
|                 ), |  | ||||||
|                 None, |  | ||||||
|             ) |  | ||||||
|             .await; |  | ||||||
|  |  | ||||||
|             if let Err(err) = message { |  | ||||||
|                 error!("Failed to backfill bars for {:?}: {:?}.", symbols, err); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             let message = message.unwrap(); |  | ||||||
|  |  | ||||||
|             for (symbol, bars_vec) in message.bars { |  | ||||||
|                 if let Some(last) = bars_vec.last() { |  | ||||||
|                     last_times.insert(symbol.clone(), last.time); |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 for bar in bars_vec { |  | ||||||
|                     bars.push(Bar::from((bar, symbol.clone()))); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             database::bars::upsert_batch( |  | ||||||
|                 &self.config.clickhouse_client, |  | ||||||
|                 &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                 &bars, |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|             let backfilled = last_times |  | ||||||
|                 .drain() |  | ||||||
|                 .map(|(symbol, time)| Backfill { |  | ||||||
|                     fresh: freshness[&symbol], |  | ||||||
|                     symbol, |  | ||||||
|                     time, |  | ||||||
|                 }) |  | ||||||
|                 .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|             database::backfills_bars::upsert_batch( |  | ||||||
|                 &self.config.clickhouse_client, |  | ||||||
|                 &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                 &backfilled, |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|             if message.next_page_token.is_none() { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             next_page_token = message.next_page_token; |  | ||||||
|             bars.clear(); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         database::backfills_bars::set_fresh_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             true, |  | ||||||
|             &symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         .unwrap(); |  | ||||||
|  |  | ||||||
|         info!("Backfilled bars for {:?}.", symbols); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn max_limit(&self) -> i64 { |  | ||||||
|         alpaca::bars::MAX_LIMIT |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn log_string(&self) -> &'static str { |  | ||||||
|         "bars" |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> { |  | ||||||
|     let data_url = match thread_type { |  | ||||||
|         ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL, |  | ||||||
|         ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL, |  | ||||||
|         _ => unreachable!(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let api_query_constructor = match thread_type { |  | ||||||
|         ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor, |  | ||||||
|         ThreadType::Bars(Class::Crypto) => crypto_query_constructor, |  | ||||||
|         _ => unreachable!(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     Box::new(Handler { |  | ||||||
|         config, |  | ||||||
|         data_url, |  | ||||||
|         api_query_constructor, |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
| @@ -1,243 +0,0 @@ | |||||||
| pub mod bars; |  | ||||||
| pub mod news; |  | ||||||
|  |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use itertools::Itertools; |  | ||||||
| use log::{info, warn}; |  | ||||||
| use nonempty::{nonempty, NonEmpty}; |  | ||||||
| use qrust::{ |  | ||||||
|     types::Backfill, |  | ||||||
|     utils::{last_minute, ONE_SECOND}, |  | ||||||
| }; |  | ||||||
| use std::{collections::HashMap, hash::Hash, sync::Arc}; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
| use tokio::{ |  | ||||||
|     spawn, |  | ||||||
|     sync::{mpsc, oneshot, Mutex}, |  | ||||||
|     task::JoinHandle, |  | ||||||
|     try_join, |  | ||||||
| }; |  | ||||||
| use uuid::Uuid; |  | ||||||
|  |  | ||||||
| pub enum Action { |  | ||||||
|     Backfill, |  | ||||||
|     Purge, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct Message { |  | ||||||
|     pub action: Action, |  | ||||||
|     pub symbols: NonEmpty<String>, |  | ||||||
|     pub response: oneshot::Sender<()>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Message { |  | ||||||
|     pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) { |  | ||||||
|         let (sender, receiver) = oneshot::channel::<()>(); |  | ||||||
|         ( |  | ||||||
|             Self { |  | ||||||
|                 action, |  | ||||||
|                 symbols, |  | ||||||
|                 response: sender, |  | ||||||
|             }, |  | ||||||
|             receiver, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct Job { |  | ||||||
|     pub symbol: String, |  | ||||||
|     pub fetch_from: OffsetDateTime, |  | ||||||
|     pub fetch_to: OffsetDateTime, |  | ||||||
|     pub fresh: bool, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| pub trait Handler: Send + Sync { |  | ||||||
|     async fn select_latest_backfills( |  | ||||||
|         &self, |  | ||||||
|         symbols: &[String], |  | ||||||
|     ) -> Result<Vec<Backfill>, clickhouse::error::Error>; |  | ||||||
|     async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; |  | ||||||
|     async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; |  | ||||||
|     async fn queue_backfill(&self, jobs: &NonEmpty<Job>); |  | ||||||
|     async fn backfill(&self, jobs: NonEmpty<Job>); |  | ||||||
|     fn max_limit(&self) -> i64; |  | ||||||
|     fn log_string(&self) -> &'static str; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct Jobs { |  | ||||||
|     pub symbol_to_uuid: HashMap<String, Uuid>, |  | ||||||
|     pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Jobs { |  | ||||||
|     pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) { |  | ||||||
|         let uuid = Uuid::new_v4(); |  | ||||||
|         for symbol in jobs { |  | ||||||
|             self.symbol_to_uuid.insert(symbol.clone(), uuid); |  | ||||||
|         } |  | ||||||
|         self.uuid_to_job.insert(uuid, fut); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn contains_key(&self, symbol: &str) -> bool { |  | ||||||
|         self.symbol_to_uuid.contains_key(symbol) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> { |  | ||||||
|         self.symbol_to_uuid |  | ||||||
|             .remove(symbol) |  | ||||||
|             .and_then(|uuid| self.uuid_to_job.remove(&uuid)) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn remove_many<T>(&mut self, symbols: &[T]) |  | ||||||
|     where |  | ||||||
|         T: AsRef<str> + Hash + Eq, |  | ||||||
|     { |  | ||||||
|         for symbol in symbols { |  | ||||||
|             self.symbol_to_uuid |  | ||||||
|                 .remove(symbol.as_ref()) |  | ||||||
|                 .and_then(|uuid| self.uuid_to_job.remove(&uuid)); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn len(&self) -> usize { |  | ||||||
|         self.symbol_to_uuid.len() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) { |  | ||||||
|     let backfill_jobs = Arc::new(Mutex::new(Jobs { |  | ||||||
|         symbol_to_uuid: HashMap::new(), |  | ||||||
|         uuid_to_job: HashMap::new(), |  | ||||||
|     })); |  | ||||||
|  |  | ||||||
|     loop { |  | ||||||
|         let message = receiver.recv().await.unwrap(); |  | ||||||
|         spawn(handle_message( |  | ||||||
|             handler.clone(), |  | ||||||
|             backfill_jobs.clone(), |  | ||||||
|             message, |  | ||||||
|         )); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| async fn handle_message( |  | ||||||
|     handler: Arc<Box<dyn Handler>>, |  | ||||||
|     backfill_jobs: Arc<Mutex<Jobs>>, |  | ||||||
|     message: Message, |  | ||||||
| ) { |  | ||||||
|     let backfill_jobs_clone = backfill_jobs.clone(); |  | ||||||
|     let mut backfill_jobs = backfill_jobs.lock().await; |  | ||||||
|     let symbols = Vec::from(message.symbols); |  | ||||||
|  |  | ||||||
|     match message.action { |  | ||||||
|         Action::Backfill => { |  | ||||||
|             let log_string = handler.log_string(); |  | ||||||
|             let max_limit = handler.max_limit(); |  | ||||||
|  |  | ||||||
|             let backfills = handler |  | ||||||
|                 .select_latest_backfills(&symbols) |  | ||||||
|                 .await |  | ||||||
|                 .unwrap() |  | ||||||
|                 .into_iter() |  | ||||||
|                 .map(|backfill| (backfill.symbol.clone(), backfill)) |  | ||||||
|                 .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|             let mut jobs = Vec::with_capacity(symbols.len()); |  | ||||||
|  |  | ||||||
|             for symbol in symbols { |  | ||||||
|                 if backfill_jobs.contains_key(&symbol) { |  | ||||||
|                     warn!( |  | ||||||
|                         "Backfill for {} {} is already running, skipping.", |  | ||||||
|                         symbol, log_string |  | ||||||
|                     ); |  | ||||||
|                     continue; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 let backfill = backfills.get(&symbol); |  | ||||||
|  |  | ||||||
|                 let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| { |  | ||||||
|                     backfill.time + ONE_SECOND |  | ||||||
|                 }); |  | ||||||
|  |  | ||||||
|                 let fetch_to = last_minute(); |  | ||||||
|  |  | ||||||
|                 if fetch_from > fetch_to { |  | ||||||
|                     info!("No need to backfill {} {}.", symbol, log_string,); |  | ||||||
|                     return; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 let fresh = backfill.map_or(false, |backfill| backfill.fresh); |  | ||||||
|  |  | ||||||
|                 jobs.push(Job { |  | ||||||
|                     symbol, |  | ||||||
|                     fetch_from, |  | ||||||
|                     fetch_to, |  | ||||||
|                     fresh, |  | ||||||
|                 }); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             let mut current_minutes = 0; |  | ||||||
|             let job_groups = jobs |  | ||||||
|                 .into_iter() |  | ||||||
|                 .sorted_unstable_by_key(|job| job.fetch_from) |  | ||||||
|                 .fold(Vec::<NonEmpty<Job>>::new(), |mut job_groups, job| { |  | ||||||
|                     let minutes = (job.fetch_to - job.fetch_from).whole_minutes(); |  | ||||||
|  |  | ||||||
|                     if let Some(job_group) = job_groups.last_mut() { |  | ||||||
|                         if current_minutes + minutes <= max_limit { |  | ||||||
|                             job_group.push(job); |  | ||||||
|                             current_minutes += minutes; |  | ||||||
|                             return job_groups; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     job_groups.push(nonempty![job]); |  | ||||||
|                     current_minutes = minutes; |  | ||||||
|                     job_groups |  | ||||||
|                 }); |  | ||||||
|  |  | ||||||
|             for job_group in job_groups { |  | ||||||
|                 let symbols = job_group |  | ||||||
|                     .iter() |  | ||||||
|                     .map(|job| job.symbol.clone()) |  | ||||||
|                     .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|                 let handler = handler.clone(); |  | ||||||
|                 let symbols_clone = symbols.clone(); |  | ||||||
|                 let backfill_jobs_clone = backfill_jobs_clone.clone(); |  | ||||||
|  |  | ||||||
|                 let fut = spawn(async move { |  | ||||||
|                     handler.queue_backfill(&job_group).await; |  | ||||||
|                     handler.backfill(job_group).await; |  | ||||||
|  |  | ||||||
|                     let mut backfill_jobs = backfill_jobs_clone.lock().await; |  | ||||||
|                     backfill_jobs.remove_many(&symbols_clone); |  | ||||||
|                     let remaining = backfill_jobs.len(); |  | ||||||
|                     drop(backfill_jobs); |  | ||||||
|  |  | ||||||
|                     info!("{} {} backfills remaining.", remaining, log_string); |  | ||||||
|                 }); |  | ||||||
|  |  | ||||||
|                 backfill_jobs.insert(symbols, fut); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         Action::Purge => { |  | ||||||
|             for symbol in &symbols { |  | ||||||
|                 if let Some(job) = backfill_jobs.remove(symbol) { |  | ||||||
|                     job.abort(); |  | ||||||
|                     let _ = job.await; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             try_join!( |  | ||||||
|                 handler.delete_backfills(&symbols), |  | ||||||
|                 handler.delete_data(&symbols) |  | ||||||
|             ) |  | ||||||
|             .unwrap(); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     message.response.send(()).unwrap(); |  | ||||||
| } |  | ||||||
| @@ -1,186 +0,0 @@ | |||||||
| use super::Job; |  | ||||||
| use crate::{ |  | ||||||
|     config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE}, |  | ||||||
|     database, |  | ||||||
| }; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use log::{error, info}; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{ |  | ||||||
|     alpaca, |  | ||||||
|     types::{ |  | ||||||
|         self, |  | ||||||
|         alpaca::shared::{Sort, Source}, |  | ||||||
|         Backfill, News, |  | ||||||
|     }, |  | ||||||
|     utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, |  | ||||||
| }; |  | ||||||
| use std::{ |  | ||||||
|     collections::{HashMap, HashSet}, |  | ||||||
|     sync::Arc, |  | ||||||
| }; |  | ||||||
| use tokio::time::sleep; |  | ||||||
|  |  | ||||||
| pub struct Handler { |  | ||||||
|     pub config: Arc<Config>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| impl super::Handler for Handler { |  | ||||||
|     async fn select_latest_backfills( |  | ||||||
|         &self, |  | ||||||
|         symbols: &[String], |  | ||||||
|     ) -> Result<Vec<Backfill>, clickhouse::error::Error> { |  | ||||||
|         database::backfills_news::select_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { |  | ||||||
|         database::backfills_news::delete_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { |  | ||||||
|         database::news::delete_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn queue_backfill(&self, jobs: &NonEmpty<Job>) { |  | ||||||
|         if *ALPACA_SOURCE == Source::Sip { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to; |  | ||||||
|         let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); |  | ||||||
|         let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay); |  | ||||||
|         sleep(run_delay).await; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[allow(clippy::too_many_lines)] |  | ||||||
|     #[allow(clippy::iter_with_drain)] |  | ||||||
|     async fn backfill(&self, jobs: NonEmpty<Job>) { |  | ||||||
|         let symbols = Vec::from(jobs.clone().map(|job| job.symbol)); |  | ||||||
|         let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>(); |  | ||||||
|         let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from; |  | ||||||
|         let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to; |  | ||||||
|         let freshness = jobs |  | ||||||
|             .into_iter() |  | ||||||
|             .map(|job| (job.symbol, job.fresh)) |  | ||||||
|             .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|         let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE); |  | ||||||
|         let mut last_times = HashMap::new(); |  | ||||||
|         let mut next_page_token = None; |  | ||||||
|  |  | ||||||
|         info!("Backfilling news for {:?}.", symbols); |  | ||||||
|  |  | ||||||
|         loop { |  | ||||||
|             let message = alpaca::news::get( |  | ||||||
|                 &self.config.alpaca_client, |  | ||||||
|                 &self.config.alpaca_rate_limiter, |  | ||||||
|                 &types::alpaca::api::outgoing::news::News { |  | ||||||
|                     symbols: symbols.clone(), |  | ||||||
|                     start: Some(fetch_from), |  | ||||||
|                     end: Some(fetch_to), |  | ||||||
|                     page_token: next_page_token.clone(), |  | ||||||
|                     sort: Some(Sort::Asc), |  | ||||||
|                     ..Default::default() |  | ||||||
|                 }, |  | ||||||
|                 None, |  | ||||||
|             ) |  | ||||||
|             .await; |  | ||||||
|  |  | ||||||
|             if let Err(err) = message { |  | ||||||
|                 error!("Failed to backfill news for {:?}: {:?}.", symbols, err); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             let message = message.unwrap(); |  | ||||||
|  |  | ||||||
|             for news_item in message.news { |  | ||||||
|                 let news_item = News::from(news_item); |  | ||||||
|  |  | ||||||
|                 for symbol in &news_item.symbols { |  | ||||||
|                     if symbols_set.contains(symbol) { |  | ||||||
|                         last_times.insert(symbol.clone(), news_item.time_created); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 news.push(news_item); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             database::news::upsert_batch( |  | ||||||
|                 &self.config.clickhouse_client, |  | ||||||
|                 &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                 &news, |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|             let backfilled = last_times |  | ||||||
|                 .drain() |  | ||||||
|                 .map(|(symbol, time)| Backfill { |  | ||||||
|                     fresh: freshness[&symbol], |  | ||||||
|                     symbol, |  | ||||||
|                     time, |  | ||||||
|                 }) |  | ||||||
|                 .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|             database::backfills_news::upsert_batch( |  | ||||||
|                 &self.config.clickhouse_client, |  | ||||||
|                 &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                 &backfilled, |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|             if message.next_page_token.is_none() { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             next_page_token = message.next_page_token; |  | ||||||
|             news.clear(); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         database::backfills_news::set_fresh_where_symbols( |  | ||||||
|             &self.config.clickhouse_client, |  | ||||||
|             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|             true, |  | ||||||
|             &symbols, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         .unwrap(); |  | ||||||
|  |  | ||||||
|         info!("Backfilled news for {:?}.", symbols); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn max_limit(&self) -> i64 { |  | ||||||
|         alpaca::news::MAX_LIMIT |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn log_string(&self) -> &'static str { |  | ||||||
|         "news" |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> { |  | ||||||
|     Box::new(Handler { config }) |  | ||||||
| } |  | ||||||
| @@ -1,391 +0,0 @@ | |||||||
| mod backfill; |  | ||||||
| mod websocket; |  | ||||||
|  |  | ||||||
| use super::clock; |  | ||||||
| use crate::{ |  | ||||||
|     config::{Config, ALPACA_API_BASE, ALPACA_SOURCE}, |  | ||||||
|     create_send_await, database, |  | ||||||
| }; |  | ||||||
| use itertools::{Either, Itertools}; |  | ||||||
| use log::error; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{ |  | ||||||
|     alpaca, |  | ||||||
|     types::{ |  | ||||||
|         alpaca::websocket::{ |  | ||||||
|             ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, |  | ||||||
|             ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, |  | ||||||
|         }, |  | ||||||
|         Asset, Class, |  | ||||||
|     }, |  | ||||||
| }; |  | ||||||
| use std::{collections::HashMap, sync::Arc}; |  | ||||||
| use tokio::{ |  | ||||||
|     join, select, spawn, |  | ||||||
|     sync::{mpsc, oneshot}, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Copy, PartialEq, Eq)] |  | ||||||
| #[allow(dead_code)] |  | ||||||
| pub enum Action { |  | ||||||
|     Add, |  | ||||||
|     Enable, |  | ||||||
|     Remove, |  | ||||||
|     Disable, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct Message { |  | ||||||
|     pub action: Action, |  | ||||||
|     pub assets: NonEmpty<(String, Class)>, |  | ||||||
|     pub response: oneshot::Sender<()>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Message { |  | ||||||
|     pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) { |  | ||||||
|         let (sender, receiver) = oneshot::channel(); |  | ||||||
|         ( |  | ||||||
|             Self { |  | ||||||
|                 action, |  | ||||||
|                 assets, |  | ||||||
|                 response: sender, |  | ||||||
|             }, |  | ||||||
|             receiver, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Clone, Copy)] |  | ||||||
| pub enum ThreadType { |  | ||||||
|     Bars(Class), |  | ||||||
|     News, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn run( |  | ||||||
|     config: Arc<Config>, |  | ||||||
|     mut receiver: mpsc::Receiver<Message>, |  | ||||||
|     mut clock_receiver: mpsc::Receiver<clock::Message>, |  | ||||||
| ) { |  | ||||||
|     let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = |  | ||||||
|         init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)); |  | ||||||
|     let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = |  | ||||||
|         init_thread(config.clone(), ThreadType::Bars(Class::Crypto)); |  | ||||||
|     let (news_websocket_sender, news_backfill_sender) = |  | ||||||
|         init_thread(config.clone(), ThreadType::News); |  | ||||||
|  |  | ||||||
|     loop { |  | ||||||
|         select! { |  | ||||||
|             Some(message) = receiver.recv() => { |  | ||||||
|                 spawn(handle_message( |  | ||||||
|                     config.clone(), |  | ||||||
|                     bars_us_equity_websocket_sender.clone(), |  | ||||||
|                     bars_us_equity_backfill_sender.clone(), |  | ||||||
|                     bars_crypto_websocket_sender.clone(), |  | ||||||
|                     bars_crypto_backfill_sender.clone(), |  | ||||||
|                     news_websocket_sender.clone(), |  | ||||||
|                     news_backfill_sender.clone(), |  | ||||||
|                     message, |  | ||||||
|                 )); |  | ||||||
|             } |  | ||||||
|             Some(message) = clock_receiver.recv() => { |  | ||||||
|                 spawn(handle_clock_message( |  | ||||||
|                     config.clone(), |  | ||||||
|                     bars_us_equity_backfill_sender.clone(), |  | ||||||
|                     bars_crypto_backfill_sender.clone(), |  | ||||||
|                     news_backfill_sender.clone(), |  | ||||||
|                     message, |  | ||||||
|                 )); |  | ||||||
|             } |  | ||||||
|             else => panic!("Communication channel unexpectedly closed.") |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| fn init_thread( |  | ||||||
|     config: Arc<Config>, |  | ||||||
|     thread_type: ThreadType, |  | ||||||
| ) -> ( |  | ||||||
|     mpsc::Sender<websocket::Message>, |  | ||||||
|     mpsc::Sender<backfill::Message>, |  | ||||||
| ) { |  | ||||||
|     let websocket_url = match thread_type { |  | ||||||
|         ThreadType::Bars(Class::UsEquity) => { |  | ||||||
|             format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE) |  | ||||||
|         } |  | ||||||
|         ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(), |  | ||||||
|         ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let backfill_handler = match thread_type { |  | ||||||
|         ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type), |  | ||||||
|         ThreadType::News => backfill::news::create_handler(config.clone()), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let (backfill_sender, backfill_receiver) = mpsc::channel(100); |  | ||||||
|  |  | ||||||
|     spawn(backfill::run(backfill_handler.into(), backfill_receiver)); |  | ||||||
|  |  | ||||||
|     let websocket_handler = match thread_type { |  | ||||||
|         ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type), |  | ||||||
|         ThreadType::News => websocket::news::create_handler(&config), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let (websocket_sender, websocket_receiver) = mpsc::channel(100); |  | ||||||
|  |  | ||||||
|     spawn(websocket::run( |  | ||||||
|         websocket_handler.into(), |  | ||||||
|         websocket_receiver, |  | ||||||
|         websocket_url, |  | ||||||
|     )); |  | ||||||
|  |  | ||||||
|     (websocket_sender, backfill_sender) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::too_many_arguments)] |  | ||||||
| #[allow(clippy::too_many_lines)] |  | ||||||
| async fn handle_message( |  | ||||||
|     config: Arc<Config>, |  | ||||||
|     bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>, |  | ||||||
|     bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>, |  | ||||||
|     bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     news_websocket_sender: mpsc::Sender<websocket::Message>, |  | ||||||
|     news_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     message: Message, |  | ||||||
| ) { |  | ||||||
|     let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message |  | ||||||
|         .assets |  | ||||||
|         .clone() |  | ||||||
|         .into_iter() |  | ||||||
|         .partition_map(|asset| match asset.1 { |  | ||||||
|             Class::UsEquity => Either::Left(asset.0), |  | ||||||
|             Class::Crypto => Either::Right(asset.0), |  | ||||||
|         }); |  | ||||||
|  |  | ||||||
|     let symbols = message.assets.map(|(symbol, _)| symbol); |  | ||||||
|  |  | ||||||
|     let bars_us_equity_future = async { |  | ||||||
|         if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_us_equity_websocket_sender, |  | ||||||
|                 websocket::Message::new, |  | ||||||
|                 message.action.into(), |  | ||||||
|                 us_equity_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let bars_crypto_future = async { |  | ||||||
|         if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_crypto_websocket_sender, |  | ||||||
|                 websocket::Message::new, |  | ||||||
|                 message.action.into(), |  | ||||||
|                 crypto_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let news_future = async { |  | ||||||
|         create_send_await!( |  | ||||||
|             news_websocket_sender, |  | ||||||
|             websocket::Message::new, |  | ||||||
|             message.action.into(), |  | ||||||
|             symbols.clone() |  | ||||||
|         ); |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     join!(bars_us_equity_future, bars_crypto_future, news_future); |  | ||||||
|  |  | ||||||
|     if message.action == Action::Disable { |  | ||||||
|         message.response.send(()).unwrap(); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     match message.action { |  | ||||||
|         Action::Add | Action::Enable => { |  | ||||||
|             let symbols = Vec::from(symbols.clone()); |  | ||||||
|  |  | ||||||
|             let assets = async { |  | ||||||
|                 alpaca::assets::get_by_symbols( |  | ||||||
|                     &config.alpaca_client, |  | ||||||
|                     &config.alpaca_rate_limiter, |  | ||||||
|                     &symbols, |  | ||||||
|                     None, |  | ||||||
|                     &ALPACA_API_BASE, |  | ||||||
|                 ) |  | ||||||
|                 .await |  | ||||||
|                 .unwrap() |  | ||||||
|                 .into_iter() |  | ||||||
|                 .map(|asset| (asset.symbol.clone(), asset)) |  | ||||||
|                 .collect::<HashMap<_, _>>() |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             let positions = async { |  | ||||||
|                 alpaca::positions::get_by_symbols( |  | ||||||
|                     &config.alpaca_client, |  | ||||||
|                     &config.alpaca_rate_limiter, |  | ||||||
|                     &symbols, |  | ||||||
|                     None, |  | ||||||
|                     &ALPACA_API_BASE, |  | ||||||
|                 ) |  | ||||||
|                 .await |  | ||||||
|                 .unwrap() |  | ||||||
|                 .into_iter() |  | ||||||
|                 .map(|position| (position.symbol.clone(), position)) |  | ||||||
|                 .collect::<HashMap<_, _>>() |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             let (mut assets, mut positions) = join!(assets, positions); |  | ||||||
|  |  | ||||||
|             let batch = |  | ||||||
|                 symbols |  | ||||||
|                     .iter() |  | ||||||
|                     .fold(Vec::with_capacity(symbols.len()), |mut batch, symbol| { |  | ||||||
|                         if let Some(asset) = assets.remove(symbol) { |  | ||||||
|                             let position = positions.remove(symbol); |  | ||||||
|                             batch.push(Asset::from((asset, position))); |  | ||||||
|                         } else { |  | ||||||
|                             error!("Failed to find asset for symbol: {}.", symbol); |  | ||||||
|                         } |  | ||||||
|                         batch |  | ||||||
|                     }); |  | ||||||
|  |  | ||||||
|             database::assets::upsert_batch( |  | ||||||
|                 &config.clickhouse_client, |  | ||||||
|                 &config.clickhouse_concurrency_limiter, |  | ||||||
|                 &batch, |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|         } |  | ||||||
|         Action::Remove => { |  | ||||||
|             database::assets::delete_where_symbols( |  | ||||||
|                 &config.clickhouse_client, |  | ||||||
|                 &config.clickhouse_concurrency_limiter, |  | ||||||
|                 &Vec::from(symbols.clone()), |  | ||||||
|             ) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|         } |  | ||||||
|         Action::Disable => unreachable!(), |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let bars_us_equity_future = async { |  | ||||||
|         if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_us_equity_backfill_sender, |  | ||||||
|                 backfill::Message::new, |  | ||||||
|                 match message.action { |  | ||||||
|                     Action::Add | Action::Enable => backfill::Action::Backfill, |  | ||||||
|                     Action::Remove => backfill::Action::Purge, |  | ||||||
|                     Action::Disable => unreachable!(), |  | ||||||
|                 }, |  | ||||||
|                 us_equity_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let bars_crypto_future = async { |  | ||||||
|         if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_crypto_backfill_sender, |  | ||||||
|                 backfill::Message::new, |  | ||||||
|                 match message.action { |  | ||||||
|                     Action::Add | Action::Enable => backfill::Action::Backfill, |  | ||||||
|                     Action::Remove => backfill::Action::Purge, |  | ||||||
|                     Action::Disable => unreachable!(), |  | ||||||
|                 }, |  | ||||||
|                 crypto_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let news_future = async { |  | ||||||
|         create_send_await!( |  | ||||||
|             news_backfill_sender, |  | ||||||
|             backfill::Message::new, |  | ||||||
|             match message.action { |  | ||||||
|                 Action::Add | Action::Enable => backfill::Action::Backfill, |  | ||||||
|                 Action::Remove => backfill::Action::Purge, |  | ||||||
|                 Action::Disable => unreachable!(), |  | ||||||
|             }, |  | ||||||
|             symbols |  | ||||||
|         ); |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     join!(bars_us_equity_future, bars_crypto_future, news_future); |  | ||||||
|  |  | ||||||
|     message.response.send(()).unwrap(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| async fn handle_clock_message( |  | ||||||
|     config: Arc<Config>, |  | ||||||
|     bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     news_backfill_sender: mpsc::Sender<backfill::Message>, |  | ||||||
|     message: clock::Message, |  | ||||||
| ) { |  | ||||||
|     if message.status == clock::Status::Closed { |  | ||||||
|         database::cleanup_all( |  | ||||||
|             &config.clickhouse_client, |  | ||||||
|             &config.clickhouse_concurrency_limiter, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         .unwrap(); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let assets = database::assets::select( |  | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .unwrap(); |  | ||||||
|  |  | ||||||
|     let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets |  | ||||||
|         .clone() |  | ||||||
|         .into_iter() |  | ||||||
|         .partition_map(|asset| match asset.class { |  | ||||||
|             Class::UsEquity => Either::Left(asset.symbol), |  | ||||||
|             Class::Crypto => Either::Right(asset.symbol), |  | ||||||
|         }); |  | ||||||
|  |  | ||||||
|     let symbols = assets |  | ||||||
|         .into_iter() |  | ||||||
|         .map(|asset| asset.symbol) |  | ||||||
|         .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|     let bars_us_equity_future = async { |  | ||||||
|         if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_us_equity_backfill_sender, |  | ||||||
|                 backfill::Message::new, |  | ||||||
|                 backfill::Action::Backfill, |  | ||||||
|                 us_equity_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let bars_crypto_future = async { |  | ||||||
|         if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 bars_crypto_backfill_sender, |  | ||||||
|                 backfill::Message::new, |  | ||||||
|                 backfill::Action::Backfill, |  | ||||||
|                 crypto_symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let news_future = async { |  | ||||||
|         if let Some(symbols) = NonEmpty::from_vec(symbols) { |  | ||||||
|             create_send_await!( |  | ||||||
|                 news_backfill_sender, |  | ||||||
|                 backfill::Message::new, |  | ||||||
|                 backfill::Action::Backfill, |  | ||||||
|                 symbols |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     join!(bars_us_equity_future, bars_crypto_future, news_future); |  | ||||||
| } |  | ||||||
| @@ -1,171 +0,0 @@ | |||||||
| use super::State; |  | ||||||
| use crate::{ |  | ||||||
|     config::{Config, CLICKHOUSE_BATCH_BARS_SIZE}, |  | ||||||
|     database, |  | ||||||
|     threads::data::ThreadType, |  | ||||||
| }; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use clickhouse::inserter::Inserter; |  | ||||||
| use log::{debug, error, info}; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{ |  | ||||||
|     types::{alpaca::websocket, Bar, Class}, |  | ||||||
|     utils::ONE_SECOND, |  | ||||||
| }; |  | ||||||
| use std::{ |  | ||||||
|     collections::{HashMap, HashSet}, |  | ||||||
|     sync::Arc, |  | ||||||
| }; |  | ||||||
| use tokio::sync::{Mutex, RwLock}; |  | ||||||
|  |  | ||||||
| pub struct Handler { |  | ||||||
|     pub config: Arc<Config>, |  | ||||||
|     pub inserter: Arc<Mutex<Inserter<Bar>>>, |  | ||||||
|     pub subscription_message_constructor: |  | ||||||
|         fn(NonEmpty<String>) -> websocket::data::outgoing::subscribe::Message, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| impl super::Handler for Handler { |  | ||||||
|     fn create_subscription_message( |  | ||||||
|         &self, |  | ||||||
|         symbols: NonEmpty<String>, |  | ||||||
|     ) -> websocket::data::outgoing::subscribe::Message { |  | ||||||
|         (self.subscription_message_constructor)(symbols) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn handle_websocket_message( |  | ||||||
|         &self, |  | ||||||
|         state: Arc<RwLock<State>>, |  | ||||||
|         message: websocket::data::incoming::Message, |  | ||||||
|     ) { |  | ||||||
|         match message { |  | ||||||
|             websocket::data::incoming::Message::Subscription(message) => { |  | ||||||
|                 let websocket::data::incoming::subscription::Message::Market { |  | ||||||
|                     bars: symbols, .. |  | ||||||
|                 } = message |  | ||||||
|                 else { |  | ||||||
|                     unreachable!() |  | ||||||
|                 }; |  | ||||||
|  |  | ||||||
|                 let symbols = symbols.into_iter().collect::<HashSet<_>>(); |  | ||||||
|                 let mut state = state.write().await; |  | ||||||
|  |  | ||||||
|                 let newly_subscribed = state |  | ||||||
|                     .pending_subscriptions |  | ||||||
|                     .extract_if(|symbol, _| symbols.contains(symbol)) |  | ||||||
|                     .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|                 let newly_unsubscribed = state |  | ||||||
|                     .pending_unsubscriptions |  | ||||||
|                     .extract_if(|symbol, _| !symbols.contains(symbol)) |  | ||||||
|                     .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|                 state |  | ||||||
|                     .active_subscriptions |  | ||||||
|                     .extend(newly_subscribed.keys().cloned()); |  | ||||||
|  |  | ||||||
|                 drop(state); |  | ||||||
|  |  | ||||||
|                 if !newly_subscribed.is_empty() { |  | ||||||
|                     info!( |  | ||||||
|                         "Subscribed to bars for {:?}.", |  | ||||||
|                         newly_subscribed.keys().collect::<Vec<_>>() |  | ||||||
|                     ); |  | ||||||
|  |  | ||||||
|                     for sender in newly_subscribed.into_values() { |  | ||||||
|                         sender.send(()).unwrap(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 if !newly_unsubscribed.is_empty() { |  | ||||||
|                     info!( |  | ||||||
|                         "Unsubscribed from bars for {:?}.", |  | ||||||
|                         newly_unsubscribed.keys().collect::<Vec<_>>() |  | ||||||
|                     ); |  | ||||||
|  |  | ||||||
|                     for sender in newly_unsubscribed.into_values() { |  | ||||||
|                         sender.send(()).unwrap(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             websocket::data::incoming::Message::Bar(message) |  | ||||||
|             | websocket::data::incoming::Message::UpdatedBar(message) => { |  | ||||||
|                 let bar = Bar::from(message); |  | ||||||
|                 debug!("Received bar for {}: {}.", bar.symbol, bar.time); |  | ||||||
|                 self.inserter.lock().await.write(&bar).await.unwrap(); |  | ||||||
|             } |  | ||||||
|             websocket::data::incoming::Message::Status(message) => { |  | ||||||
|                 debug!( |  | ||||||
|                     "Received status message for {}: {:?}.", |  | ||||||
|                     message.symbol, message.status |  | ||||||
|                 ); |  | ||||||
|  |  | ||||||
|                 match message.status { |  | ||||||
|                     websocket::data::incoming::status::Status::TradingHalt(_) |  | ||||||
|                     | websocket::data::incoming::status::Status::VolatilityTradingPause(_) => { |  | ||||||
|                         database::assets::update_status_where_symbol( |  | ||||||
|                             &self.config.clickhouse_client, |  | ||||||
|                             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                             &message.symbol, |  | ||||||
|                             false, |  | ||||||
|                         ) |  | ||||||
|                         .await |  | ||||||
|                         .unwrap(); |  | ||||||
|                     } |  | ||||||
|                     websocket::data::incoming::status::Status::Resume(_) |  | ||||||
|                     | websocket::data::incoming::status::Status::TradingResumption(_) => { |  | ||||||
|                         database::assets::update_status_where_symbol( |  | ||||||
|                             &self.config.clickhouse_client, |  | ||||||
|                             &self.config.clickhouse_concurrency_limiter, |  | ||||||
|                             &message.symbol, |  | ||||||
|                             true, |  | ||||||
|                         ) |  | ||||||
|                         .await |  | ||||||
|                         .unwrap(); |  | ||||||
|                     } |  | ||||||
|                     _ => {} |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             websocket::data::incoming::Message::Error(message) => { |  | ||||||
|                 error!("Received error message: {}.", message.message); |  | ||||||
|             } |  | ||||||
|             _ => unreachable!(), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn log_string(&self) -> &'static str { |  | ||||||
|         "bars" |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn run_inserter(&self) { |  | ||||||
|         super::run_inserter(self.inserter.clone()).await; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> { |  | ||||||
|     let inserter = Arc::new(Mutex::new( |  | ||||||
|         config |  | ||||||
|             .clickhouse_client |  | ||||||
|             .inserter("bars") |  | ||||||
|             .unwrap() |  | ||||||
|             .with_period(Some(ONE_SECOND)) |  | ||||||
|             .with_max_entries((*CLICKHOUSE_BATCH_BARS_SIZE).try_into().unwrap()), |  | ||||||
|     )); |  | ||||||
|  |  | ||||||
|     let subscription_message_constructor = match thread_type { |  | ||||||
|         ThreadType::Bars(Class::UsEquity) => { |  | ||||||
|             websocket::data::outgoing::subscribe::Message::new_market_us_equity |  | ||||||
|         } |  | ||||||
|         ThreadType::Bars(Class::Crypto) => { |  | ||||||
|             websocket::data::outgoing::subscribe::Message::new_market_crypto |  | ||||||
|         } |  | ||||||
|         _ => unreachable!(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     Box::new(Handler { |  | ||||||
|         config, |  | ||||||
|         inserter, |  | ||||||
|         subscription_message_constructor, |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
| @@ -1,353 +0,0 @@ | |||||||
| pub mod bars; |  | ||||||
| pub mod news; |  | ||||||
|  |  | ||||||
| use crate::config::{ALPACA_API_KEY, ALPACA_API_SECRET}; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use clickhouse::{inserter::Inserter, Row}; |  | ||||||
| use futures_util::{future::join_all, SinkExt, StreamExt}; |  | ||||||
| use log::error; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::types::alpaca::{self, websocket}; |  | ||||||
| use serde::Serialize; |  | ||||||
| use serde_json::{from_str, to_string}; |  | ||||||
| use std::{ |  | ||||||
|     collections::{HashMap, HashSet}, |  | ||||||
|     sync::Arc, |  | ||||||
|     time::Duration, |  | ||||||
| }; |  | ||||||
| use tokio::{ |  | ||||||
|     net::TcpStream, |  | ||||||
|     select, spawn, |  | ||||||
|     sync::{mpsc, oneshot, Mutex, RwLock}, |  | ||||||
| }; |  | ||||||
| use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream}; |  | ||||||
|  |  | ||||||
| pub enum Action { |  | ||||||
|     Subscribe, |  | ||||||
|     Unsubscribe, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<super::Action> for Option<Action> { |  | ||||||
|     fn from(action: super::Action) -> Self { |  | ||||||
|         match action { |  | ||||||
|             super::Action::Add | super::Action::Enable => Some(Action::Subscribe), |  | ||||||
|             super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct Message { |  | ||||||
|     pub action: Option<Action>, |  | ||||||
|     pub symbols: NonEmpty<String>, |  | ||||||
|     pub response: oneshot::Sender<()>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Message { |  | ||||||
|     pub fn new(action: Option<Action>, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) { |  | ||||||
|         let (sender, receiver) = oneshot::channel(); |  | ||||||
|         ( |  | ||||||
|             Self { |  | ||||||
|                 action, |  | ||||||
|                 symbols, |  | ||||||
|                 response: sender, |  | ||||||
|             }, |  | ||||||
|             receiver, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct State { |  | ||||||
|     pub active_subscriptions: HashSet<String>, |  | ||||||
|     pub pending_subscriptions: HashMap<String, oneshot::Sender<()>>, |  | ||||||
|     pub pending_unsubscriptions: HashMap<String, oneshot::Sender<()>>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| pub trait Handler: Send + Sync + 'static { |  | ||||||
|     fn create_subscription_message( |  | ||||||
|         &self, |  | ||||||
|         symbols: NonEmpty<String>, |  | ||||||
|     ) -> websocket::data::outgoing::subscribe::Message; |  | ||||||
|     async fn handle_websocket_message( |  | ||||||
|         &self, |  | ||||||
|         state: Arc<RwLock<State>>, |  | ||||||
|         message: websocket::data::incoming::Message, |  | ||||||
|     ); |  | ||||||
|     fn log_string(&self) -> &'static str; |  | ||||||
|     async fn run_inserter(&self); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn run( |  | ||||||
|     handler: Arc<Box<dyn Handler>>, |  | ||||||
|     mut receiver: mpsc::Receiver<Message>, |  | ||||||
|     websocket_url: String, |  | ||||||
| ) { |  | ||||||
|     let state = Arc::new(RwLock::new(State { |  | ||||||
|         active_subscriptions: HashSet::new(), |  | ||||||
|         pending_subscriptions: HashMap::new(), |  | ||||||
|         pending_unsubscriptions: HashMap::new(), |  | ||||||
|     })); |  | ||||||
|  |  | ||||||
|     let handler_clone = handler.clone(); |  | ||||||
|     spawn(async move { handler_clone.run_inserter().await }); |  | ||||||
|  |  | ||||||
|     let (sink_sender, sink_receiver) = mpsc::channel(100); |  | ||||||
|     let (stream_sender, mut stream_receiver) = mpsc::channel(10_000); |  | ||||||
|  |  | ||||||
|     spawn(run_connection( |  | ||||||
|         handler.clone(), |  | ||||||
|         sink_receiver, |  | ||||||
|         stream_sender, |  | ||||||
|         websocket_url.clone(), |  | ||||||
|         state.clone(), |  | ||||||
|     )); |  | ||||||
|  |  | ||||||
|     loop { |  | ||||||
|         select! { |  | ||||||
|             Some(message) = receiver.recv() => { |  | ||||||
|                 spawn(handle_message( |  | ||||||
|                     handler.clone(), |  | ||||||
|                     state.clone(), |  | ||||||
|                     sink_sender.clone(), |  | ||||||
|                     message, |  | ||||||
|                 )); |  | ||||||
|             } |  | ||||||
|             Some(message) = stream_receiver.recv() => { |  | ||||||
|                 match message { |  | ||||||
|                     tungstenite::Message::Text(message) => { |  | ||||||
|                         let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message); |  | ||||||
|  |  | ||||||
|                         if parsed_message.is_err() { |  | ||||||
|                             error!("Failed to deserialize websocket message: {:?}.", message); |  | ||||||
|                             continue; |  | ||||||
|                         } |  | ||||||
|  |  | ||||||
|                         for message in parsed_message.unwrap() { |  | ||||||
|                             let handler = handler.clone(); |  | ||||||
|                             let state = state.clone(); |  | ||||||
|                             spawn(async move { |  | ||||||
|                                 handler.handle_websocket_message(state, message).await; |  | ||||||
|                             }); |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                     tungstenite::Message::Ping(_) => {} |  | ||||||
|                     _ => error!("Unexpected websocket message: {:?}.", message), |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             else => panic!("Communication channel unexpectedly closed.") |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::too_many_lines)] |  | ||||||
| async fn run_connection( |  | ||||||
|     handler: Arc<Box<dyn Handler>>, |  | ||||||
|     mut sink_receiver: mpsc::Receiver<tungstenite::Message>, |  | ||||||
|     stream_sender: mpsc::Sender<tungstenite::Message>, |  | ||||||
|     websocket_url: String, |  | ||||||
|     state: Arc<RwLock<State>>, |  | ||||||
| ) { |  | ||||||
|     let mut peek = None; |  | ||||||
|  |  | ||||||
|     'connection: loop { |  | ||||||
|         let (websocket, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = retry_notify( |  | ||||||
|             ExponentialBackoff::default(), |  | ||||||
|             || async { |  | ||||||
|                 connect_async(websocket_url.clone()) |  | ||||||
|                     .await |  | ||||||
|                     .map_err(Into::into) |  | ||||||
|             }, |  | ||||||
|             |e, duration: Duration| { |  | ||||||
|                 error!( |  | ||||||
|                     "Failed to connect to {} websocket, will retry in {} seconds: {}.", |  | ||||||
|                     handler.log_string(), |  | ||||||
|                     duration.as_secs(), |  | ||||||
|                     e |  | ||||||
|                 ); |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         .unwrap(); |  | ||||||
|  |  | ||||||
|         let (mut sink, mut stream) = websocket.split(); |  | ||||||
|         alpaca::websocket::data::authenticate( |  | ||||||
|             &mut sink, |  | ||||||
|             &mut stream, |  | ||||||
|             (*ALPACA_API_KEY).to_string(), |  | ||||||
|             (*ALPACA_API_SECRET).to_string(), |  | ||||||
|         ) |  | ||||||
|         .await; |  | ||||||
|  |  | ||||||
|         let mut state = state.write().await; |  | ||||||
|  |  | ||||||
|         state |  | ||||||
|             .pending_unsubscriptions |  | ||||||
|             .drain() |  | ||||||
|             .for_each(|(_, sender)| { |  | ||||||
|                 sender.send(()).unwrap(); |  | ||||||
|             }); |  | ||||||
|  |  | ||||||
|         let (recovered_subscriptions, receivers) = state |  | ||||||
|             .active_subscriptions |  | ||||||
|             .iter() |  | ||||||
|             .map(|symbol| { |  | ||||||
|                 let (sender, receiver) = oneshot::channel(); |  | ||||||
|                 ((symbol.clone(), sender), receiver) |  | ||||||
|             }) |  | ||||||
|             .unzip::<_, _, Vec<_>, Vec<_>>(); |  | ||||||
|  |  | ||||||
|         state.pending_subscriptions.extend(recovered_subscriptions); |  | ||||||
|  |  | ||||||
|         let pending_subscriptions = state |  | ||||||
|             .pending_subscriptions |  | ||||||
|             .keys() |  | ||||||
|             .cloned() |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         drop(state); |  | ||||||
|  |  | ||||||
|         if let Some(pending_subscriptions) = NonEmpty::from_vec(pending_subscriptions) { |  | ||||||
|             if let Err(err) = sink |  | ||||||
|                 .send(tungstenite::Message::Text( |  | ||||||
|                     to_string(&websocket::data::outgoing::Message::Subscribe( |  | ||||||
|                         handler.create_subscription_message(pending_subscriptions), |  | ||||||
|                     )) |  | ||||||
|                     .unwrap(), |  | ||||||
|                 )) |  | ||||||
|                 .await |  | ||||||
|             { |  | ||||||
|                 error!("Failed to send websocket message: {:?}.", err); |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         join_all(receivers).await; |  | ||||||
|  |  | ||||||
|         if peek.is_some() { |  | ||||||
|             if let Err(err) = sink.send(peek.clone().unwrap()).await { |  | ||||||
|                 error!("Failed to send websocket message: {:?}.", err); |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|             peek = None; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         loop { |  | ||||||
|             select! { |  | ||||||
|                 Some(message) = sink_receiver.recv() => { |  | ||||||
|                     peek = Some(message.clone()); |  | ||||||
|  |  | ||||||
|                     if let Err(err) = sink.send(message).await { |  | ||||||
|                         error!("Failed to send websocket message: {:?}.", err); |  | ||||||
|                         continue 'connection; |  | ||||||
|                     }; |  | ||||||
|  |  | ||||||
|                     peek = None; |  | ||||||
|                 } |  | ||||||
|                 message = stream.next() => { |  | ||||||
|                     if message.is_none() { |  | ||||||
|                         error!("Websocket stream unexpectedly closed."); |  | ||||||
|                         continue 'connection; |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     let message = message.unwrap(); |  | ||||||
|  |  | ||||||
|                     if let Err(err) = message { |  | ||||||
|                         error!("Failed to receive websocket message: {:?}.", err); |  | ||||||
|                         continue 'connection; |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     let message = message.unwrap(); |  | ||||||
|  |  | ||||||
|                     if message.is_close() { |  | ||||||
|                         error!("Websocket connection closed."); |  | ||||||
|                         continue 'connection; |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     stream_sender.send(message).await.unwrap(); |  | ||||||
|                 } |  | ||||||
|                 else => error!("Communication channel unexpectedly closed.") |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| async fn handle_message( |  | ||||||
|     handler: Arc<Box<dyn Handler>>, |  | ||||||
|     pending: Arc<RwLock<State>>, |  | ||||||
|     sink_sender: mpsc::Sender<tungstenite::Message>, |  | ||||||
|     message: Message, |  | ||||||
| ) { |  | ||||||
|     match message.action { |  | ||||||
|         Some(Action::Subscribe) => { |  | ||||||
|             let (pending_subscriptions, receivers) = message |  | ||||||
|                 .symbols |  | ||||||
|                 .iter() |  | ||||||
|                 .map(|symbol| { |  | ||||||
|                     let (sender, receiver) = oneshot::channel(); |  | ||||||
|                     ((symbol.clone(), sender), receiver) |  | ||||||
|                 }) |  | ||||||
|                 .unzip::<_, _, Vec<_>, Vec<_>>(); |  | ||||||
|  |  | ||||||
|             pending |  | ||||||
|                 .write() |  | ||||||
|                 .await |  | ||||||
|                 .pending_subscriptions |  | ||||||
|                 .extend(pending_subscriptions); |  | ||||||
|  |  | ||||||
|             sink_sender |  | ||||||
|                 .send(tungstenite::Message::Text( |  | ||||||
|                     to_string(&websocket::data::outgoing::Message::Subscribe( |  | ||||||
|                         handler.create_subscription_message(message.symbols), |  | ||||||
|                     )) |  | ||||||
|                     .unwrap(), |  | ||||||
|                 )) |  | ||||||
|                 .await |  | ||||||
|                 .unwrap(); |  | ||||||
|  |  | ||||||
|             join_all(receivers).await; |  | ||||||
|         } |  | ||||||
|         Some(Action::Unsubscribe) => { |  | ||||||
|             let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message |  | ||||||
|                 .symbols |  | ||||||
|                 .iter() |  | ||||||
|                 .map(|symbol| { |  | ||||||
|                     let (sender, receiver) = oneshot::channel(); |  | ||||||
|                     ((symbol.clone(), sender), receiver) |  | ||||||
|                 }) |  | ||||||
|                 .unzip(); |  | ||||||
|  |  | ||||||
|             pending |  | ||||||
|                 .write() |  | ||||||
|                 .await |  | ||||||
|                 .pending_unsubscriptions |  | ||||||
|                 .extend(pending_unsubscriptions); |  | ||||||
|  |  | ||||||
|             sink_sender |  | ||||||
|                 .send(tungstenite::Message::Text( |  | ||||||
|                     to_string(&websocket::data::outgoing::Message::Unsubscribe( |  | ||||||
|                         handler.create_subscription_message(message.symbols.clone()), |  | ||||||
|                     )) |  | ||||||
|                     .unwrap(), |  | ||||||
|                 )) |  | ||||||
|                 .await |  | ||||||
|                 .unwrap(); |  | ||||||
|  |  | ||||||
|             join_all(receivers).await; |  | ||||||
|         } |  | ||||||
|         None => {} |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     message.response.send(()).unwrap(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| async fn run_inserter<T>(inserter: Arc<Mutex<Inserter<T>>>) |  | ||||||
| where |  | ||||||
|     T: Row + Serialize, |  | ||||||
| { |  | ||||||
|     loop { |  | ||||||
|         let time_left = inserter.lock().await.time_left().unwrap(); |  | ||||||
|         tokio::time::sleep(time_left).await; |  | ||||||
|         inserter.lock().await.commit().await.unwrap(); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,119 +0,0 @@ | |||||||
| use super::State; |  | ||||||
| use crate::config::{Config, CLICKHOUSE_BATCH_NEWS_SIZE}; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| use clickhouse::inserter::Inserter; |  | ||||||
| use log::{debug, error, info}; |  | ||||||
| use nonempty::NonEmpty; |  | ||||||
| use qrust::{ |  | ||||||
|     types::{alpaca::websocket, News}, |  | ||||||
|     utils::ONE_SECOND, |  | ||||||
| }; |  | ||||||
| use std::{ |  | ||||||
|     collections::{HashMap, HashSet}, |  | ||||||
|     sync::Arc, |  | ||||||
| }; |  | ||||||
| use tokio::sync::{Mutex, RwLock}; |  | ||||||
|  |  | ||||||
| pub struct Handler { |  | ||||||
|     pub inserter: Arc<Mutex<Inserter<News>>>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[async_trait] |  | ||||||
| impl super::Handler for Handler { |  | ||||||
|     fn create_subscription_message( |  | ||||||
|         &self, |  | ||||||
|         symbols: NonEmpty<String>, |  | ||||||
|     ) -> websocket::data::outgoing::subscribe::Message { |  | ||||||
|         websocket::data::outgoing::subscribe::Message::new_news(symbols) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn handle_websocket_message( |  | ||||||
|         &self, |  | ||||||
|         state: Arc<RwLock<State>>, |  | ||||||
|         message: websocket::data::incoming::Message, |  | ||||||
|     ) { |  | ||||||
|         match message { |  | ||||||
|             websocket::data::incoming::Message::Subscription(message) => { |  | ||||||
|                 let websocket::data::incoming::subscription::Message::News { news: symbols } = |  | ||||||
|                     message |  | ||||||
|                 else { |  | ||||||
|                     unreachable!() |  | ||||||
|                 }; |  | ||||||
|  |  | ||||||
|                 let symbols = symbols.into_iter().collect::<HashSet<_>>(); |  | ||||||
|                 let mut state = state.write().await; |  | ||||||
|  |  | ||||||
|                 let newly_subscribed = state |  | ||||||
|                     .pending_subscriptions |  | ||||||
|                     .extract_if(|symbol, _| symbols.contains(symbol)) |  | ||||||
|                     .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|                 let newly_unsubscribed = state |  | ||||||
|                     .pending_unsubscriptions |  | ||||||
|                     .extract_if(|symbol, _| !symbols.contains(symbol)) |  | ||||||
|                     .collect::<HashMap<_, _>>(); |  | ||||||
|  |  | ||||||
|                 state |  | ||||||
|                     .active_subscriptions |  | ||||||
|                     .extend(newly_subscribed.keys().cloned()); |  | ||||||
|  |  | ||||||
|                 drop(state); |  | ||||||
|  |  | ||||||
|                 if !newly_subscribed.is_empty() { |  | ||||||
|                     info!( |  | ||||||
|                         "Subscribed to news for {:?}.", |  | ||||||
|                         newly_subscribed.keys().collect::<Vec<_>>() |  | ||||||
|                     ); |  | ||||||
|  |  | ||||||
|                     for sender in newly_subscribed.into_values() { |  | ||||||
|                         sender.send(()).unwrap(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 if !newly_unsubscribed.is_empty() { |  | ||||||
|                     info!( |  | ||||||
|                         "Unsubscribed from news for {:?}.", |  | ||||||
|                         newly_unsubscribed.keys().collect::<Vec<_>>() |  | ||||||
|                     ); |  | ||||||
|  |  | ||||||
|                     for sender in newly_unsubscribed.into_values() { |  | ||||||
|                         sender.send(()).unwrap(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             websocket::data::incoming::Message::News(message) => { |  | ||||||
|                 let news = News::from(message); |  | ||||||
|                 debug!( |  | ||||||
|                     "Received news for {:?}: {}.", |  | ||||||
|                     news.symbols, news.time_created |  | ||||||
|                 ); |  | ||||||
|                 self.inserter.lock().await.write(&news).await.unwrap(); |  | ||||||
|             } |  | ||||||
|             websocket::data::incoming::Message::Error(message) => { |  | ||||||
|                 error!("Received error message: {}.", message.message); |  | ||||||
|             } |  | ||||||
|             _ => unreachable!(), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn log_string(&self) -> &'static str { |  | ||||||
|         "news" |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     async fn run_inserter(&self) { |  | ||||||
|         super::run_inserter(self.inserter.clone()).await; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn create_handler(config: &Arc<Config>) -> Box<dyn super::Handler> { |  | ||||||
|     let inserter = Arc::new(Mutex::new( |  | ||||||
|         config |  | ||||||
|             .clickhouse_client |  | ||||||
|             .inserter("news") |  | ||||||
|             .unwrap() |  | ||||||
|             .with_period(Some(ONE_SECOND)) |  | ||||||
|             .with_max_entries((*CLICKHOUSE_BATCH_NEWS_SIZE).try_into().unwrap()), |  | ||||||
|     )); |  | ||||||
|  |  | ||||||
|     Box::new(Handler { inserter }) |  | ||||||
| } |  | ||||||
| @@ -1,27 +0,0 @@ | |||||||
| mod websocket; |  | ||||||
|  |  | ||||||
| use crate::config::{Config, ALPACA_API_BASE, ALPACA_API_KEY, ALPACA_API_SECRET}; |  | ||||||
| use futures_util::StreamExt; |  | ||||||
| use qrust::types::alpaca; |  | ||||||
| use std::sync::Arc; |  | ||||||
| use tokio::spawn; |  | ||||||
| use tokio_tungstenite::connect_async; |  | ||||||
|  |  | ||||||
| pub async fn run(config: Arc<Config>) { |  | ||||||
|     let (websocket, _) = |  | ||||||
|         connect_async(&format!("wss://{}.alpaca.markets/stream", *ALPACA_API_BASE)) |  | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
|     let (mut websocket_sink, mut websocket_stream) = websocket.split(); |  | ||||||
|  |  | ||||||
|     alpaca::websocket::trading::authenticate( |  | ||||||
|         &mut websocket_sink, |  | ||||||
|         &mut websocket_stream, |  | ||||||
|         (*ALPACA_API_KEY).to_string(), |  | ||||||
|         (*ALPACA_API_SECRET).to_string(), |  | ||||||
|     ) |  | ||||||
|     .await; |  | ||||||
|     alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await; |  | ||||||
|  |  | ||||||
|     spawn(websocket::run(config, websocket_stream)); |  | ||||||
| } |  | ||||||
| @@ -1,133 +0,0 @@ | |||||||
| #![warn(clippy::all, clippy::pedantic, clippy::nursery)] |  | ||||||
|  |  | ||||||
| use burn::{ |  | ||||||
|     config::Config, |  | ||||||
|     data::{ |  | ||||||
|         dataloader::{DataLoaderBuilder, Dataset}, |  | ||||||
|         dataset::transform::{PartialDataset, ShuffledDataset}, |  | ||||||
|     }, |  | ||||||
|     module::Module, |  | ||||||
|     optim::AdamConfig, |  | ||||||
|     record::CompactRecorder, |  | ||||||
|     tensor::backend::AutodiffBackend, |  | ||||||
|     train::LearnerBuilder, |  | ||||||
| }; |  | ||||||
| use dotenv::dotenv; |  | ||||||
| use log::info; |  | ||||||
| use qrust::{ |  | ||||||
|     database, |  | ||||||
|     ml::{ |  | ||||||
|         BarWindow, BarWindowBatcher, ModelConfig, MultipleSymbolDataset, MyAutodiffBackend, DEVICE, |  | ||||||
|     }, |  | ||||||
|     types::Bar, |  | ||||||
| }; |  | ||||||
| use std::{env, fs, path::Path, sync::Arc}; |  | ||||||
| use tokio::sync::Semaphore; |  | ||||||
|  |  | ||||||
| #[derive(Config)] |  | ||||||
| pub struct TrainingConfig { |  | ||||||
|     pub model: ModelConfig, |  | ||||||
|     pub optimizer: AdamConfig, |  | ||||||
|     #[config(default = 100)] |  | ||||||
|     pub epochs: usize, |  | ||||||
|     #[config(default = 256)] |  | ||||||
|     pub batch_size: usize, |  | ||||||
|     #[config(default = 16)] |  | ||||||
|     pub num_workers: usize, |  | ||||||
|     #[config(default = 0)] |  | ||||||
|     pub seed: u64, |  | ||||||
|     #[config(default = 0.2)] |  | ||||||
|     pub valid_pct: f64, |  | ||||||
|     #[config(default = 1.0e-4)] |  | ||||||
|     pub learning_rate: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[tokio::main] |  | ||||||
| async fn main() { |  | ||||||
|     dotenv().ok(); |  | ||||||
|     let dir = Path::new(file!()).parent().unwrap(); |  | ||||||
|  |  | ||||||
|     let model_config = ModelConfig::new(); |  | ||||||
|     let optimizer = AdamConfig::new(); |  | ||||||
|  |  | ||||||
|     let training_config = TrainingConfig::new(model_config, optimizer); |  | ||||||
|  |  | ||||||
|     let clickhouse_client = clickhouse::Client::default() |  | ||||||
|         .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) |  | ||||||
|         .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) |  | ||||||
|         .with_password(env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.")) |  | ||||||
|         .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")); |  | ||||||
|  |  | ||||||
|     let clickhouse_concurrency_limiter = Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)); |  | ||||||
|  |  | ||||||
|     let bars = database::ta::select(&clickhouse_client, &clickhouse_concurrency_limiter) |  | ||||||
|         .await |  | ||||||
|         .unwrap(); |  | ||||||
|  |  | ||||||
|     info!("Loaded {} bars.", bars.len()); |  | ||||||
|  |  | ||||||
|     train::<MyAutodiffBackend>( |  | ||||||
|         bars, |  | ||||||
|         &training_config, |  | ||||||
|         dir.join("artifacts").to_str().unwrap(), |  | ||||||
|         &DEVICE, |  | ||||||
|     ); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::cast_possible_truncation)] |  | ||||||
| #[allow(clippy::cast_sign_loss)] |  | ||||||
| #[allow(clippy::cast_precision_loss)] |  | ||||||
| fn train<B: AutodiffBackend<FloatElem = f32, IntElem = i32>>( |  | ||||||
|     bars: Vec<Bar>, |  | ||||||
|     config: &TrainingConfig, |  | ||||||
|     dir: &str, |  | ||||||
|     device: &B::Device, |  | ||||||
| ) { |  | ||||||
|     B::seed(config.seed); |  | ||||||
|  |  | ||||||
|     fs::create_dir_all(dir).unwrap(); |  | ||||||
|  |  | ||||||
|     let dataset = MultipleSymbolDataset::new(bars); |  | ||||||
|     let dataset = ShuffledDataset::with_seed(dataset, config.seed); |  | ||||||
|     let dataset = Arc::new(dataset); |  | ||||||
|  |  | ||||||
|     let split = (dataset.len() as f64 * (1.0 - config.valid_pct)) as usize; |  | ||||||
|  |  | ||||||
|     let train: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> = |  | ||||||
|         PartialDataset::new(dataset.clone(), 0, split); |  | ||||||
|  |  | ||||||
|     let batcher_train = BarWindowBatcher::<B> { |  | ||||||
|         device: device.clone(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let dataloader_train = DataLoaderBuilder::new(batcher_train) |  | ||||||
|         .batch_size(config.batch_size) |  | ||||||
|         .num_workers(config.num_workers) |  | ||||||
|         .build(train); |  | ||||||
|  |  | ||||||
|     let valid: PartialDataset<Arc<ShuffledDataset<MultipleSymbolDataset, BarWindow>>, BarWindow> = |  | ||||||
|         PartialDataset::new(dataset.clone(), split, dataset.len()); |  | ||||||
|  |  | ||||||
|     let batcher_valid = BarWindowBatcher::<B::InnerBackend> { |  | ||||||
|         device: device.clone(), |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let dataloader_valid = DataLoaderBuilder::new(batcher_valid) |  | ||||||
|         .batch_size(config.batch_size) |  | ||||||
|         .num_workers(config.num_workers) |  | ||||||
|         .build(valid); |  | ||||||
|  |  | ||||||
|     let learner = LearnerBuilder::new(dir) |  | ||||||
|         .with_file_checkpointer(CompactRecorder::new()) |  | ||||||
|         .devices(vec![device.clone()]) |  | ||||||
|         .num_epochs(config.epochs) |  | ||||||
|         .build( |  | ||||||
|             config.model.init::<B>(device), |  | ||||||
|             config.optimizer.init(), |  | ||||||
|             config.learning_rate, |  | ||||||
|         ); |  | ||||||
|  |  | ||||||
|     let trained = learner.fit(dataloader_train, dataloader_valid); |  | ||||||
|  |  | ||||||
|     trained.save_file(dir, &CompactRecorder::new()).unwrap(); |  | ||||||
| } |  | ||||||
							
								
								
									
										123
									
								
								src/config.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								src/config.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | |||||||
|  | use crate::types::alpaca::shared::{Mode, Source}; | ||||||
|  | use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; | ||||||
|  | use lazy_static::lazy_static; | ||||||
|  | use reqwest::{ | ||||||
|  |     header::{HeaderMap, HeaderName, HeaderValue}, | ||||||
|  |     Client, | ||||||
|  | }; | ||||||
|  | use rust_bert::{ | ||||||
|  |     pipelines::{ | ||||||
|  |         common::{ModelResource, ModelType}, | ||||||
|  |         sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel}, | ||||||
|  |     }, | ||||||
|  |     resources::LocalResource, | ||||||
|  | }; | ||||||
|  | use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; | ||||||
|  | use tokio::sync::Mutex; | ||||||
|  |  | ||||||
|  | pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; | ||||||
|  | pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; | ||||||
|  | pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news"; | ||||||
|  |  | ||||||
|  | pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; | ||||||
|  | pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str = | ||||||
|  |     "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; | ||||||
|  | pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; | ||||||
|  |  | ||||||
|  | lazy_static! { | ||||||
|  |     pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") | ||||||
|  |         .expect("ALPACA_MODE must be set.") | ||||||
|  |         .parse() | ||||||
|  |         .expect("ALPACA_MODE must be 'live' or 'paper'"); | ||||||
|  |     pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") | ||||||
|  |         .expect("ALPACA_SOURCE must be set.") | ||||||
|  |         .parse() | ||||||
|  |         .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); | ||||||
|  |     pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); | ||||||
|  |     pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); | ||||||
|  |     #[derive(Debug)] | ||||||
|  |     pub static ref ALPACA_API_URL: String = format!( | ||||||
|  |         "https://{}.alpaca.markets/v2", | ||||||
|  |         match *ALPACA_MODE { | ||||||
|  |             Mode::Live => String::from("api"), | ||||||
|  |             Mode::Paper => String::from("paper-api"), | ||||||
|  |         } | ||||||
|  |     ); | ||||||
|  |     #[derive(Debug)] | ||||||
|  |     pub static ref ALPACA_WEBSOCKET_URL: String = format!( | ||||||
|  |         "wss://{}.alpaca.markets/stream", | ||||||
|  |         match *ALPACA_MODE { | ||||||
|  |             Mode::Live => String::from("api"), | ||||||
|  |             Mode::Paper => String::from("paper-api"), | ||||||
|  |         } | ||||||
|  |     ); | ||||||
|  |     pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS") | ||||||
|  |         .expect("MAX_BERT_INPUTS must be set.") | ||||||
|  |         .parse() | ||||||
|  |         .expect("MAX_BERT_INPUTS must be a positive integer."); | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub struct Config { | ||||||
|  |     pub alpaca_client: Client, | ||||||
|  |     pub alpaca_rate_limiter: DefaultDirectRateLimiter, | ||||||
|  |     pub clickhouse_client: clickhouse::Client, | ||||||
|  |     pub sequence_classifier: Mutex<SequenceClassificationModel>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Config { | ||||||
|  |     pub fn from_env() -> Self { | ||||||
|  |         Self { | ||||||
|  |             alpaca_client: Client::builder() | ||||||
|  |                 .default_headers(HeaderMap::from_iter([ | ||||||
|  |                     ( | ||||||
|  |                         HeaderName::from_static("apca-api-key-id"), | ||||||
|  |                         HeaderValue::from_str(&ALPACA_API_KEY) | ||||||
|  |                             .expect("Alpaca API key must not contain invalid characters."), | ||||||
|  |                     ), | ||||||
|  |                     ( | ||||||
|  |                         HeaderName::from_static("apca-api-secret-key"), | ||||||
|  |                         HeaderValue::from_str(&ALPACA_API_SECRET) | ||||||
|  |                             .expect("Alpaca API secret must not contain invalid characters."), | ||||||
|  |                     ), | ||||||
|  |                 ])) | ||||||
|  |                 .build() | ||||||
|  |                 .unwrap(), | ||||||
|  |             alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { | ||||||
|  |                 Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, | ||||||
|  |                 Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, | ||||||
|  |                 Source::Otc => unimplemented!("OTC rate limit not implemented."), | ||||||
|  |             })), | ||||||
|  |             clickhouse_client: clickhouse::Client::default() | ||||||
|  |                 .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) | ||||||
|  |                 .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) | ||||||
|  |                 .with_password( | ||||||
|  |                     env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), | ||||||
|  |                 ) | ||||||
|  |                 .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), | ||||||
|  |             sequence_classifier: Mutex::new( | ||||||
|  |                 SequenceClassificationModel::new(SequenceClassificationConfig::new( | ||||||
|  |                     ModelType::Bert, | ||||||
|  |                     ModelResource::Torch(Box::new(LocalResource { | ||||||
|  |                         local_path: PathBuf::from("./models/finbert/rust_model.ot"), | ||||||
|  |                     })), | ||||||
|  |                     LocalResource { | ||||||
|  |                         local_path: PathBuf::from("./models/finbert/config.json"), | ||||||
|  |                     }, | ||||||
|  |                     LocalResource { | ||||||
|  |                         local_path: PathBuf::from("./models/finbert/vocab.txt"), | ||||||
|  |                     }, | ||||||
|  |                     None, | ||||||
|  |                     true, | ||||||
|  |                     None, | ||||||
|  |                     None, | ||||||
|  |                 )) | ||||||
|  |                 .unwrap(), | ||||||
|  |             ), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn arc_from_env() -> Arc<Self> { | ||||||
|  |         Arc::new(Self::from_env()) | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -1,11 +1,8 @@ | |||||||
| use std::sync::Arc; |  | ||||||
| 
 |  | ||||||
| use crate::{ | use crate::{ | ||||||
|     delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, |     delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, | ||||||
| }; | }; | ||||||
| use clickhouse::{error::Error, Client}; | use clickhouse::{error::Error, Client}; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use tokio::sync::Semaphore; |  | ||||||
| 
 | 
 | ||||||
| select!(Asset, "assets"); | select!(Asset, "assets"); | ||||||
| select_where_symbol!(Asset, "assets"); | select_where_symbol!(Asset, "assets"); | ||||||
| @@ -14,16 +11,14 @@ delete_where_symbols!("assets"); | |||||||
| optimize!("assets"); | optimize!("assets"); | ||||||
| 
 | 
 | ||||||
| pub async fn update_status_where_symbol<T>( | pub async fn update_status_where_symbol<T>( | ||||||
|     client: &Client, |     clickhouse_client: &Client, | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |  | ||||||
|     symbol: &T, |     symbol: &T, | ||||||
|     status: bool, |     status: bool, | ||||||
| ) -> Result<(), Error> | ) -> Result<(), Error> | ||||||
| where | where | ||||||
|     T: AsRef<str> + Serialize + Send + Sync, |     T: AsRef<str> + Serialize + Send + Sync, | ||||||
| { | { | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |     clickhouse_client | ||||||
|     client |  | ||||||
|         .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?") |         .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?") | ||||||
|         .bind(status) |         .bind(status) | ||||||
|         .bind(symbol) |         .bind(symbol) | ||||||
| @@ -32,16 +27,14 @@ where | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub async fn update_qty_where_symbol<T>( | pub async fn update_qty_where_symbol<T>( | ||||||
|     client: &Client, |     clickhouse_client: &Client, | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |  | ||||||
|     symbol: &T, |     symbol: &T, | ||||||
|     qty: f64, |     qty: f64, | ||||||
| ) -> Result<(), Error> | ) -> Result<(), Error> | ||||||
| where | where | ||||||
|     T: AsRef<str> + Serialize + Send + Sync, |     T: AsRef<str> + Serialize + Send + Sync, | ||||||
| { | { | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |     clickhouse_client | ||||||
|     client |  | ||||||
|         .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") |         .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") | ||||||
|         .bind(qty) |         .bind(qty) | ||||||
|         .bind(symbol) |         .bind(symbol) | ||||||
							
								
								
									
										17
									
								
								src/database/backfills_bars.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								src/database/backfills_bars.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | use crate::{ | ||||||
|  |     cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, | ||||||
|  | }; | ||||||
|  | use clickhouse::{error::Error, Client}; | ||||||
|  |  | ||||||
|  | select_where_symbol!(Backfill, "backfills_bars"); | ||||||
|  | upsert!(Backfill, "backfills_bars"); | ||||||
|  | delete_where_symbols!("backfills_bars"); | ||||||
|  | cleanup!("backfills_bars"); | ||||||
|  | optimize!("backfills_bars"); | ||||||
|  |  | ||||||
|  | pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { | ||||||
|  |     clickhouse_client | ||||||
|  |         .query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true") | ||||||
|  |         .execute() | ||||||
|  |         .await | ||||||
|  | } | ||||||
							
								
								
									
										17
									
								
								src/database/backfills_news.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								src/database/backfills_news.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | use crate::{ | ||||||
|  |     cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, | ||||||
|  | }; | ||||||
|  | use clickhouse::{error::Error, Client}; | ||||||
|  |  | ||||||
|  | select_where_symbol!(Backfill, "backfills_news"); | ||||||
|  | upsert!(Backfill, "backfills_news"); | ||||||
|  | delete_where_symbols!("backfills_news"); | ||||||
|  | cleanup!("backfills_news"); | ||||||
|  | optimize!("backfills_news"); | ||||||
|  |  | ||||||
|  | pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { | ||||||
|  |     clickhouse_client | ||||||
|  |         .query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true") | ||||||
|  |         .execute() | ||||||
|  |         .await | ||||||
|  | } | ||||||
							
								
								
									
										7
									
								
								src/database/bars.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								src/database/bars.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; | ||||||
|  |  | ||||||
|  | upsert!(Bar, "bars"); | ||||||
|  | upsert_batch!(Bar, "bars"); | ||||||
|  | delete_where_symbols!("bars"); | ||||||
|  | cleanup!("bars"); | ||||||
|  | optimize!("bars"); | ||||||
| @@ -1,19 +1,16 @@ | |||||||
| use std::sync::Arc; |  | ||||||
| 
 |  | ||||||
| use crate::{optimize, types::Calendar}; | use crate::{optimize, types::Calendar}; | ||||||
| use clickhouse::{error::Error, Client}; | use clickhouse::error::Error; | ||||||
| use tokio::{sync::Semaphore, try_join}; | use tokio::try_join; | ||||||
| 
 | 
 | ||||||
| optimize!("calendar"); | optimize!("calendar"); | ||||||
| 
 | 
 | ||||||
| pub async fn upsert_batch_and_delete<'a, I>( | pub async fn upsert_batch_and_delete<'a, T>( | ||||||
|     client: &Client, |     client: &clickhouse::Client, | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |     records: T, | ||||||
|     records: I, |  | ||||||
| ) -> Result<(), Error> | ) -> Result<(), Error> | ||||||
| where | where | ||||||
|     I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone, |     T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone, | ||||||
|     I::IntoIter: Send, |     T::IntoIter: Send, | ||||||
| { | { | ||||||
|     let upsert_future = async { |     let upsert_future = async { | ||||||
|         let mut insert = client.insert("calendar")?; |         let mut insert = client.insert("calendar")?; | ||||||
| @@ -37,6 +34,5 @@ where | |||||||
|             .await |             .await | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     let _ = concurrency_limiter.acquire_many(2).await.unwrap(); |  | ||||||
|     try_join!(upsert_future, delete_future).map(|_| ()) |     try_join!(upsert_future, delete_future).map(|_| ()) | ||||||
| } | } | ||||||
							
								
								
									
										152
									
								
								src/database/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								src/database/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | |||||||
|  | pub mod assets; | ||||||
|  | pub mod backfills_bars; | ||||||
|  | pub mod backfills_news; | ||||||
|  | pub mod bars; | ||||||
|  | pub mod calendar; | ||||||
|  | pub mod news; | ||||||
|  | pub mod orders; | ||||||
|  |  | ||||||
|  | use clickhouse::{error::Error, Client}; | ||||||
|  | use log::info; | ||||||
|  | use tokio::try_join; | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! select { | ||||||
|  |     ($record:ty, $table_name:expr) => { | ||||||
|  |         pub async fn select( | ||||||
|  |             client: &clickhouse::Client, | ||||||
|  |         ) -> Result<Vec<$record>, clickhouse::error::Error> { | ||||||
|  |             client | ||||||
|  |                 .query(&format!("SELECT ?fields FROM {} FINAL", $table_name)) | ||||||
|  |                 .fetch_all::<$record>() | ||||||
|  |                 .await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! select_where_symbol { | ||||||
|  |     ($record:ty, $table_name:expr) => { | ||||||
|  |         pub async fn select_where_symbol<T>( | ||||||
|  |             client: &clickhouse::Client, | ||||||
|  |             symbol: &T, | ||||||
|  |         ) -> Result<Option<$record>, clickhouse::error::Error> | ||||||
|  |         where | ||||||
|  |             T: AsRef<str> + serde::Serialize + Send + Sync, | ||||||
|  |         { | ||||||
|  |             client | ||||||
|  |                 .query(&format!( | ||||||
|  |                     "SELECT ?fields FROM {} FINAL WHERE symbol = ?", | ||||||
|  |                     $table_name | ||||||
|  |                 )) | ||||||
|  |                 .bind(symbol) | ||||||
|  |                 .fetch_optional::<$record>() | ||||||
|  |                 .await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! upsert { | ||||||
|  |     ($record:ty, $table_name:expr) => { | ||||||
|  |         pub async fn upsert( | ||||||
|  |             client: &clickhouse::Client, | ||||||
|  |             record: &$record, | ||||||
|  |         ) -> Result<(), clickhouse::error::Error> { | ||||||
|  |             let mut insert = client.insert($table_name)?; | ||||||
|  |             insert.write(record).await?; | ||||||
|  |             insert.end().await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! upsert_batch { | ||||||
|  |     ($record:ty, $table_name:expr) => { | ||||||
|  |         pub async fn upsert_batch<'a, T>( | ||||||
|  |             client: &clickhouse::Client, | ||||||
|  |             records: T, | ||||||
|  |         ) -> Result<(), clickhouse::error::Error> | ||||||
|  |         where | ||||||
|  |             T: IntoIterator<Item = &'a $record> + Send + Sync, | ||||||
|  |             T::IntoIter: Send, | ||||||
|  |         { | ||||||
|  |             let mut insert = client.insert($table_name)?; | ||||||
|  |             for record in records { | ||||||
|  |                 insert.write(record).await?; | ||||||
|  |             } | ||||||
|  |             insert.end().await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! delete_where_symbols { | ||||||
|  |     ($table_name:expr) => { | ||||||
|  |         pub async fn delete_where_symbols<T>( | ||||||
|  |             client: &clickhouse::Client, | ||||||
|  |             symbols: &[T], | ||||||
|  |         ) -> Result<(), clickhouse::error::Error> | ||||||
|  |         where | ||||||
|  |             T: AsRef<str> + serde::Serialize + Send + Sync, | ||||||
|  |         { | ||||||
|  |             client | ||||||
|  |                 .query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name)) | ||||||
|  |                 .bind(symbols) | ||||||
|  |                 .execute() | ||||||
|  |                 .await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! cleanup { | ||||||
|  |     ($table_name:expr) => { | ||||||
|  |         pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { | ||||||
|  |             client | ||||||
|  |                 .query(&format!( | ||||||
|  |                     "DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)", | ||||||
|  |                     $table_name | ||||||
|  |                 )) | ||||||
|  |                 .execute() | ||||||
|  |                 .await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! optimize { | ||||||
|  |     ($table_name:expr) => { | ||||||
|  |         pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { | ||||||
|  |             client | ||||||
|  |                 .query(&format!("OPTIMIZE TABLE {} FINAL", $table_name)) | ||||||
|  |                 .execute() | ||||||
|  |                 .await | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> { | ||||||
|  |     info!("Cleaning up database."); | ||||||
|  |     try_join!( | ||||||
|  |         bars::cleanup(clickhouse_client), | ||||||
|  |         news::cleanup(clickhouse_client), | ||||||
|  |         backfills_bars::cleanup(clickhouse_client), | ||||||
|  |         backfills_news::cleanup(clickhouse_client) | ||||||
|  |     ) | ||||||
|  |     .map(|_| ()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> { | ||||||
|  |     info!("Optimizing database."); | ||||||
|  |     try_join!( | ||||||
|  |         assets::optimize(clickhouse_client), | ||||||
|  |         bars::optimize(clickhouse_client), | ||||||
|  |         news::optimize(clickhouse_client), | ||||||
|  |         backfills_bars::optimize(clickhouse_client), | ||||||
|  |         backfills_news::optimize(clickhouse_client), | ||||||
|  |         orders::optimize(clickhouse_client), | ||||||
|  |         calendar::optimize(clickhouse_client) | ||||||
|  |     ) | ||||||
|  |     .map(|_| ()) | ||||||
|  | } | ||||||
| @@ -1,33 +1,24 @@ | |||||||
| use std::sync::Arc; |  | ||||||
| 
 |  | ||||||
| use crate::{optimize, types::News, upsert, upsert_batch}; | use crate::{optimize, types::News, upsert, upsert_batch}; | ||||||
| use clickhouse::{error::Error, Client}; | use clickhouse::{error::Error, Client}; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use tokio::sync::Semaphore; |  | ||||||
| 
 | 
 | ||||||
| upsert!(News, "news"); | upsert!(News, "news"); | ||||||
| upsert_batch!(News, "news"); | upsert_batch!(News, "news"); | ||||||
| optimize!("news"); | optimize!("news"); | ||||||
| 
 | 
 | ||||||
| pub async fn delete_where_symbols<T>( | pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> | ||||||
|     client: &Client, |  | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |  | ||||||
|     symbols: &[T], |  | ||||||
| ) -> Result<(), Error> |  | ||||||
| where | where | ||||||
|     T: AsRef<str> + Serialize + Send + Sync, |     T: AsRef<str> + Serialize + Send + Sync, | ||||||
| { | { | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |     clickhouse_client | ||||||
|     client |  | ||||||
|         .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))") |         .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))") | ||||||
|         .bind(symbols) |         .bind(symbols) | ||||||
|         .execute() |         .execute() | ||||||
|         .await |         .await | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> { | pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |     clickhouse_client | ||||||
|     client |  | ||||||
|         .query( |         .query( | ||||||
|             "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))", |             "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))", | ||||||
|         ) |         ) | ||||||
| @@ -1,25 +1,24 @@ | |||||||
| use crate::{ | use crate::{ | ||||||
|     config::{Config, ALPACA_API_BASE}, |     config::{Config, ALPACA_MODE}, | ||||||
|     database, |     database, | ||||||
|  |     types::alpaca, | ||||||
| }; | }; | ||||||
| use log::{info, warn}; | use log::{info, warn}; | ||||||
| use qrust::{alpaca, types}; |  | ||||||
| use std::{collections::HashMap, sync::Arc}; | use std::{collections::HashMap, sync::Arc}; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
| use tokio::join; | use tokio::join; | ||||||
| 
 | 
 | ||||||
| pub async fn check_account(config: &Arc<Config>) { | pub async fn check_account(config: &Arc<Config>) { | ||||||
|     let account = alpaca::account::get( |     let account = alpaca::api::incoming::account::get( | ||||||
|         &config.alpaca_client, |         &config.alpaca_client, | ||||||
|         &config.alpaca_rate_limiter, |         &config.alpaca_rate_limiter, | ||||||
|         None, |         None, | ||||||
|         &ALPACA_API_BASE, |  | ||||||
|     ) |     ) | ||||||
|     .await |     .await | ||||||
|     .unwrap(); |     .unwrap(); | ||||||
| 
 | 
 | ||||||
|     assert!( |     assert!( | ||||||
|         !(account.status != types::alpaca::api::incoming::account::Status::Active), |         !(account.status != alpaca::api::incoming::account::Status::Active), | ||||||
|         "Account status is not active: {:?}.", |         "Account status is not active: {:?}.", | ||||||
|         account.status |         account.status | ||||||
|     ); |     ); | ||||||
| @@ -34,60 +33,56 @@ pub async fn check_account(config: &Arc<Config>) { | |||||||
|         warn!("Account cash is zero, qrust will not be able to trade."); |         warn!("Account cash is zero, qrust will not be able to trade."); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     info!( |     warn!( | ||||||
|         "qrust running on {} account with {} {}, avoid transferring funds without shutting down.", |         "qrust active on {} account with {} {}, avoid transferring funds without shutting down.", | ||||||
|         *ALPACA_API_BASE, account.currency, account.cash |         *ALPACA_MODE, account.currency, account.cash | ||||||
|     ); |     ); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub async fn rehydrate_orders(config: &Arc<Config>) { | pub async fn rehydrate_orders(config: &Arc<Config>) { | ||||||
|  |     info!("Rehydrating order data."); | ||||||
|  | 
 | ||||||
|     let mut orders = vec![]; |     let mut orders = vec![]; | ||||||
|     let mut after = OffsetDateTime::UNIX_EPOCH; |     let mut after = OffsetDateTime::UNIX_EPOCH; | ||||||
| 
 | 
 | ||||||
|     loop { |     while let Some(message) = alpaca::api::incoming::order::get( | ||||||
|         let message = alpaca::orders::get( |  | ||||||
|         &config.alpaca_client, |         &config.alpaca_client, | ||||||
|         &config.alpaca_rate_limiter, |         &config.alpaca_rate_limiter, | ||||||
|             &types::alpaca::api::outgoing::order::Order { |         &alpaca::api::outgoing::order::Order { | ||||||
|                 status: Some(types::alpaca::api::outgoing::order::Status::All), |             status: Some(alpaca::api::outgoing::order::Status::All), | ||||||
|             after: Some(after), |             after: Some(after), | ||||||
|             ..Default::default() |             ..Default::default() | ||||||
|         }, |         }, | ||||||
|         None, |         None, | ||||||
|             &ALPACA_API_BASE, |  | ||||||
|     ) |     ) | ||||||
|     .await |     .await | ||||||
|         .unwrap(); |     .ok() | ||||||
| 
 |     .filter(|message| !message.is_empty()) | ||||||
|         if message.is_empty() { |     { | ||||||
|             break; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         orders.extend(message); |         orders.extend(message); | ||||||
|         after = orders.last().unwrap().submitted_at; |         after = orders.last().unwrap().submitted_at; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let orders = orders |     let orders = orders | ||||||
|         .into_iter() |         .into_iter() | ||||||
|         .flat_map(&types::alpaca::api::incoming::order::Order::normalize) |         .flat_map(&alpaca::api::incoming::order::Order::normalize) | ||||||
|         .collect::<Vec<_>>(); |         .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|     database::orders::upsert_batch( |     database::orders::upsert_batch(&config.clickhouse_client, &orders) | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|         &orders, |  | ||||||
|     ) |  | ||||||
|         .await |         .await | ||||||
|         .unwrap(); |         .unwrap(); | ||||||
|  | 
 | ||||||
|  |     info!("Rehydrated order data."); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub async fn rehydrate_positions(config: &Arc<Config>) { | pub async fn rehydrate_positions(config: &Arc<Config>) { | ||||||
|  |     info!("Rehydrating position data."); | ||||||
|  | 
 | ||||||
|     let positions_future = async { |     let positions_future = async { | ||||||
|         alpaca::positions::get( |         alpaca::api::incoming::position::get( | ||||||
|             &config.alpaca_client, |             &config.alpaca_client, | ||||||
|             &config.alpaca_rate_limiter, |             &config.alpaca_rate_limiter, | ||||||
|             None, |             None, | ||||||
|             &ALPACA_API_BASE, |  | ||||||
|         ) |         ) | ||||||
|         .await |         .await | ||||||
|         .unwrap() |         .unwrap() | ||||||
| @@ -97,10 +92,7 @@ pub async fn rehydrate_positions(config: &Arc<Config>) { | |||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     let assets_future = async { |     let assets_future = async { | ||||||
|         database::assets::select( |         database::assets::select(&config.clickhouse_client) | ||||||
|             &config.clickhouse_client, |  | ||||||
|             &config.clickhouse_concurrency_limiter, |  | ||||||
|         ) |  | ||||||
|             .await |             .await | ||||||
|             .unwrap() |             .unwrap() | ||||||
|     }; |     }; | ||||||
| @@ -119,11 +111,7 @@ pub async fn rehydrate_positions(config: &Arc<Config>) { | |||||||
|         }) |         }) | ||||||
|         .collect::<Vec<_>>(); |         .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|     database::assets::upsert_batch( |     database::assets::upsert_batch(&config.clickhouse_client, &assets) | ||||||
|         &config.clickhouse_client, |  | ||||||
|         &config.clickhouse_concurrency_limiter, |  | ||||||
|         &assets, |  | ||||||
|     ) |  | ||||||
|         .await |         .await | ||||||
|         .unwrap(); |         .unwrap(); | ||||||
| 
 | 
 | ||||||
| @@ -133,4 +121,6 @@ pub async fn rehydrate_positions(config: &Arc<Config>) { | |||||||
|             position.symbol, position.qty |             position.symbol, position.qty | ||||||
|         ); |         ); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     info!("Rehydrated position data."); | ||||||
| } | } | ||||||
| @@ -1,39 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::incoming::account::Account; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use std::time::Duration; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Account, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/account", api_base)) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Account>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get account, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
| @@ -1,132 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::{ |  | ||||||
|     incoming::asset::{Asset, Class}, |  | ||||||
|     outgoing, |  | ||||||
| }; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use itertools::Itertools; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use std::{collections::HashSet, time::Duration}; |  | ||||||
| use tokio::try_join; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     query: &outgoing::asset::Asset, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Asset>, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/assets", api_base)) |  | ||||||
|                 .query(query) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Vec<Asset>>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get assets, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get_by_symbol( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     symbol: &str, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Asset, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!( |  | ||||||
|                     "https://{}.alpaca.markets/v2/assets/{}", |  | ||||||
|                     api_base, symbol |  | ||||||
|                 )) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Asset>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get asset, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get_by_symbols( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     symbols: &[String], |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Asset>, Error> { |  | ||||||
|     if symbols.is_empty() { |  | ||||||
|         return Ok(vec![]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if symbols.len() == 1 { |  | ||||||
|         let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; |  | ||||||
|         return Ok(vec![asset]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let symbols = symbols.iter().collect::<HashSet<_>>(); |  | ||||||
|  |  | ||||||
|     let backoff_clone = backoff.clone(); |  | ||||||
|  |  | ||||||
|     let us_equity_query = outgoing::asset::Asset { |  | ||||||
|         class: Some(Class::UsEquity), |  | ||||||
|         ..Default::default() |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let us_equity_assets = get( |  | ||||||
|         client, |  | ||||||
|         rate_limiter, |  | ||||||
|         &us_equity_query, |  | ||||||
|         backoff_clone, |  | ||||||
|         api_base, |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     let crypto_query = outgoing::asset::Asset { |  | ||||||
|         class: Some(Class::Crypto), |  | ||||||
|         ..Default::default() |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base); |  | ||||||
|  |  | ||||||
|     let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?; |  | ||||||
|  |  | ||||||
|     Ok(crypto_assets |  | ||||||
|         .into_iter() |  | ||||||
|         .chain(us_equity_assets) |  | ||||||
|         .dedup_by(|a, b| a.symbol == b.symbol) |  | ||||||
|         .filter(|asset| symbols.contains(&asset.symbol)) |  | ||||||
|         .collect()) |  | ||||||
| } |  | ||||||
| @@ -1,50 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::{incoming::bar::Bar, outgoing}; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use std::{collections::HashMap, time::Duration}; |  | ||||||
|  |  | ||||||
| pub const MAX_LIMIT: i64 = 10_000; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Message { |  | ||||||
|     pub bars: HashMap<String, Vec<Bar>>, |  | ||||||
|     pub next_page_token: Option<String>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     data_url: &str, |  | ||||||
|     query: &outgoing::bar::Bar, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
| ) -> Result<Message, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(data_url) |  | ||||||
|                 .query(query) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Message>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get historical bars, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
| @@ -1,41 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing}; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use std::time::Duration; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     query: &outgoing::calendar::Calendar, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Calendar>, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/calendar", api_base)) |  | ||||||
|                 .query(query) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Vec<Calendar>>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get calendar, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
| @@ -1,39 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::incoming::clock::Clock; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use std::time::Duration; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Clock, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/clock", api_base)) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Clock>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get clock, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
| @@ -1,27 +0,0 @@ | |||||||
| pub mod account; |  | ||||||
| pub mod assets; |  | ||||||
| pub mod bars; |  | ||||||
| pub mod calendar; |  | ||||||
| pub mod clock; |  | ||||||
| pub mod news; |  | ||||||
| pub mod orders; |  | ||||||
| pub mod positions; |  | ||||||
|  |  | ||||||
| use http::StatusCode; |  | ||||||
|  |  | ||||||
| pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> { |  | ||||||
|     if err.is_status() { |  | ||||||
|         return match err.status() { |  | ||||||
|             Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND) |  | ||||||
|             | None => backoff::Error::Permanent(err), |  | ||||||
|             _ => err.into(), |  | ||||||
|         }; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body() |  | ||||||
|     { |  | ||||||
|         return backoff::Error::Permanent(err); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     err.into() |  | ||||||
| } |  | ||||||
| @@ -1,49 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL}; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::{Client, Error}; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use std::time::Duration; |  | ||||||
|  |  | ||||||
| pub const MAX_LIMIT: i64 = 50; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Message { |  | ||||||
|     pub news: Vec<News>, |  | ||||||
|     pub next_page_token: Option<String>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     query: &outgoing::news::News, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
| ) -> Result<Message, Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(ALPACA_NEWS_DATA_API_URL) |  | ||||||
|                 .query(query) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Message>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get historical news, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
| @@ -1,109 +0,0 @@ | |||||||
| use super::error_to_backoff; |  | ||||||
| use crate::types::alpaca::api::incoming::position::Position; |  | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; |  | ||||||
| use governor::DefaultDirectRateLimiter; |  | ||||||
| use http::StatusCode; |  | ||||||
| use log::warn; |  | ||||||
| use reqwest::Client; |  | ||||||
| use std::{collections::HashSet, time::Duration}; |  | ||||||
|  |  | ||||||
| pub async fn get( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Position>, reqwest::Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             client |  | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/positions", api_base)) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Vec<Position>>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get positions, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get_by_symbol( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     symbol: &str, |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Option<Position>, reqwest::Error> { |  | ||||||
|     retry_notify( |  | ||||||
|         backoff.unwrap_or_default(), |  | ||||||
|         || async { |  | ||||||
|             rate_limiter.until_ready().await; |  | ||||||
|             let response = client |  | ||||||
|                 .get(&format!( |  | ||||||
|                     "https://{}.alpaca.markets/v2/positions/{}", |  | ||||||
|                     api_base, symbol |  | ||||||
|                 )) |  | ||||||
|                 .send() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff)?; |  | ||||||
|  |  | ||||||
|             if response.status() == StatusCode::NOT_FOUND { |  | ||||||
|                 return Ok(None); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             response |  | ||||||
|                 .error_for_status() |  | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .json::<Position>() |  | ||||||
|                 .await |  | ||||||
|                 .map_err(error_to_backoff) |  | ||||||
|                 .map(Some) |  | ||||||
|         }, |  | ||||||
|         |e, duration: Duration| { |  | ||||||
|             warn!( |  | ||||||
|                 "Failed to get position, will retry in {} seconds: {}.", |  | ||||||
|                 duration.as_secs(), |  | ||||||
|                 e |  | ||||||
|             ); |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn get_by_symbols( |  | ||||||
|     client: &Client, |  | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |  | ||||||
|     symbols: &[String], |  | ||||||
|     backoff: Option<ExponentialBackoff>, |  | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Position>, reqwest::Error> { |  | ||||||
|     if symbols.is_empty() { |  | ||||||
|         return Ok(vec![]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if symbols.len() == 1 { |  | ||||||
|         let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?; |  | ||||||
|         return Ok(position.into_iter().collect()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let symbols = symbols.iter().collect::<HashSet<_>>(); |  | ||||||
|  |  | ||||||
|     let positions = get(client, rate_limiter, backoff, api_base).await?; |  | ||||||
|  |  | ||||||
|     Ok(positions |  | ||||||
|         .into_iter() |  | ||||||
|         .filter(|position| symbols.contains(&position.symbol)) |  | ||||||
|         .collect()) |  | ||||||
| } |  | ||||||
| @@ -1,11 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols, |  | ||||||
|     types::Backfill, upsert_batch, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| select_where_symbols!(Backfill, "backfills_bars"); |  | ||||||
| upsert_batch!(Backfill, "backfills_bars"); |  | ||||||
| delete_where_symbols!("backfills_bars"); |  | ||||||
| cleanup!("backfills_bars"); |  | ||||||
| optimize!("backfills_bars"); |  | ||||||
| set_fresh_where_symbols!("backfills_bars"); |  | ||||||
| @@ -1,11 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols, |  | ||||||
|     types::Backfill, upsert_batch, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| select_where_symbols!(Backfill, "backfills_news"); |  | ||||||
| upsert_batch!(Backfill, "backfills_news"); |  | ||||||
| delete_where_symbols!("backfills_news"); |  | ||||||
| cleanup!("backfills_news"); |  | ||||||
| optimize!("backfills_news"); |  | ||||||
| set_fresh_where_symbols!("backfills_news"); |  | ||||||
| @@ -1,21 +0,0 @@ | |||||||
| use std::sync::Arc; |  | ||||||
|  |  | ||||||
| use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; |  | ||||||
| use clickhouse::Client; |  | ||||||
| use tokio::sync::Semaphore; |  | ||||||
|  |  | ||||||
| upsert!(Bar, "bars"); |  | ||||||
| upsert_batch!(Bar, "bars"); |  | ||||||
| delete_where_symbols!("bars"); |  | ||||||
| optimize!("bars"); |  | ||||||
|  |  | ||||||
| pub async fn cleanup( |  | ||||||
|     client: &Client, |  | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |  | ||||||
| ) -> Result<(), clickhouse::error::Error> { |  | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|     client |  | ||||||
|         .query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)") |  | ||||||
|         .execute() |  | ||||||
|         .await |  | ||||||
| } |  | ||||||
| @@ -1,224 +0,0 @@ | |||||||
| pub mod assets; |  | ||||||
| pub mod backfills_bars; |  | ||||||
| pub mod backfills_news; |  | ||||||
| pub mod bars; |  | ||||||
| pub mod calendar; |  | ||||||
| pub mod news; |  | ||||||
| pub mod orders; |  | ||||||
| pub mod ta; |  | ||||||
|  |  | ||||||
| use clickhouse::{error::Error, Client}; |  | ||||||
| use tokio::try_join; |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! select { |  | ||||||
|     ($record:ty, $table_name:expr) => { |  | ||||||
|         pub async fn select( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|         ) -> Result<Vec<$record>, clickhouse::error::Error> { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!("SELECT ?fields FROM {} FINAL", $table_name)) |  | ||||||
|                 .fetch_all::<$record>() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! select_where_symbol { |  | ||||||
|     ($record:ty, $table_name:expr) => { |  | ||||||
|         pub async fn select_where_symbol<T>( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             symbol: &T, |  | ||||||
|         ) -> Result<Option<$record>, clickhouse::error::Error> |  | ||||||
|         where |  | ||||||
|             T: AsRef<str> + serde::Serialize + Send + Sync, |  | ||||||
|         { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!( |  | ||||||
|                     "SELECT ?fields FROM {} FINAL WHERE symbol = ?", |  | ||||||
|                     $table_name |  | ||||||
|                 )) |  | ||||||
|                 .bind(symbol) |  | ||||||
|                 .fetch_optional::<$record>() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! select_where_symbols { |  | ||||||
|     ($record:ty, $table_name:expr) => { |  | ||||||
|         pub async fn select_where_symbols<T>( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             symbols: &[T], |  | ||||||
|         ) -> Result<Vec<$record>, clickhouse::error::Error> |  | ||||||
|         where |  | ||||||
|             T: AsRef<str> + serde::Serialize + Send + Sync, |  | ||||||
|         { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!( |  | ||||||
|                     "SELECT ?fields FROM {} FINAL WHERE symbol IN ?", |  | ||||||
|                     $table_name |  | ||||||
|                 )) |  | ||||||
|                 .bind(symbols) |  | ||||||
|                 .fetch_all::<$record>() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! upsert { |  | ||||||
|     ($record:ty, $table_name:expr) => { |  | ||||||
|         pub async fn upsert( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             record: &$record, |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             let mut insert = client.insert($table_name)?; |  | ||||||
|             insert.write(record).await?; |  | ||||||
|             insert.end().await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! upsert_batch { |  | ||||||
|     ($record:ty, $table_name:expr) => { |  | ||||||
|         pub async fn upsert_batch<'a, I>( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             records: I, |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> |  | ||||||
|         where |  | ||||||
|             I: IntoIterator<Item = &'a $record> + Send + Sync, |  | ||||||
|             I::IntoIter: Send, |  | ||||||
|         { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             let mut insert = client.insert($table_name)?; |  | ||||||
|             for record in records { |  | ||||||
|                 insert.write(record).await?; |  | ||||||
|             } |  | ||||||
|             insert.end().await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! delete_where_symbols { |  | ||||||
|     ($table_name:expr) => { |  | ||||||
|         pub async fn delete_where_symbols<T>( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             symbols: &[T], |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> |  | ||||||
|         where |  | ||||||
|             T: AsRef<str> + serde::Serialize + Send + Sync, |  | ||||||
|         { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name)) |  | ||||||
|                 .bind(symbols) |  | ||||||
|                 .execute() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! cleanup { |  | ||||||
|     ($table_name:expr) => { |  | ||||||
|         pub async fn cleanup( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!( |  | ||||||
|                     "DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)", |  | ||||||
|                     $table_name |  | ||||||
|                 )) |  | ||||||
|                 .execute() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! optimize { |  | ||||||
|     ($table_name:expr) => { |  | ||||||
|         pub async fn optimize( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!("OPTIMIZE TABLE {} FINAL", $table_name)) |  | ||||||
|                 .execute() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[macro_export] |  | ||||||
| macro_rules! set_fresh_where_symbols { |  | ||||||
|     ($table_name:expr) => { |  | ||||||
|         pub async fn set_fresh_where_symbols<T>( |  | ||||||
|             client: &clickhouse::Client, |  | ||||||
|             concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
|             fresh: bool, |  | ||||||
|             symbols: &[T], |  | ||||||
|         ) -> Result<(), clickhouse::error::Error> |  | ||||||
|         where |  | ||||||
|             T: AsRef<str> + serde::Serialize + Send + Sync, |  | ||||||
|         { |  | ||||||
|             let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|             client |  | ||||||
|                 .query(&format!( |  | ||||||
|                     "ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?", |  | ||||||
|                     $table_name |  | ||||||
|                 )) |  | ||||||
|                 .bind(fresh) |  | ||||||
|                 .bind(symbols) |  | ||||||
|                 .execute() |  | ||||||
|                 .await |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn cleanup_all( |  | ||||||
|     clickhouse_client: &Client, |  | ||||||
|     concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
| ) -> Result<(), Error> { |  | ||||||
|     try_join!( |  | ||||||
|         bars::cleanup(clickhouse_client, concurrency_limiter), |  | ||||||
|         news::cleanup(clickhouse_client, concurrency_limiter), |  | ||||||
|         backfills_bars::cleanup(clickhouse_client, concurrency_limiter), |  | ||||||
|         backfills_news::cleanup(clickhouse_client, concurrency_limiter) |  | ||||||
|     ) |  | ||||||
|     .map(|_| ()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn optimize_all( |  | ||||||
|     clickhouse_client: &Client, |  | ||||||
|     concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>, |  | ||||||
| ) -> Result<(), Error> { |  | ||||||
|     try_join!( |  | ||||||
|         assets::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         bars::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         news::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         backfills_bars::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         backfills_news::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         orders::optimize(clickhouse_client, concurrency_limiter), |  | ||||||
|         calendar::optimize(clickhouse_client, concurrency_limiter) |  | ||||||
|     ) |  | ||||||
|     .map(|_| ()) |  | ||||||
| } |  | ||||||
| @@ -1,30 +0,0 @@ | |||||||
| use crate::types::Bar; |  | ||||||
| use clickhouse::{error::Error, Client}; |  | ||||||
| use std::sync::Arc; |  | ||||||
| use tokio::sync::Semaphore; |  | ||||||
|  |  | ||||||
| pub async fn select( |  | ||||||
|     client: &Client, |  | ||||||
|     concurrency_limiter: &Arc<Semaphore>, |  | ||||||
| ) -> Result<Vec<Bar>, Error> { |  | ||||||
|     let _ = concurrency_limiter.acquire().await.unwrap(); |  | ||||||
|     client |  | ||||||
|         .query( |  | ||||||
|             " |  | ||||||
|         SELECT symbol, |  | ||||||
|             toStartOfHour(bars.time) AS time, |  | ||||||
|             any(bars.open) AS open, |  | ||||||
|             max(bars.high) AS high, |  | ||||||
|             min(bars.low) AS low, |  | ||||||
|             anyLast(bars.close) AS close, |  | ||||||
|             sum(bars.volume) AS volume, |  | ||||||
|             sum(bars.trades) AS trades |  | ||||||
|         FROM bars FINAL |  | ||||||
|         GROUP BY ALL |  | ||||||
|         ORDER BY symbol, |  | ||||||
|             time |  | ||||||
|             ", |  | ||||||
|         ) |  | ||||||
|         .fetch_all::<Bar>() |  | ||||||
|         .await |  | ||||||
| } |  | ||||||
| @@ -1,75 +0,0 @@ | |||||||
| use super::BarWindow; |  | ||||||
| use burn::{ |  | ||||||
|     data::dataloader::batcher::Batcher, |  | ||||||
|     tensor::{self, backend::Backend, Tensor}, |  | ||||||
| }; |  | ||||||
| use rayon::iter::{IntoParallelIterator, ParallelIterator}; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| pub struct BarWindowBatcher<B: Backend> { |  | ||||||
|     pub device: B::Device, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| pub struct BarWindowBatch<B: Backend> { |  | ||||||
|     pub hour_tensor: Tensor<B, 2, tensor::Int>, |  | ||||||
|     pub day_tensor: Tensor<B, 2, tensor::Int>, |  | ||||||
|     pub numerical_tensor: Tensor<B, 3>, |  | ||||||
|     pub target_tensor: Tensor<B, 2>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<B: Backend<FloatElem = f32, IntElem = i32>> Batcher<BarWindow, BarWindowBatch<B>> |  | ||||||
|     for BarWindowBatcher<B> |  | ||||||
| { |  | ||||||
|     fn batch(&self, items: Vec<BarWindow>) -> BarWindowBatch<B> { |  | ||||||
|         let batch_size = items.len(); |  | ||||||
|  |  | ||||||
|         let (hour_tensors, day_tensors, numerical_tensors, target_tensors) = items |  | ||||||
|             .into_par_iter() |  | ||||||
|             .fold( |  | ||||||
|                 || { |  | ||||||
|                     ( |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                     ) |  | ||||||
|                 }, |  | ||||||
|                 |(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors), |  | ||||||
|                  item| { |  | ||||||
|                     hour_tensors.push(Tensor::from_data(item.hours, &self.device)); |  | ||||||
|                     day_tensors.push(Tensor::from_data(item.days, &self.device)); |  | ||||||
|                     numerical_tensors.push(Tensor::from_data(item.numerical, &self.device)); |  | ||||||
|                     target_tensors.push(Tensor::from_data(item.target, &self.device)); |  | ||||||
|  |  | ||||||
|                     (hour_tensors, day_tensors, numerical_tensors, target_tensors) |  | ||||||
|                 }, |  | ||||||
|             ) |  | ||||||
|             .reduce( |  | ||||||
|                 || { |  | ||||||
|                     ( |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                         Vec::with_capacity(batch_size), |  | ||||||
|                     ) |  | ||||||
|                 }, |  | ||||||
|                 |(mut hour_tensors, mut day_tensors, mut numerical_tensors, mut target_tensors), |  | ||||||
|                  item| { |  | ||||||
|                     hour_tensors.extend(item.0); |  | ||||||
|                     day_tensors.extend(item.1); |  | ||||||
|                     numerical_tensors.extend(item.2); |  | ||||||
|                     target_tensors.extend(item.3); |  | ||||||
|  |  | ||||||
|                     (hour_tensors, day_tensors, numerical_tensors, target_tensors) |  | ||||||
|                 }, |  | ||||||
|             ); |  | ||||||
|  |  | ||||||
|         BarWindowBatch { |  | ||||||
|             hour_tensor: Tensor::stack(hour_tensors, 0).to_device(&self.device), |  | ||||||
|             day_tensor: Tensor::stack(day_tensors, 0).to_device(&self.device), |  | ||||||
|             numerical_tensor: Tensor::stack(numerical_tensors, 0).to_device(&self.device), |  | ||||||
|             target_tensor: Tensor::stack(target_tensors, 0).to_device(&self.device), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,219 +0,0 @@ | |||||||
| use crate::types::{ |  | ||||||
|     ta::{calculate_indicators, IndicatedBar, HEAD_SIZE, NUMERICAL_FIELD_COUNT}, |  | ||||||
|     Bar, |  | ||||||
| }; |  | ||||||
| use burn::{ |  | ||||||
|     data::dataset::{transform::ComposedDataset, Dataset}, |  | ||||||
|     tensor::Data, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub const WINDOW_SIZE: usize = 48; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| pub struct BarWindow { |  | ||||||
|     pub hours: Data<i32, 1>, |  | ||||||
|     pub days: Data<i32, 1>, |  | ||||||
|     pub numerical: Data<f32, 2>, |  | ||||||
|     pub target: Data<f32, 1>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| struct SingleSymbolDataset { |  | ||||||
|     hours: Vec<i32>, |  | ||||||
|     days: Vec<i32>, |  | ||||||
|     numerical: Vec<[f32; NUMERICAL_FIELD_COUNT]>, |  | ||||||
|     targets: Vec<f32>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl SingleSymbolDataset { |  | ||||||
|     #[allow(clippy::cast_possible_truncation)] |  | ||||||
|     pub fn new(bars: Vec<IndicatedBar>) -> Self { |  | ||||||
|         if !bars.is_empty() { |  | ||||||
|             let symbol = &bars[0].symbol; |  | ||||||
|             assert!(bars.iter().all(|bar| bar.symbol == *symbol)); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let (hours, days, numerical, targets) = bars.windows(2).skip(HEAD_SIZE - 1).fold( |  | ||||||
|             ( |  | ||||||
|                 Vec::with_capacity(bars.len() - 1), |  | ||||||
|                 Vec::with_capacity(bars.len() - 1), |  | ||||||
|                 Vec::with_capacity(bars.len() - 1), |  | ||||||
|                 Vec::with_capacity(bars.len() - 1), |  | ||||||
|             ), |  | ||||||
|             |(mut hours, mut days, mut numerical, mut targets), bar| { |  | ||||||
|                 hours.push(i32::from(bar[0].hour)); |  | ||||||
|                 days.push(i32::from(bar[0].day)); |  | ||||||
|                 numerical.push([ |  | ||||||
|                     bar[0].open as f32, |  | ||||||
|                     (bar[0].open_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].high as f32, |  | ||||||
|                     (bar[0].high_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].low as f32, |  | ||||||
|                     (bar[0].low_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].close as f32, |  | ||||||
|                     (bar[0].close_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].volume as f32, |  | ||||||
|                     (bar[0].volume_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].trades as f32, |  | ||||||
|                     (bar[0].trades_pct as f32).min(f32::MAX), |  | ||||||
|                     bar[0].sma_3 as f32, |  | ||||||
|                     bar[0].sma_6 as f32, |  | ||||||
|                     bar[0].sma_12 as f32, |  | ||||||
|                     bar[0].sma_24 as f32, |  | ||||||
|                     bar[0].sma_48 as f32, |  | ||||||
|                     bar[0].sma_72 as f32, |  | ||||||
|                     bar[0].ema_3 as f32, |  | ||||||
|                     bar[0].ema_6 as f32, |  | ||||||
|                     bar[0].ema_12 as f32, |  | ||||||
|                     bar[0].ema_24 as f32, |  | ||||||
|                     bar[0].ema_48 as f32, |  | ||||||
|                     bar[0].ema_72 as f32, |  | ||||||
|                     bar[0].macd as f32, |  | ||||||
|                     bar[0].macd_signal as f32, |  | ||||||
|                     bar[0].obv as f32, |  | ||||||
|                     bar[0].rsi as f32, |  | ||||||
|                     bar[0].bbands_lower as f32, |  | ||||||
|                     bar[0].bbands_mean as f32, |  | ||||||
|                     bar[0].bbands_upper as f32, |  | ||||||
|                 ]); |  | ||||||
|                 targets.push(bar[1].close_pct as f32); |  | ||||||
|                 (hours, days, numerical, targets) |  | ||||||
|             }, |  | ||||||
|         ); |  | ||||||
|  |  | ||||||
|         Self { |  | ||||||
|             hours, |  | ||||||
|             days, |  | ||||||
|             numerical, |  | ||||||
|             targets, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Dataset<BarWindow> for SingleSymbolDataset { |  | ||||||
|     fn len(&self) -> usize { |  | ||||||
|         self.targets.len() - WINDOW_SIZE + 1 |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[allow(clippy::single_range_in_vec_init)] |  | ||||||
|     fn get(&self, idx: usize) -> Option<BarWindow> { |  | ||||||
|         if idx >= self.len() { |  | ||||||
|             return None; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let hours: [i32; WINDOW_SIZE] = self.hours[idx..idx + WINDOW_SIZE].try_into().unwrap(); |  | ||||||
|         let days: [i32; WINDOW_SIZE] = self.days[idx..idx + WINDOW_SIZE].try_into().unwrap(); |  | ||||||
|         let numerical: [[f32; NUMERICAL_FIELD_COUNT]; WINDOW_SIZE] = |  | ||||||
|             self.numerical[idx..idx + WINDOW_SIZE].try_into().unwrap(); |  | ||||||
|         let target: [f32; 1] = [self.targets[idx + WINDOW_SIZE - 1]]; |  | ||||||
|  |  | ||||||
|         Some(BarWindow { |  | ||||||
|             hours: Data::from(hours), |  | ||||||
|             days: Data::from(days), |  | ||||||
|             numerical: Data::from(numerical), |  | ||||||
|             target: Data::from(target), |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub struct MultipleSymbolDataset { |  | ||||||
|     composed_dataset: ComposedDataset<SingleSymbolDataset>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl MultipleSymbolDataset { |  | ||||||
|     pub fn new(bars: Vec<Bar>) -> Self { |  | ||||||
|         let groups = calculate_indicators(bars) |  | ||||||
|             .into_iter() |  | ||||||
|             .map(SingleSymbolDataset::new) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         Self { |  | ||||||
|             composed_dataset: ComposedDataset::new(groups), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Dataset<BarWindow> for MultipleSymbolDataset { |  | ||||||
|     fn len(&self) -> usize { |  | ||||||
|         self.composed_dataset.len() |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     fn get(&self, idx: usize) -> Option<BarWindow> { |  | ||||||
|         self.composed_dataset.get(idx) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|     use rand::{ |  | ||||||
|         distributions::{Distribution, Uniform}, |  | ||||||
|         Rng, |  | ||||||
|     }; |  | ||||||
|     use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
|     fn generate_random_dataset(length: usize) -> MultipleSymbolDataset { |  | ||||||
|         let mut rng = rand::thread_rng(); |  | ||||||
|         let uniform = Uniform::new(1.0, 100.0); |  | ||||||
|         let mut bars = Vec::with_capacity(length); |  | ||||||
|  |  | ||||||
|         for _ in 0..=(length + (HEAD_SIZE - 1) + (WINDOW_SIZE - 1)) { |  | ||||||
|             bars.push(Bar { |  | ||||||
|                 symbol: "AAPL".to_string(), |  | ||||||
|                 time: OffsetDateTime::now_utc(), |  | ||||||
|                 open: uniform.sample(&mut rng), |  | ||||||
|                 high: uniform.sample(&mut rng), |  | ||||||
|                 low: uniform.sample(&mut rng), |  | ||||||
|                 close: uniform.sample(&mut rng), |  | ||||||
|                 volume: uniform.sample(&mut rng), |  | ||||||
|                 trades: rng.gen_range(1..100), |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         MultipleSymbolDataset::new(bars) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_single_symbol_dataset() { |  | ||||||
|         let length = 100; |  | ||||||
|         let dataset = generate_random_dataset(length); |  | ||||||
|  |  | ||||||
|         assert_eq!(dataset.len(), length); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_single_symbol_dataset_window() { |  | ||||||
|         let length = 100; |  | ||||||
|         let dataset = generate_random_dataset(length); |  | ||||||
|  |  | ||||||
|         let item = dataset.get(0).unwrap(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             item.numerical.shape.dims, |  | ||||||
|             [WINDOW_SIZE, NUMERICAL_FIELD_COUNT] |  | ||||||
|         ); |  | ||||||
|         assert_eq!(item.target.shape.dims, [1]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_single_symbol_dataset_last_window() { |  | ||||||
|         let length = 100; |  | ||||||
|         let dataset = generate_random_dataset(length); |  | ||||||
|  |  | ||||||
|         let item = dataset.get(dataset.len() - 1).unwrap(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             item.numerical.shape.dims, |  | ||||||
|             [WINDOW_SIZE, NUMERICAL_FIELD_COUNT] |  | ||||||
|         ); |  | ||||||
|         assert_eq!(item.target.shape.dims, [1]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_single_symbol_dataset_out_of_bounds() { |  | ||||||
|         let length = 100; |  | ||||||
|         let dataset = generate_random_dataset(length); |  | ||||||
|  |  | ||||||
|         assert!(dataset.get(dataset.len()).is_none()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,21 +0,0 @@ | |||||||
| pub mod batcher; |  | ||||||
| pub mod dataset; |  | ||||||
| pub mod model; |  | ||||||
|  |  | ||||||
| pub use batcher::{BarWindowBatch, BarWindowBatcher}; |  | ||||||
| pub use dataset::{BarWindow, MultipleSymbolDataset}; |  | ||||||
| pub use model::{Model, ModelConfig}; |  | ||||||
|  |  | ||||||
| use burn::{ |  | ||||||
|     backend::{ |  | ||||||
|         wgpu::{AutoGraphicsApi, WgpuDevice}, |  | ||||||
|         Autodiff, Wgpu, |  | ||||||
|     }, |  | ||||||
|     tensor::backend::Backend, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>; |  | ||||||
| pub type MyAutodiffBackend = Autodiff<MyBackend>; |  | ||||||
| pub type MyDevice = <Autodiff<Wgpu> as Backend>::Device; |  | ||||||
|  |  | ||||||
| pub const DEVICE: MyDevice = WgpuDevice::BestAvailable; |  | ||||||
| @@ -1,160 +0,0 @@ | |||||||
| use super::BarWindowBatch; |  | ||||||
| use crate::types::ta::NUMERICAL_FIELD_COUNT; |  | ||||||
| use burn::{ |  | ||||||
|     config::Config, |  | ||||||
|     module::Module, |  | ||||||
|     nn::{ |  | ||||||
|         loss::{MseLoss, Reduction}, |  | ||||||
|         Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear, LinearConfig, Lstm, LstmConfig, |  | ||||||
|     }, |  | ||||||
|     tensor::{ |  | ||||||
|         self, |  | ||||||
|         backend::{AutodiffBackend, Backend}, |  | ||||||
|         Tensor, |  | ||||||
|     }, |  | ||||||
|     train::{RegressionOutput, TrainOutput, TrainStep, ValidStep}, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| #[derive(Module, Debug)] |  | ||||||
| pub struct Model<B: Backend> { |  | ||||||
|     hour_embedding: Embedding<B>, |  | ||||||
|     day_embedding: Embedding<B>, |  | ||||||
|     lstm_1: Lstm<B>, |  | ||||||
|     dropout_1: Dropout, |  | ||||||
|     lstm_2: Lstm<B>, |  | ||||||
|     dropout_2: Dropout, |  | ||||||
|     lstm_3: Lstm<B>, |  | ||||||
|     dropout_3: Dropout, |  | ||||||
|     lstm_4: Lstm<B>, |  | ||||||
|     dropout_4: Dropout, |  | ||||||
|     linear: Linear<B>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Config, Debug)] |  | ||||||
| pub struct ModelConfig { |  | ||||||
|     #[config(default = "3")] |  | ||||||
|     pub hour_features: usize, |  | ||||||
|     #[config(default = "2")] |  | ||||||
|     pub day_features: usize, |  | ||||||
|     #[config(default = "{NUMERICAL_FIELD_COUNT}")] |  | ||||||
|     pub numerical_features: usize, |  | ||||||
|     #[config(default = "0.2")] |  | ||||||
|     pub dropout: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl ModelConfig { |  | ||||||
|     pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> { |  | ||||||
|         let num_features = self.numerical_features + self.hour_features + self.day_features; |  | ||||||
|  |  | ||||||
|         let lstm_1_hidden_size = 512; |  | ||||||
|         let lstm_2_hidden_size = 256; |  | ||||||
|         let lstm_3_hidden_size = 64; |  | ||||||
|         let lstm_4_hidden_size = 32; |  | ||||||
|  |  | ||||||
|         Model { |  | ||||||
|             hour_embedding: EmbeddingConfig::new(24, self.hour_features).init(device), |  | ||||||
|             day_embedding: EmbeddingConfig::new(7, self.day_features).init(device), |  | ||||||
|             lstm_1: LstmConfig::new(num_features, lstm_1_hidden_size, true).init(device), |  | ||||||
|             dropout_1: DropoutConfig::new(self.dropout).init(), |  | ||||||
|             lstm_2: LstmConfig::new(lstm_1_hidden_size, lstm_2_hidden_size, true).init(device), |  | ||||||
|             dropout_2: DropoutConfig::new(self.dropout).init(), |  | ||||||
|             lstm_3: LstmConfig::new(lstm_2_hidden_size, lstm_3_hidden_size, true).init(device), |  | ||||||
|             dropout_3: DropoutConfig::new(self.dropout).init(), |  | ||||||
|             lstm_4: LstmConfig::new(lstm_3_hidden_size, lstm_4_hidden_size, true).init(device), |  | ||||||
|             dropout_4: DropoutConfig::new(self.dropout).init(), |  | ||||||
|             linear: LinearConfig::new(lstm_4_hidden_size, 1).init(device), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<B: Backend> Model<B> { |  | ||||||
|     pub fn forward( |  | ||||||
|         &self, |  | ||||||
|         hour: Tensor<B, 2, tensor::Int>, |  | ||||||
|         day: Tensor<B, 2, tensor::Int>, |  | ||||||
|         numerical: Tensor<B, 3>, |  | ||||||
|     ) -> Tensor<B, 2> { |  | ||||||
|         let hour = self.hour_embedding.forward(hour); |  | ||||||
|         let day = self.day_embedding.forward(day); |  | ||||||
|  |  | ||||||
|         let x = Tensor::cat(vec![hour, day, numerical], 2); |  | ||||||
|  |  | ||||||
|         let (_, x) = self.lstm_1.forward(x, None); |  | ||||||
|         let x = self.dropout_1.forward(x); |  | ||||||
|         let (_, x) = self.lstm_2.forward(x, None); |  | ||||||
|         let x = self.dropout_2.forward(x); |  | ||||||
|         let (_, x) = self.lstm_3.forward(x, None); |  | ||||||
|         let x = self.dropout_3.forward(x); |  | ||||||
|         let (_, x) = self.lstm_4.forward(x, None); |  | ||||||
|         let x = self.dropout_4.forward(x); |  | ||||||
|  |  | ||||||
|         let [batch_size, window_size, features] = x.shape().dims; |  | ||||||
|  |  | ||||||
|         let x = x.slice([0..batch_size, window_size - 1..window_size, 0..features]); |  | ||||||
|         let x = x.squeeze(1); |  | ||||||
|  |  | ||||||
|         self.linear.forward(x) |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn forward_regression( |  | ||||||
|         &self, |  | ||||||
|         hour: Tensor<B, 2, tensor::Int>, |  | ||||||
|         day: Tensor<B, 2, tensor::Int>, |  | ||||||
|         numerical: Tensor<B, 3>, |  | ||||||
|         target: Tensor<B, 2>, |  | ||||||
|     ) -> RegressionOutput<B> { |  | ||||||
|         let output = self.forward(hour, day, numerical); |  | ||||||
|         let loss = MseLoss::new().forward(output.clone(), target.clone(), Reduction::Mean); |  | ||||||
|  |  | ||||||
|         RegressionOutput::new(loss, output, target) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<B: AutodiffBackend> TrainStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> { |  | ||||||
|     fn step(&self, batch: BarWindowBatch<B>) -> TrainOutput<RegressionOutput<B>> { |  | ||||||
|         let item = self.forward_regression( |  | ||||||
|             batch.hour_tensor, |  | ||||||
|             batch.day_tensor, |  | ||||||
|             batch.numerical_tensor, |  | ||||||
|             batch.target_tensor, |  | ||||||
|         ); |  | ||||||
|  |  | ||||||
|         TrainOutput::new(self, item.loss.backward(), item) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<B: Backend> ValidStep<BarWindowBatch<B>, RegressionOutput<B>> for Model<B> { |  | ||||||
|     fn step(&self, batch: BarWindowBatch<B>) -> RegressionOutput<B> { |  | ||||||
|         self.forward_regression( |  | ||||||
|             batch.hour_tensor, |  | ||||||
|             batch.day_tensor, |  | ||||||
|             batch.numerical_tensor, |  | ||||||
|             batch.target_tensor, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|     use burn::{backend::Wgpu, tensor::Distribution}; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     #[ignore] |  | ||||||
|     fn test_model() { |  | ||||||
|         let device = Default::default(); |  | ||||||
|         let distribution = Distribution::Normal(0.0, 1.0); |  | ||||||
|  |  | ||||||
|         let config = ModelConfig::new().with_numerical_features(7); |  | ||||||
|  |  | ||||||
|         let model = config.init::<Wgpu>(&device); |  | ||||||
|  |  | ||||||
|         let hour = Tensor::ones([2, 10], &device); |  | ||||||
|         let day = Tensor::ones([2, 10], &device); |  | ||||||
|         let numerical = Tensor::random([2, 10, 7], distribution, &device); |  | ||||||
|  |  | ||||||
|         let output = model.forward(hour, day, numerical); |  | ||||||
|  |  | ||||||
|         assert_eq!(output.shape().dims, [2, 1]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,6 +0,0 @@ | |||||||
| pub mod alpaca; |  | ||||||
| pub mod database; |  | ||||||
| pub mod ml; |  | ||||||
| pub mod ta; |  | ||||||
| pub mod types; |  | ||||||
| pub mod utils; |  | ||||||
| @@ -1,149 +0,0 @@ | |||||||
| use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize}; |  | ||||||
|  |  | ||||||
| pub struct BbandsState { |  | ||||||
|     window: VecDeque<f64>, |  | ||||||
|     sum: f64, |  | ||||||
|     squared_sum: f64, |  | ||||||
|     multiplier: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Bbands<T>: Iterator + Sized { |  | ||||||
|     fn bbands( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|         multiplier: f64, // Typically 2.0 |  | ||||||
|     ) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Bbands<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn bbands( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|         multiplier: f64, |  | ||||||
|     ) -> Scan<Self, BbandsState, fn(&mut BbandsState, T) -> Option<(f64, f64, f64)>> { |  | ||||||
|         self.scan( |  | ||||||
|             BbandsState { |  | ||||||
|                 window: VecDeque::from(vec![0.0; period.get()]), |  | ||||||
|                 sum: 0.0, |  | ||||||
|                 squared_sum: 0.0, |  | ||||||
|                 multiplier, |  | ||||||
|             }, |  | ||||||
|             |state: &mut BbandsState, value: T| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 let front = state.window.pop_front().unwrap(); |  | ||||||
|                 state.sum -= front; |  | ||||||
|                 state.squared_sum -= front.powi(2); |  | ||||||
|  |  | ||||||
|                 state.window.push_back(value); |  | ||||||
|                 state.sum += value; |  | ||||||
|                 state.squared_sum += value.powi(2); |  | ||||||
|  |  | ||||||
|                 let mean = state.sum / state.window.len() as f64; |  | ||||||
|                 let variance = |  | ||||||
|                     ((state.squared_sum / state.window.len() as f64) - mean.powi(2)).max(0.0); |  | ||||||
|                 let standard_deviation = variance.sqrt(); |  | ||||||
|  |  | ||||||
|                 let upper_band = mean + state.multiplier * standard_deviation; |  | ||||||
|                 let lower_band = mean - state.multiplier * standard_deviation; |  | ||||||
|  |  | ||||||
|                 Some((upper_band, mean, lower_band)) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_bbands() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let bbands = data |  | ||||||
|             .into_iter() |  | ||||||
|             .bbands(NonZeroUsize::new(3).unwrap(), 2.0) |  | ||||||
|             .map(|(upper, mean, lower)| { |  | ||||||
|                 ( |  | ||||||
|                     (upper * 100.0).round() / 100.0, |  | ||||||
|                     (mean * 100.0).round() / 100.0, |  | ||||||
|                     (lower * 100.0).round() / 100.0, |  | ||||||
|                 ) |  | ||||||
|             }) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             bbands, |  | ||||||
|             vec![ |  | ||||||
|                 (1.28, 0.33, -0.61), |  | ||||||
|                 (2.63, 1.0, -0.63), |  | ||||||
|                 (3.63, 2.0, 0.37), |  | ||||||
|                 (4.63, 3.0, 1.37), |  | ||||||
|                 (5.63, 4.0, 2.37) |  | ||||||
|             ] |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_bbands_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let bbands = data |  | ||||||
|             .into_iter() |  | ||||||
|             .bbands(NonZeroUsize::new(3).unwrap(), 2.0) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(bbands, Vec::<(f64, f64, f64)>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_bbands_1_period() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let bbands = data |  | ||||||
|             .into_iter() |  | ||||||
|             .bbands(NonZeroUsize::new(1).unwrap(), 2.0) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             bbands, |  | ||||||
|             vec![ |  | ||||||
|                 (1.0, 1.0, 1.0), |  | ||||||
|                 (2.0, 2.0, 2.0), |  | ||||||
|                 (3.0, 3.0, 3.0), |  | ||||||
|                 (4.0, 4.0, 4.0), |  | ||||||
|                 (5.0, 5.0, 5.0) |  | ||||||
|             ] |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_bbands_borrow() { |  | ||||||
|         let data = [1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let bbands = data |  | ||||||
|             .iter() |  | ||||||
|             .bbands(NonZeroUsize::new(3).unwrap(), 2.0) |  | ||||||
|             .map(|(upper, mean, lower)| { |  | ||||||
|                 ( |  | ||||||
|                     (upper * 100.0).round() / 100.0, |  | ||||||
|                     (mean * 100.0).round() / 100.0, |  | ||||||
|                     (lower * 100.0).round() / 100.0, |  | ||||||
|                 ) |  | ||||||
|             }) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             bbands, |  | ||||||
|             vec![ |  | ||||||
|                 (1.28, 0.33, -0.61), |  | ||||||
|                 (2.63, 1.0, -0.63), |  | ||||||
|                 (3.63, 2.0, 0.37), |  | ||||||
|                 (4.63, 3.0, 1.37), |  | ||||||
|                 (5.63, 4.0, 2.37) |  | ||||||
|             ] |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,59 +0,0 @@ | |||||||
| use std::{borrow::Borrow, iter::Scan}; |  | ||||||
|  |  | ||||||
| pub struct DerivState { |  | ||||||
|     pub last: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Deriv<T>: Iterator + Sized { |  | ||||||
|     fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Deriv<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn deriv(self) -> Scan<Self, DerivState, fn(&mut DerivState, T) -> Option<f64>> { |  | ||||||
|         self.scan( |  | ||||||
|             DerivState { last: 0.0 }, |  | ||||||
|             |state: &mut DerivState, value: T| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 let deriv = value - state.last; |  | ||||||
|                 state.last = value; |  | ||||||
|  |  | ||||||
|                 Some(deriv) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_deriv() { |  | ||||||
|         let data = vec![1.0, 3.0, 6.0, 3.0, 1.0]; |  | ||||||
|         let deriv = data.into_iter().deriv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_deriv_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let deriv = data.into_iter().deriv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(deriv, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_deriv_borrow() { |  | ||||||
|         let data = [1.0, 3.0, 6.0, 3.0, 1.0]; |  | ||||||
|         let deriv = data.iter().deriv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(deriv, vec![1.0, 2.0, 3.0, -3.0, -2.0]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,95 +0,0 @@ | |||||||
| use std::{ |  | ||||||
|     borrow::Borrow, |  | ||||||
|     iter::{Peekable, Scan}, |  | ||||||
|     num::NonZeroUsize, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub struct EmaState { |  | ||||||
|     weight: f64, |  | ||||||
|     ema: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Ema<T>: Iterator + Sized { |  | ||||||
|     fn ema( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|     ) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Ema<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn ema( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|     ) -> Scan<Peekable<Self>, EmaState, fn(&mut EmaState, T) -> Option<f64>> { |  | ||||||
|         let smoothing = 2.0; |  | ||||||
|         let weight = smoothing / (1.0 + period.get() as f64); |  | ||||||
|  |  | ||||||
|         let mut iter = self.peekable(); |  | ||||||
|         let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default(); |  | ||||||
|  |  | ||||||
|         iter.scan( |  | ||||||
|             EmaState { weight, ema: first }, |  | ||||||
|             |state: &mut EmaState, value: T| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 state.ema = (value * state.weight) + (state.ema * (1.0 - state.weight)); |  | ||||||
|  |  | ||||||
|                 Some(state.ema) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_ema() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let ema = data |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_ema_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let ema = data |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(ema, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_ema_1_period() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let ema = data |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(1).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(ema, vec![1.0, 2.0, 3.0, 4.0, 5.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_ema_borrow() { |  | ||||||
|         let data = [1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let ema = data |  | ||||||
|             .iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(ema, vec![1.0, 1.5, 2.25, 3.125, 4.0625]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,216 +0,0 @@ | |||||||
| use std::{ |  | ||||||
|     borrow::Borrow, |  | ||||||
|     iter::{Peekable, Scan}, |  | ||||||
|     num::NonZeroUsize, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub struct MacdState { |  | ||||||
|     short_weight: f64, |  | ||||||
|     long_weight: f64, |  | ||||||
|     signal_weight: f64, |  | ||||||
|     short_ema: f64, |  | ||||||
|     long_ema: f64, |  | ||||||
|     signal_ema: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Macd<T>: Iterator + Sized { |  | ||||||
|     fn macd( |  | ||||||
|         self, |  | ||||||
|         short_period: NonZeroUsize,  // Typically 12 |  | ||||||
|         long_period: NonZeroUsize,   // Typically 26 |  | ||||||
|         signal_period: NonZeroUsize, // Typically 9 |  | ||||||
|     ) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Macd<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn macd( |  | ||||||
|         self, |  | ||||||
|         short_period: NonZeroUsize, |  | ||||||
|         long_period: NonZeroUsize, |  | ||||||
|         signal_period: NonZeroUsize, |  | ||||||
|     ) -> Scan<Peekable<Self>, MacdState, fn(&mut MacdState, T) -> Option<(f64, f64)>> { |  | ||||||
|         let smoothing = 2.0; |  | ||||||
|         let short_weight = smoothing / (1.0 + short_period.get() as f64); |  | ||||||
|         let long_weight = smoothing / (1.0 + long_period.get() as f64); |  | ||||||
|         let signal_weight = smoothing / (1.0 + signal_period.get() as f64); |  | ||||||
|  |  | ||||||
|         let mut iter = self.peekable(); |  | ||||||
|         let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default(); |  | ||||||
|  |  | ||||||
|         iter.scan( |  | ||||||
|             MacdState { |  | ||||||
|                 short_weight, |  | ||||||
|                 long_weight, |  | ||||||
|                 signal_weight, |  | ||||||
|                 short_ema: first, |  | ||||||
|                 long_ema: first, |  | ||||||
|                 signal_ema: 0.0, |  | ||||||
|             }, |  | ||||||
|             |state: &mut MacdState, value: T| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 state.short_ema = |  | ||||||
|                     (value * state.short_weight) + (state.short_ema * (1.0 - state.short_weight)); |  | ||||||
|                 state.long_ema = |  | ||||||
|                     (value * state.long_weight) + (state.long_ema * (1.0 - state.long_weight)); |  | ||||||
|  |  | ||||||
|                 let macd = state.short_ema - state.long_ema; |  | ||||||
|                 state.signal_ema = |  | ||||||
|                     (macd * state.signal_weight) + (state.signal_ema * (1.0 - state.signal_weight)); |  | ||||||
|  |  | ||||||
|                 Some((macd, state.signal_ema)) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::super::ema::Ema; |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_macd() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; |  | ||||||
|  |  | ||||||
|         let short_ema = data |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let long_ema = data |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(5).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let macd = short_ema |  | ||||||
|             .into_iter() |  | ||||||
|             .zip(long_ema) |  | ||||||
|             .map(|(short, long)| short - long) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let signal = macd |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let expected = macd.into_iter().zip(signal).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             data.into_iter() |  | ||||||
|                 .macd( |  | ||||||
|                     NonZeroUsize::new(3).unwrap(), |  | ||||||
|                     NonZeroUsize::new(5).unwrap(), |  | ||||||
|                     NonZeroUsize::new(3).unwrap() |  | ||||||
|                 ) |  | ||||||
|                 .collect::<Vec<_>>(), |  | ||||||
|             expected |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_macd_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             data.into_iter() |  | ||||||
|                 .macd( |  | ||||||
|                     NonZeroUsize::new(3).unwrap(), |  | ||||||
|                     NonZeroUsize::new(5).unwrap(), |  | ||||||
|                     NonZeroUsize::new(3).unwrap() |  | ||||||
|                 ) |  | ||||||
|                 .collect::<Vec<_>>(), |  | ||||||
|             vec![] |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_macd_1_period() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|  |  | ||||||
|         let short_ema = data |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(1).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let long_ema = data |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(1).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let macd = short_ema |  | ||||||
|             .into_iter() |  | ||||||
|             .zip(long_ema) |  | ||||||
|             .map(|(short, long)| short - long) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let signal = macd |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(1).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let expected = macd.into_iter().zip(signal).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             data.into_iter() |  | ||||||
|                 .macd( |  | ||||||
|                     NonZeroUsize::new(1).unwrap(), |  | ||||||
|                     NonZeroUsize::new(1).unwrap(), |  | ||||||
|                     NonZeroUsize::new(1).unwrap() |  | ||||||
|                 ) |  | ||||||
|                 .collect::<Vec<_>>(), |  | ||||||
|             expected |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_macd_borrow() { |  | ||||||
|         let data = [1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|  |  | ||||||
|         let short_ema = data |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let long_ema = data |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(5).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let macd = short_ema |  | ||||||
|             .into_iter() |  | ||||||
|             .zip(long_ema) |  | ||||||
|             .map(|(short, long)| short - long) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let signal = macd |  | ||||||
|             .clone() |  | ||||||
|             .into_iter() |  | ||||||
|             .ema(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         let expected = macd.into_iter().zip(signal).collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!( |  | ||||||
|             data.iter() |  | ||||||
|                 .macd( |  | ||||||
|                     NonZeroUsize::new(3).unwrap(), |  | ||||||
|                     NonZeroUsize::new(5).unwrap(), |  | ||||||
|                     NonZeroUsize::new(3).unwrap() |  | ||||||
|                 ) |  | ||||||
|                 .collect::<Vec<_>>(), |  | ||||||
|             expected |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,17 +0,0 @@ | |||||||
| pub mod bbands; |  | ||||||
| pub mod deriv; |  | ||||||
| pub mod ema; |  | ||||||
| pub mod macd; |  | ||||||
| pub mod obv; |  | ||||||
| pub mod pct; |  | ||||||
| pub mod rsi; |  | ||||||
| pub mod sma; |  | ||||||
|  |  | ||||||
| pub use bbands::Bbands; |  | ||||||
| pub use deriv::Deriv; |  | ||||||
| pub use ema::Ema; |  | ||||||
| pub use macd::Macd; |  | ||||||
| pub use obv::Obv; |  | ||||||
| pub use pct::Pct; |  | ||||||
| pub use rsi::Rsi; |  | ||||||
| pub use sma::Sma; |  | ||||||
| @@ -1,73 +0,0 @@ | |||||||
| use std::{ |  | ||||||
|     borrow::Borrow, |  | ||||||
|     iter::{Peekable, Scan}, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub struct ObvState { |  | ||||||
|     last: f64, |  | ||||||
|     obv: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Obv<T>: Iterator + Sized { |  | ||||||
|     fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Obv<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<(f64, f64)>, |  | ||||||
| { |  | ||||||
|     fn obv(self) -> Scan<Peekable<Self>, ObvState, fn(&mut ObvState, T) -> Option<f64>> { |  | ||||||
|         let mut iter = self.peekable(); |  | ||||||
|         let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default(); |  | ||||||
|  |  | ||||||
|         iter.scan( |  | ||||||
|             ObvState { |  | ||||||
|                 last: first.0, |  | ||||||
|                 obv: 0.0, |  | ||||||
|             }, |  | ||||||
|             |state: &mut ObvState, value: T| { |  | ||||||
|                 let (close, volume) = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 if close > state.last { |  | ||||||
|                     state.obv += volume; |  | ||||||
|                 } else if close < state.last { |  | ||||||
|                     state.obv -= volume; |  | ||||||
|                 } |  | ||||||
|                 state.last = close; |  | ||||||
|  |  | ||||||
|                 Some(state.obv) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_obv() { |  | ||||||
|         let data = vec![(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)]; |  | ||||||
|         let obv = data.into_iter().obv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_obv_empty() { |  | ||||||
|         let data = Vec::<(f64, f64)>::new(); |  | ||||||
|         let obv = data.into_iter().obv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(obv, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_obv_borrow() { |  | ||||||
|         let data = [(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (2.0, 4.0), (1.0, 5.0)]; |  | ||||||
|         let obv = data.iter().obv().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(obv, vec![0.0, 2.0, 5.0, 1.0, -4.0]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,64 +0,0 @@ | |||||||
| use std::{borrow::Borrow, iter::Scan}; |  | ||||||
|  |  | ||||||
| pub struct PctState { |  | ||||||
|     pub last: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Pct<T>: Iterator + Sized { |  | ||||||
|     fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Pct<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn pct(self) -> Scan<Self, PctState, fn(&mut PctState, T) -> Option<f64>> { |  | ||||||
|         self.scan(PctState { last: 0.0 }, |state: &mut PctState, value: T| { |  | ||||||
|             let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|             let pct = value / state.last - 1.0; |  | ||||||
|             state.last = value; |  | ||||||
|  |  | ||||||
|             Some(pct) |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_pct() { |  | ||||||
|         let data = vec![1.0, 2.0, 4.0, 2.0, 1.0]; |  | ||||||
|         let pct = data.into_iter().pct().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_pct_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let pct = data.into_iter().pct().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(pct, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_pct_0() { |  | ||||||
|         let data = vec![1.0, 0.0, 4.0, 2.0, 1.0]; |  | ||||||
|         let pct = data.into_iter().pct().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(pct, vec![f64::INFINITY, -1.0, f64::INFINITY, -0.5, -0.5]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_pct_borrow() { |  | ||||||
|         let data = [1.0, 2.0, 4.0, 2.0, 1.0]; |  | ||||||
|         let pct = data.iter().pct().collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(pct, vec![f64::INFINITY, 1.0, 1.0, -0.5, -0.5]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,135 +0,0 @@ | |||||||
| use std::{ |  | ||||||
|     borrow::Borrow, |  | ||||||
|     collections::VecDeque, |  | ||||||
|     iter::{Peekable, Scan}, |  | ||||||
|     num::NonZeroUsize, |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| pub struct RsiState { |  | ||||||
|     last: f64, |  | ||||||
|     window_gains: VecDeque<f64>, |  | ||||||
|     window_losses: VecDeque<f64>, |  | ||||||
|     sum_gains: f64, |  | ||||||
|     sum_losses: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Rsi<T>: Iterator + Sized { |  | ||||||
|     fn rsi( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, // Typically 14 |  | ||||||
|     ) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Rsi<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn rsi( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|     ) -> Scan<Peekable<Self>, RsiState, fn(&mut RsiState, T) -> Option<f64>> { |  | ||||||
|         let mut iter = self.peekable(); |  | ||||||
|         let first = iter.peek().map(|value| *value.borrow()).unwrap_or_default(); |  | ||||||
|  |  | ||||||
|         iter.scan( |  | ||||||
|             RsiState { |  | ||||||
|                 last: first, |  | ||||||
|                 window_gains: VecDeque::from(vec![0.0; period.get()]), |  | ||||||
|                 window_losses: VecDeque::from(vec![0.0; period.get()]), |  | ||||||
|                 sum_gains: 0.0, |  | ||||||
|                 sum_losses: 0.0, |  | ||||||
|             }, |  | ||||||
|             |state, value| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 state.sum_gains -= state.window_gains.pop_front().unwrap(); |  | ||||||
|                 state.sum_losses -= state.window_losses.pop_front().unwrap(); |  | ||||||
|  |  | ||||||
|                 let gain = (value - state.last).max(0.0); |  | ||||||
|                 let loss = (state.last - value).max(0.0); |  | ||||||
|  |  | ||||||
|                 state.last = value; |  | ||||||
|  |  | ||||||
|                 state.window_gains.push_back(gain); |  | ||||||
|                 state.window_losses.push_back(loss); |  | ||||||
|                 state.sum_gains += gain; |  | ||||||
|                 state.sum_losses += loss; |  | ||||||
|  |  | ||||||
|                 let avg_loss = state.sum_losses / state.window_losses.len() as f64; |  | ||||||
|  |  | ||||||
|                 if avg_loss == 0.0 { |  | ||||||
|                     return Some(100.0); |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 let avg_gain = state.sum_gains / state.window_gains.len() as f64; |  | ||||||
|                 let rs = avg_gain / avg_loss; |  | ||||||
|  |  | ||||||
|                 Some(100.0 - (100.0 / (1.0 + rs))) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_rsi() { |  | ||||||
|         let data = vec![1.0, 4.0, 7.0, 4.0, 1.0]; |  | ||||||
|         let rsi = data |  | ||||||
|             .into_iter() |  | ||||||
|             .rsi(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .map(|v| (v * 100.0).round() / 100.0) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_rsi_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let rsi = data |  | ||||||
|             .into_iter() |  | ||||||
|             .rsi(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(rsi, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_rsi_no_loss() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let rsi = data |  | ||||||
|             .into_iter() |  | ||||||
|             .rsi(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(rsi, vec![100.0, 100.0, 100.0, 100.0, 100.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_rsi_no_gain() { |  | ||||||
|         let data = vec![5.0, 4.0, 3.0, 2.0, 1.0]; |  | ||||||
|         let rsi = data |  | ||||||
|             .into_iter() |  | ||||||
|             .rsi(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(rsi, vec![100.0, 0.0, 0.0, 0.0, 0.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_rsi_borrow() { |  | ||||||
|         let data = [1.0, 4.0, 7.0, 4.0, 1.0]; |  | ||||||
|         let rsi = data |  | ||||||
|             .iter() |  | ||||||
|             .rsi(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .map(|v| (v * 100.0).round() / 100.0) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(rsi, vec![100.0, 100.0, 100.0, 66.67, 33.33]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,88 +0,0 @@ | |||||||
| use std::{borrow::Borrow, collections::VecDeque, iter::Scan, num::NonZeroUsize}; |  | ||||||
|  |  | ||||||
| pub struct SmaState { |  | ||||||
|     window: VecDeque<f64>, |  | ||||||
|     sum: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::type_complexity)] |  | ||||||
| pub trait Sma<T>: Iterator + Sized { |  | ||||||
|     fn sma(self, period: NonZeroUsize) |  | ||||||
|         -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>>; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl<I, T> Sma<T> for I |  | ||||||
| where |  | ||||||
|     I: Iterator<Item = T>, |  | ||||||
|     T: Borrow<f64>, |  | ||||||
| { |  | ||||||
|     fn sma( |  | ||||||
|         self, |  | ||||||
|         period: NonZeroUsize, |  | ||||||
|     ) -> Scan<Self, SmaState, fn(&mut SmaState, T) -> Option<f64>> { |  | ||||||
|         self.scan( |  | ||||||
|             SmaState { |  | ||||||
|                 window: VecDeque::from(vec![0.0; period.get()]), |  | ||||||
|                 sum: 0.0, |  | ||||||
|             }, |  | ||||||
|             |state: &mut SmaState, value: T| { |  | ||||||
|                 let value = *value.borrow(); |  | ||||||
|  |  | ||||||
|                 state.sum -= state.window.pop_front().unwrap(); |  | ||||||
|                 state.window.push_back(value); |  | ||||||
|                 state.sum += value; |  | ||||||
|  |  | ||||||
|                 Some(state.sum / state.window.len() as f64) |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_sma() { |  | ||||||
|         let data = vec![3.0, 6.0, 9.0, 12.0, 15.0]; |  | ||||||
|         let sma = data |  | ||||||
|             .into_iter() |  | ||||||
|             .sma(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_sma_empty() { |  | ||||||
|         let data = Vec::<f64>::new(); |  | ||||||
|         let sma = data |  | ||||||
|             .into_iter() |  | ||||||
|             .sma(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(sma, Vec::<f64>::new()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_sma_1_period() { |  | ||||||
|         let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; |  | ||||||
|         let sma = data |  | ||||||
|             .into_iter() |  | ||||||
|             .sma(NonZeroUsize::new(1).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(sma, vec![1.0, 2.0, 3.0, 4.0, 5.0]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_sma_borrow() { |  | ||||||
|         let data = [3.0, 6.0, 9.0, 12.0, 15.0]; |  | ||||||
|         let sma = data |  | ||||||
|             .iter() |  | ||||||
|             .sma(NonZeroUsize::new(3).unwrap()) |  | ||||||
|             .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|         assert_eq!(sma, vec![1.0, 3.0, 6.0, 9.0, 12.0]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,39 +0,0 @@ | |||||||
| use super::position::Position; |  | ||||||
| use crate::types::{self, alpaca::shared::asset}; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use serde_aux::field_attributes::deserialize_option_number_from_string; |  | ||||||
| use uuid::Uuid; |  | ||||||
|  |  | ||||||
| pub use asset::{Class, Exchange, Status}; |  | ||||||
|  |  | ||||||
| #[allow(clippy::struct_excessive_bools)] |  | ||||||
| #[derive(Deserialize, Clone)] |  | ||||||
| pub struct Asset { |  | ||||||
|     pub id: Uuid, |  | ||||||
|     pub class: Class, |  | ||||||
|     pub exchange: Exchange, |  | ||||||
|     pub symbol: String, |  | ||||||
|     pub name: String, |  | ||||||
|     pub status: Status, |  | ||||||
|     pub tradable: bool, |  | ||||||
|     pub marginable: bool, |  | ||||||
|     pub shortable: bool, |  | ||||||
|     pub easy_to_borrow: bool, |  | ||||||
|     pub fractionable: bool, |  | ||||||
|     #[serde(deserialize_with = "deserialize_option_number_from_string")] |  | ||||||
|     pub maintenance_margin_requirement: Option<f32>, |  | ||||||
|     pub attributes: Option<Vec<String>>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<(Asset, Option<Position>)> for types::Asset { |  | ||||||
|     fn from((asset, position): (Asset, Option<Position>)) -> Self { |  | ||||||
|         Self { |  | ||||||
|             symbol: asset.symbol, |  | ||||||
|             class: asset.class.into(), |  | ||||||
|             exchange: asset.exchange.into(), |  | ||||||
|             status: asset.status.into(), |  | ||||||
|             time_added: time::OffsetDateTime::now_utc(), |  | ||||||
|             qty: position.map(|position| position.qty).unwrap_or_default(), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,39 +0,0 @@ | |||||||
| use crate::types; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Bar { |  | ||||||
|     #[serde(rename = "t")] |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     pub time: OffsetDateTime, |  | ||||||
|     #[serde(rename = "o")] |  | ||||||
|     pub open: f64, |  | ||||||
|     #[serde(rename = "h")] |  | ||||||
|     pub high: f64, |  | ||||||
|     #[serde(rename = "l")] |  | ||||||
|     pub low: f64, |  | ||||||
|     #[serde(rename = "c")] |  | ||||||
|     pub close: f64, |  | ||||||
|     #[serde(rename = "v")] |  | ||||||
|     pub volume: f64, |  | ||||||
|     #[serde(rename = "n")] |  | ||||||
|     pub trades: i64, |  | ||||||
|     #[serde(rename = "vw")] |  | ||||||
|     pub vwap: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<(Bar, String)> for types::Bar { |  | ||||||
|     fn from((bar, symbol): (Bar, String)) -> Self { |  | ||||||
|         Self { |  | ||||||
|             symbol, |  | ||||||
|             time: bar.time, |  | ||||||
|             open: bar.open, |  | ||||||
|             high: bar.high, |  | ||||||
|             low: bar.low, |  | ||||||
|             close: bar.close, |  | ||||||
|             volume: bar.volume, |  | ||||||
|             trades: bar.trades, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,26 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     types, |  | ||||||
|     utils::{de, time::EST_OFFSET}, |  | ||||||
| }; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use time::{Date, OffsetDateTime, Time}; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Calendar { |  | ||||||
|     pub date: Date, |  | ||||||
|     #[serde(deserialize_with = "de::human_time_hh_mm")] |  | ||||||
|     pub open: Time, |  | ||||||
|     #[serde(deserialize_with = "de::human_time_hh_mm")] |  | ||||||
|     pub close: Time, |  | ||||||
|     pub settlement_date: Date, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<Calendar> for types::Calendar { |  | ||||||
|     fn from(calendar: Calendar) -> Self { |  | ||||||
|         Self { |  | ||||||
|             date: calendar.date, |  | ||||||
|             open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET), |  | ||||||
|             close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,13 +0,0 @@ | |||||||
| use serde::Deserialize; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Clock { |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     pub timestamp: OffsetDateTime, |  | ||||||
|     pub is_open: bool, |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     pub next_open: OffsetDateTime, |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     pub next_close: OffsetDateTime, |  | ||||||
| } |  | ||||||
| @@ -1,57 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     types::{self, alpaca::shared::news::strip}, |  | ||||||
|     utils::de, |  | ||||||
| }; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| #[serde(rename_all = "snake_case")] |  | ||||||
| pub enum ImageSize { |  | ||||||
|     Thumb, |  | ||||||
|     Small, |  | ||||||
|     Large, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct Image { |  | ||||||
|     pub size: ImageSize, |  | ||||||
|     pub url: String, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Deserialize)] |  | ||||||
| pub struct News { |  | ||||||
|     pub id: i64, |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     #[serde(rename = "created_at")] |  | ||||||
|     pub time_created: OffsetDateTime, |  | ||||||
|     #[serde(with = "time::serde::rfc3339")] |  | ||||||
|     #[serde(rename = "updated_at")] |  | ||||||
|     pub time_updated: OffsetDateTime, |  | ||||||
|     #[serde(deserialize_with = "de::add_slash_to_symbols")] |  | ||||||
|     pub symbols: Vec<String>, |  | ||||||
|     pub headline: String, |  | ||||||
|     pub author: String, |  | ||||||
|     pub source: String, |  | ||||||
|     pub summary: String, |  | ||||||
|     pub content: String, |  | ||||||
|     pub url: Option<String>, |  | ||||||
|     pub images: Vec<Image>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<News> for types::News { |  | ||||||
|     fn from(news: News) -> Self { |  | ||||||
|         Self { |  | ||||||
|             id: news.id, |  | ||||||
|             time_created: news.time_created, |  | ||||||
|             time_updated: news.time_updated, |  | ||||||
|             symbols: news.symbols, |  | ||||||
|             headline: strip(&news.headline), |  | ||||||
|             author: strip(&news.author), |  | ||||||
|             source: strip(&news.source), |  | ||||||
|             summary: news.summary, |  | ||||||
|             content: news.content, |  | ||||||
|             url: news.url.unwrap_or_default(), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,3 +0,0 @@ | |||||||
| use crate::types::alpaca::shared::order; |  | ||||||
|  |  | ||||||
| pub use order::{Order, Side}; |  | ||||||
| @@ -1,61 +0,0 @@ | |||||||
| use crate::{ |  | ||||||
|     types::alpaca::api::incoming::{ |  | ||||||
|         asset::{Class, Exchange}, |  | ||||||
|         order, |  | ||||||
|     }, |  | ||||||
|     utils::de, |  | ||||||
| }; |  | ||||||
| use serde::Deserialize; |  | ||||||
| use serde_aux::field_attributes::deserialize_number_from_string; |  | ||||||
| use uuid::Uuid; |  | ||||||
|  |  | ||||||
| #[derive(Deserialize, Clone, Copy)] |  | ||||||
| #[serde(rename_all = "snake_case")] |  | ||||||
| pub enum Side { |  | ||||||
|     Long, |  | ||||||
|     Short, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl From<Side> for order::Side { |  | ||||||
|     fn from(side: Side) -> Self { |  | ||||||
|         match side { |  | ||||||
|             Side::Long => Self::Buy, |  | ||||||
|             Side::Short => Self::Sell, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[derive(Deserialize, Clone)] |  | ||||||
| pub struct Position { |  | ||||||
|     pub asset_id: Uuid, |  | ||||||
|     #[serde(deserialize_with = "de::add_slash_to_symbol")] |  | ||||||
|     pub symbol: String, |  | ||||||
|     pub exchange: Exchange, |  | ||||||
|     pub asset_class: Class, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub avg_entry_price: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub qty: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub qty_available: f64, |  | ||||||
|     pub side: Side, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub market_value: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub cost_basis: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub unrealized_pl: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub unrealized_plpc: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub unrealized_intraday_pl: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub unrealized_intraday_plpc: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub current_price: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub lastday_price: f64, |  | ||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |  | ||||||
|     pub change_today: f64, |  | ||||||
|     pub asset_marginable: bool, |  | ||||||
| } |  | ||||||
| @@ -1,6 +0,0 @@ | |||||||
| pub mod incoming; |  | ||||||
| pub mod outgoing; |  | ||||||
|  |  | ||||||
| pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; |  | ||||||
| pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; |  | ||||||
| pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news"; |  | ||||||
| @@ -1,23 +0,0 @@ | |||||||
| use crate::types::alpaca::shared::asset; |  | ||||||
| use serde::Serialize; |  | ||||||
|  |  | ||||||
| pub use asset::{Class, Exchange, Status}; |  | ||||||
|  |  | ||||||
| #[derive(Serialize)] |  | ||||||
| pub struct Asset { |  | ||||||
|     pub status: Option<Status>, |  | ||||||
|     pub class: Option<Class>, |  | ||||||
|     pub exchange: Option<Exchange>, |  | ||||||
|     pub attributes: Option<Vec<String>>, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| impl Default for Asset { |  | ||||||
|     fn default() -> Self { |  | ||||||
|         Self { |  | ||||||
|             status: None, |  | ||||||
|             class: Some(Class::UsEquity), |  | ||||||
|             exchange: None, |  | ||||||
|             attributes: None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,8 +0,0 @@ | |||||||
| pub mod auth; |  | ||||||
| pub mod data; |  | ||||||
| pub mod trading; |  | ||||||
|  |  | ||||||
| pub const ALPACA_US_EQUITY_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; |  | ||||||
| pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str = |  | ||||||
|     "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; |  | ||||||
| pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; |  | ||||||
| @@ -1,11 +0,0 @@ | |||||||
| use clickhouse::Row; |  | ||||||
| use serde::{Deserialize, Serialize}; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] |  | ||||||
| pub struct Backfill { |  | ||||||
|     pub symbol: String, |  | ||||||
|     #[serde(with = "clickhouse::serde::time::datetime")] |  | ||||||
|     pub time: OffsetDateTime, |  | ||||||
|     pub fresh: bool, |  | ||||||
| } |  | ||||||
| @@ -1,19 +0,0 @@ | |||||||
| use clickhouse::Row; |  | ||||||
| use serde::{Deserialize, Serialize}; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)] |  | ||||||
| pub struct News { |  | ||||||
|     pub id: i64, |  | ||||||
|     #[serde(with = "clickhouse::serde::time::datetime")] |  | ||||||
|     pub time_created: OffsetDateTime, |  | ||||||
|     #[serde(with = "clickhouse::serde::time::datetime")] |  | ||||||
|     pub time_updated: OffsetDateTime, |  | ||||||
|     pub symbols: Vec<String>, |  | ||||||
|     pub headline: String, |  | ||||||
|     pub author: String, |  | ||||||
|     pub source: String, |  | ||||||
|     pub summary: String, |  | ||||||
|     pub content: String, |  | ||||||
|     pub url: String, |  | ||||||
| } |  | ||||||
| @@ -1,277 +0,0 @@ | |||||||
| use super::Bar; |  | ||||||
| use crate::ta::{Bbands, Deriv, Ema, Macd, Obv, Pct, Rsi, Sma}; |  | ||||||
| use clickhouse::Row; |  | ||||||
| use itertools::Itertools; |  | ||||||
| use rayon::scope; |  | ||||||
| use serde::{Deserialize, Serialize}; |  | ||||||
| use std::num::NonZeroUsize; |  | ||||||
| use time::OffsetDateTime; |  | ||||||
|  |  | ||||||
| pub const HEAD_SIZE: usize = 72; |  | ||||||
| pub const FIELD_COUNT: usize = 33; |  | ||||||
| pub const NUMERICAL_FIELD_COUNT: usize = FIELD_COUNT - 2; |  | ||||||
|  |  | ||||||
| #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)] |  | ||||||
| pub struct IndicatedBar { |  | ||||||
|     pub symbol: String, |  | ||||||
|     #[serde(with = "clickhouse::serde::time::datetime")] |  | ||||||
|     pub time: OffsetDateTime, |  | ||||||
|     pub hour: u8, |  | ||||||
|     pub day: u8, |  | ||||||
|     pub open: f64, |  | ||||||
|     pub open_pct: f64, |  | ||||||
|     pub high: f64, |  | ||||||
|     pub high_pct: f64, |  | ||||||
|     pub low: f64, |  | ||||||
|     pub low_pct: f64, |  | ||||||
|     pub close: f64, |  | ||||||
|     pub close_pct: f64, |  | ||||||
|     pub volume: f64, |  | ||||||
|     pub volume_pct: f64, |  | ||||||
|     pub trades: f64, |  | ||||||
|     pub trades_pct: f64, |  | ||||||
|     pub sma_3: f64, |  | ||||||
|     pub sma_6: f64, |  | ||||||
|     pub sma_12: f64, |  | ||||||
|     pub sma_24: f64, |  | ||||||
|     pub sma_48: f64, |  | ||||||
|     pub sma_72: f64, |  | ||||||
|     pub ema_3: f64, |  | ||||||
|     pub ema_6: f64, |  | ||||||
|     pub ema_12: f64, |  | ||||||
|     pub ema_24: f64, |  | ||||||
|     pub ema_48: f64, |  | ||||||
|     pub ema_72: f64, |  | ||||||
|     pub macd: f64, |  | ||||||
|     pub macd_signal: f64, |  | ||||||
|     pub obv: f64, |  | ||||||
|     pub rsi: f64, |  | ||||||
|     pub bbands_lower: f64, |  | ||||||
|     pub bbands_mean: f64, |  | ||||||
|     pub bbands_upper: f64, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[allow(clippy::too_many_lines)] |  | ||||||
| fn _calculate_indicators(bars: &[Bar]) -> Vec<IndicatedBar> { |  | ||||||
|     let length = bars.len(); |  | ||||||
|  |  | ||||||
|     let (symbol, time, hour, day, open, high, low, close, volume, trades) = bars.iter().fold( |  | ||||||
|         ( |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|             Vec::with_capacity(length), |  | ||||||
|         ), |  | ||||||
|         |( |  | ||||||
|             mut symbol, |  | ||||||
|             mut time, |  | ||||||
|             mut hour, |  | ||||||
|             mut day, |  | ||||||
|             mut open, |  | ||||||
|             mut high, |  | ||||||
|             mut low, |  | ||||||
|             mut close, |  | ||||||
|             mut volume, |  | ||||||
|             mut trades, |  | ||||||
|         ), |  | ||||||
|          bar| { |  | ||||||
|             symbol.push(bar.symbol.clone()); |  | ||||||
|             time.push(bar.time); |  | ||||||
|             hour.push(bar.time.hour()); |  | ||||||
|             day.push(bar.time.day()); |  | ||||||
|             open.push(bar.open); |  | ||||||
|             high.push(bar.high); |  | ||||||
|             low.push(bar.low); |  | ||||||
|             close.push(bar.close); |  | ||||||
|             volume.push(bar.volume); |  | ||||||
|             trades.push(bar.trades as f64); |  | ||||||
|             ( |  | ||||||
|                 symbol, time, hour, day, open, high, low, close, volume, trades, |  | ||||||
|             ) |  | ||||||
|         }, |  | ||||||
|     ); |  | ||||||
|  |  | ||||||
|     let mut close_deriv = Vec::with_capacity(length); |  | ||||||
|     let mut sma_3 = Vec::with_capacity(length); |  | ||||||
|     let mut sma_6 = Vec::with_capacity(length); |  | ||||||
|     let mut sma_12 = Vec::with_capacity(length); |  | ||||||
|     let mut sma_24 = Vec::with_capacity(length); |  | ||||||
|     let mut sma_48 = Vec::with_capacity(length); |  | ||||||
|     let mut sma_72 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_3 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_6 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_12 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_24 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_48 = Vec::with_capacity(length); |  | ||||||
|     let mut ema_72 = Vec::with_capacity(length); |  | ||||||
|     let mut macd = Vec::with_capacity(length); |  | ||||||
|     let mut macd_signal = Vec::with_capacity(length); |  | ||||||
|     let mut obv = Vec::with_capacity(length); |  | ||||||
|     let mut rsi = Vec::with_capacity(length); |  | ||||||
|     let mut bbands_upper = Vec::with_capacity(length); |  | ||||||
|     let mut bbands_mean = Vec::with_capacity(length); |  | ||||||
|     let mut bbands_lower = Vec::with_capacity(length); |  | ||||||
|  |  | ||||||
|     scope(|s| { |  | ||||||
|         s.spawn(|_| close_deriv.extend(close.iter().deriv())); |  | ||||||
|         s.spawn(|_| sma_3.extend(close.iter().sma(NonZeroUsize::new(3).unwrap()))); |  | ||||||
|         s.spawn(|_| sma_6.extend(close.iter().sma(NonZeroUsize::new(6).unwrap()))); |  | ||||||
|         s.spawn(|_| sma_12.extend(close.iter().sma(NonZeroUsize::new(12).unwrap()))); |  | ||||||
|         s.spawn(|_| sma_24.extend(close.iter().sma(NonZeroUsize::new(24).unwrap()))); |  | ||||||
|         s.spawn(|_| sma_48.extend(close.iter().sma(NonZeroUsize::new(48).unwrap()))); |  | ||||||
|         s.spawn(|_| sma_72.extend(close.iter().sma(NonZeroUsize::new(72).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_3.extend(close.iter().ema(NonZeroUsize::new(3).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_6.extend(close.iter().ema(NonZeroUsize::new(6).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_12.extend(close.iter().ema(NonZeroUsize::new(12).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_24.extend(close.iter().ema(NonZeroUsize::new(24).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_48.extend(close.iter().ema(NonZeroUsize::new(48).unwrap()))); |  | ||||||
|         s.spawn(|_| ema_72.extend(close.iter().ema(NonZeroUsize::new(72).unwrap()))); |  | ||||||
|         s.spawn(|_| { |  | ||||||
|             close |  | ||||||
|                 .iter() |  | ||||||
|                 .macd( |  | ||||||
|                     NonZeroUsize::new(12).unwrap(), |  | ||||||
|                     NonZeroUsize::new(26).unwrap(), |  | ||||||
|                     NonZeroUsize::new(9).unwrap(), |  | ||||||
|                 ) |  | ||||||
|                 .for_each(|(macd_val, signal_val)| { |  | ||||||
|                     macd.push(macd_val); |  | ||||||
|                     macd_signal.push(signal_val); |  | ||||||
|                 }); |  | ||||||
|         }); |  | ||||||
|         s.spawn(|_| { |  | ||||||
|             obv.extend(bars.iter().map(|bar| (bar.close, bar.volume)).obv()); |  | ||||||
|         }); |  | ||||||
|         s.spawn(|_: &_| { |  | ||||||
|             rsi.extend(close.iter().rsi(NonZeroUsize::new(14).unwrap())); |  | ||||||
|         }); |  | ||||||
|         s.spawn(|_| { |  | ||||||
|             close |  | ||||||
|                 .iter() |  | ||||||
|                 .bbands(NonZeroUsize::new(20).unwrap(), 2.0) |  | ||||||
|                 .for_each(|(upper, mean, lower)| { |  | ||||||
|                     bbands_upper.push(upper); |  | ||||||
|                     bbands_mean.push(mean); |  | ||||||
|                     bbands_lower.push(lower); |  | ||||||
|                 }); |  | ||||||
|         }) |  | ||||||
|     }); |  | ||||||
|  |  | ||||||
|     let mut open_pct = Vec::with_capacity(length); |  | ||||||
|     let mut high_pct = Vec::with_capacity(length); |  | ||||||
|     let mut low_pct = Vec::with_capacity(length); |  | ||||||
|     let mut close_pct = Vec::with_capacity(length); |  | ||||||
|     let mut volume_pct = Vec::with_capacity(length); |  | ||||||
|     let mut trades_pct = Vec::with_capacity(length); |  | ||||||
|  |  | ||||||
|     scope(|s| { |  | ||||||
|         s.spawn(|_| open_pct.extend(open.iter().pct())); |  | ||||||
|         s.spawn(|_| high_pct.extend(high.iter().pct())); |  | ||||||
|         s.spawn(|_| low_pct.extend(low.iter().pct())); |  | ||||||
|         s.spawn(|_| close_pct.extend(close.iter().pct())); |  | ||||||
|         s.spawn(|_| volume_pct.extend(volume.iter().pct())); |  | ||||||
|         s.spawn(|_| trades_pct.extend(trades.iter().pct())); |  | ||||||
|     }); |  | ||||||
|  |  | ||||||
|     bars.iter() |  | ||||||
|         .enumerate() |  | ||||||
|         .map(|(i, _)| IndicatedBar { |  | ||||||
|             symbol: symbol[i].clone(), |  | ||||||
|             time: time[i], |  | ||||||
|             hour: hour[i], |  | ||||||
|             day: day[i], |  | ||||||
|             open: open[i], |  | ||||||
|             open_pct: open_pct[i], |  | ||||||
|             high: high[i], |  | ||||||
|             high_pct: high_pct[i], |  | ||||||
|             low: low[i], |  | ||||||
|             low_pct: low_pct[i], |  | ||||||
|             close: close[i], |  | ||||||
|             close_pct: close_pct[i], |  | ||||||
|             volume: volume[i], |  | ||||||
|             volume_pct: volume_pct[i], |  | ||||||
|             trades: trades[i], |  | ||||||
|             trades_pct: trades_pct[i], |  | ||||||
|             sma_3: sma_3[i], |  | ||||||
|             sma_6: sma_6[i], |  | ||||||
|             sma_12: sma_12[i], |  | ||||||
|             sma_24: sma_24[i], |  | ||||||
|             sma_48: sma_48[i], |  | ||||||
|             sma_72: sma_72[i], |  | ||||||
|             ema_3: ema_3[i], |  | ||||||
|             ema_6: ema_6[i], |  | ||||||
|             ema_12: ema_12[i], |  | ||||||
|             ema_24: ema_24[i], |  | ||||||
|             ema_48: ema_48[i], |  | ||||||
|             ema_72: ema_72[i], |  | ||||||
|             macd: macd[i], |  | ||||||
|             macd_signal: macd_signal[i], |  | ||||||
|             obv: obv[i], |  | ||||||
|             rsi: rsi[i], |  | ||||||
|             bbands_lower: bbands_lower[i], |  | ||||||
|             bbands_mean: bbands_mean[i], |  | ||||||
|             bbands_upper: bbands_upper[i], |  | ||||||
|         }) |  | ||||||
|         .collect() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn calculate_indicators<I>(bars: I) -> Vec<Vec<IndicatedBar>> |  | ||||||
| where |  | ||||||
|     I: IntoIterator<Item = Bar>, |  | ||||||
| { |  | ||||||
|     bars.into_iter() |  | ||||||
|         .filter(|bar| { |  | ||||||
|             bar.open > 0.0 |  | ||||||
|                 && bar.high > 0.0 |  | ||||||
|                 && bar.low > 0.0 |  | ||||||
|                 && bar.close > 0.0 |  | ||||||
|                 && bar.volume > 0.0 |  | ||||||
|                 && bar.trades > 0 |  | ||||||
|         }) |  | ||||||
|         .sorted_by_key(|bar| (bar.symbol.clone(), bar.time)) |  | ||||||
|         .group_by(|bar| bar.symbol.clone()) |  | ||||||
|         .into_iter() |  | ||||||
|         .map(|(_, group)| _calculate_indicators(&group.collect::<Vec<_>>())) |  | ||||||
|         .collect::<Vec<_>>() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|     use rand::{ |  | ||||||
|         distributions::{Distribution, Uniform}, |  | ||||||
|         Rng, |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_calculate_indicators() { |  | ||||||
|         let length = 1_000_000; |  | ||||||
|  |  | ||||||
|         let mut rng = rand::thread_rng(); |  | ||||||
|         let uniform = Uniform::new(1.0, 100.0); |  | ||||||
|         let mut bars = Vec::with_capacity(length); |  | ||||||
|  |  | ||||||
|         for _ in 0..length { |  | ||||||
|             bars.push(Bar { |  | ||||||
|                 symbol: "AAPL".to_string(), |  | ||||||
|                 time: OffsetDateTime::now_utc(), |  | ||||||
|                 open: uniform.sample(&mut rng), |  | ||||||
|                 high: uniform.sample(&mut rng), |  | ||||||
|                 low: uniform.sample(&mut rng), |  | ||||||
|                 close: uniform.sample(&mut rng), |  | ||||||
|                 volume: uniform.sample(&mut rng), |  | ||||||
|                 trades: rng.gen_range(1..100), |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         let indicated_bars = calculate_indicators(bars); |  | ||||||
|  |  | ||||||
|         assert_eq!(indicated_bars[0].len(), length); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,281 +0,0 @@ | |||||||
| use serde::{ser::SerializeSeq, Serializer}; |  | ||||||
| use std::time::Duration; |  | ||||||
|  |  | ||||||
| pub fn timeframe<S>(timeframe: &Duration, serializer: S) -> Result<S::Ok, S::Error> |  | ||||||
| where |  | ||||||
|     S: serde::Serializer, |  | ||||||
| { |  | ||||||
|     let secs = timeframe.as_secs(); |  | ||||||
|  |  | ||||||
|     if secs < 60 || secs % 60 != 0 { |  | ||||||
|         return Err(serde::ser::Error::custom("Invalid timeframe duration")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let mins = secs / 60; |  | ||||||
|  |  | ||||||
|     if mins < 60 { |  | ||||||
|         return serializer.serialize_str(&format!("{mins}Min")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if mins % 60 != 0 { |  | ||||||
|         return Err(serde::ser::Error::custom("Invalid timeframe duration")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let hours = mins / 60; |  | ||||||
|  |  | ||||||
|     if hours < 24 { |  | ||||||
|         return serializer.serialize_str(&format!("{hours}Hour")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if hours % 24 != 0 { |  | ||||||
|         return Err(serde::ser::Error::custom("Invalid timeframe duration")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let days = hours / 24; |  | ||||||
|  |  | ||||||
|     if days == 1 { |  | ||||||
|         return serializer.serialize_str("1Day"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if days == 7 { |  | ||||||
|         return serializer.serialize_str("1Week"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if days < 30 || days % 30 != 0 { |  | ||||||
|         return Err(serde::ser::Error::custom("Invalid timeframe duration")); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     let months = days / 30; |  | ||||||
|  |  | ||||||
|     if [1, 2, 3, 4, 6, 12].contains(&months) { |  | ||||||
|         return serializer.serialize_str(&format!("{months}Month")); |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     Err(serde::ser::Error::custom("Invalid timeframe duration")) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| fn remove_slash(pair: &str) -> String { |  | ||||||
|     pair.replace('/', "") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn join_symbols<S>(symbols: &[String], serializer: S) -> Result<S::Ok, S::Error> |  | ||||||
| where |  | ||||||
|     S: Serializer, |  | ||||||
| { |  | ||||||
|     let string = symbols.join(","); |  | ||||||
|     serializer.serialize_str(&string) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn join_symbols_option<S>( |  | ||||||
|     symbols: &Option<Vec<String>>, |  | ||||||
|     serializer: S, |  | ||||||
| ) -> Result<S::Ok, S::Error> |  | ||||||
| where |  | ||||||
|     S: Serializer, |  | ||||||
| { |  | ||||||
|     match symbols { |  | ||||||
|         Some(symbols) => join_symbols(symbols, serializer), |  | ||||||
|         None => serializer.serialize_none(), |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn remove_slash_from_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error> |  | ||||||
| where |  | ||||||
|     S: Serializer, |  | ||||||
|     I: IntoIterator<Item = &'a String>, |  | ||||||
| { |  | ||||||
|     let symbols = symbols |  | ||||||
|         .into_iter() |  | ||||||
|         .map(|pair| remove_slash(pair)) |  | ||||||
|         .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|     let mut seq = serializer.serialize_seq(Some(symbols.len()))?; |  | ||||||
|     for symbol in symbols { |  | ||||||
|         seq.serialize_element(&symbol)?; |  | ||||||
|     } |  | ||||||
|     seq.end() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn remove_slash_and_join_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error> |  | ||||||
| where |  | ||||||
|     S: Serializer, |  | ||||||
|     I: IntoIterator<Item = &'a String>, |  | ||||||
| { |  | ||||||
|     let symbols = symbols |  | ||||||
|         .into_iter() |  | ||||||
|         .map(|symbol| remove_slash(symbol)) |  | ||||||
|         .collect::<Vec<_>>(); |  | ||||||
|  |  | ||||||
|     join_symbols(&symbols, serializer) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|     use super::*; |  | ||||||
|     use serde::Serialize; |  | ||||||
|     use serde_test::{assert_ser_tokens, assert_ser_tokens_error, Token}; |  | ||||||
|  |  | ||||||
|     #[derive(Serialize)] |  | ||||||
|     #[serde(transparent)] |  | ||||||
|     struct Timeframe { |  | ||||||
|         #[serde(serialize_with = "timeframe")] |  | ||||||
|         duration: Duration, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_30_mins() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 30), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&timeframe, &[Token::Str("30Min")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_2_hours() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 60 * 2), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&timeframe, &[Token::Str("2Hour")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_1_day() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 60 * 24), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&timeframe, &[Token::Str("1Day")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_1_week() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 60 * 24 * 7), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&timeframe, &[Token::Str("1Week")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_6_months() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 60 * 24 * 30 * 6), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&timeframe, &[Token::Str("6Month")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_invalid_1_second() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(1), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_invalid_61_seconds() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(61), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_timeframe_invalid_6_days() { |  | ||||||
|         let timeframe = Timeframe { |  | ||||||
|             duration: Duration::from_secs(60 * 60 * 24 * 6), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens_error(&timeframe, &[], "Invalid timeframe duration"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_remove_slash() { |  | ||||||
|         let pair = "BTC/USDT"; |  | ||||||
|         assert_eq!(remove_slash(pair), "BTCUSDT"); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[derive(Serialize)] |  | ||||||
|     #[serde(transparent)] |  | ||||||
|     struct JoinSymbols { |  | ||||||
|         #[serde(serialize_with = "join_symbols")] |  | ||||||
|         symbols: Vec<String>, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_join_symbols() { |  | ||||||
|         let symbols = JoinSymbols { |  | ||||||
|             symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")], |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&symbols, &[Token::Str("BTC/USD,ETH/USD")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[derive(Serialize)] |  | ||||||
|     #[serde(transparent)] |  | ||||||
|     struct JoinSymbolsOption { |  | ||||||
|         #[serde(serialize_with = "join_symbols_option")] |  | ||||||
|         symbols: Option<Vec<String>>, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_join_symbols_option_some() { |  | ||||||
|         let symbols = JoinSymbolsOption { |  | ||||||
|             symbols: Some(vec![String::from("BTC/USD"), String::from("ETH/USD")]), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&symbols, &[Token::Str("BTC/USD,ETH/USD")]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_join_symbols_option_none() { |  | ||||||
|         let symbols = JoinSymbolsOption { symbols: None }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&symbols, &[Token::None]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[derive(Serialize)] |  | ||||||
|     #[serde(transparent)] |  | ||||||
|     struct RemoveSlashFromSymbols { |  | ||||||
|         #[serde(serialize_with = "remove_slash_from_symbols")] |  | ||||||
|         symbols: Vec<String>, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_remove_slash_from_symbols() { |  | ||||||
|         let symbols = RemoveSlashFromSymbols { |  | ||||||
|             symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")], |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens( |  | ||||||
|             &symbols, |  | ||||||
|             &[ |  | ||||||
|                 Token::Seq { len: Some(2) }, |  | ||||||
|                 Token::Str("BTCUSD"), |  | ||||||
|                 Token::Str("ETHUSD"), |  | ||||||
|                 Token::SeqEnd, |  | ||||||
|             ], |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[derive(Serialize)] |  | ||||||
|     #[serde(transparent)] |  | ||||||
|     struct RemoveSlashAndJoinSymbols { |  | ||||||
|         #[serde(serialize_with = "remove_slash_and_join_symbols")] |  | ||||||
|         symbols: Vec<String>, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     #[test] |  | ||||||
|     fn test_remove_slash_and_join_symbols() { |  | ||||||
|         let symbols = RemoveSlashAndJoinSymbols { |  | ||||||
|             symbols: vec![String::from("BTC/USD"), String::from("ETH/USD")], |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         assert_ser_tokens(&symbols, &[Token::Str("BTCUSD,ETHUSD")]); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
							
								
								
									
										71
									
								
								src/main.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								src/main.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | |||||||
|  | #![warn(clippy::all, clippy::pedantic, clippy::nursery)] | ||||||
|  | #![allow(clippy::missing_docs_in_private_items)] | ||||||
|  | #![feature(hash_extract_if)] | ||||||
|  |  | ||||||
|  | mod config; | ||||||
|  | mod database; | ||||||
|  | mod init; | ||||||
|  | mod routes; | ||||||
|  | mod threads; | ||||||
|  | mod types; | ||||||
|  | mod utils; | ||||||
|  |  | ||||||
|  | use config::Config; | ||||||
|  | use dotenv::dotenv; | ||||||
|  | use log4rs::config::Deserializers; | ||||||
|  | use tokio::{join, spawn, sync::mpsc, try_join}; | ||||||
|  |  | ||||||
|  | #[tokio::main] | ||||||
|  | async fn main() { | ||||||
|  |     dotenv().ok(); | ||||||
|  |     log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); | ||||||
|  |     let config = Config::arc_from_env(); | ||||||
|  |  | ||||||
|  |     try_join!( | ||||||
|  |         database::backfills_bars::unfresh(&config.clickhouse_client), | ||||||
|  |         database::backfills_news::unfresh(&config.clickhouse_client) | ||||||
|  |     ) | ||||||
|  |     .unwrap(); | ||||||
|  |  | ||||||
|  |     database::cleanup_all(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .unwrap(); | ||||||
|  |     database::optimize_all(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     init::check_account(&config).await; | ||||||
|  |     join!( | ||||||
|  |         init::rehydrate_orders(&config), | ||||||
|  |         init::rehydrate_positions(&config) | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     spawn(threads::trading::run(config.clone())); | ||||||
|  |  | ||||||
|  |     let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100); | ||||||
|  |     let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1); | ||||||
|  |  | ||||||
|  |     spawn(threads::data::run( | ||||||
|  |         config.clone(), | ||||||
|  |         data_receiver, | ||||||
|  |         clock_receiver, | ||||||
|  |     )); | ||||||
|  |  | ||||||
|  |     spawn(threads::clock::run(config.clone(), clock_sender)); | ||||||
|  |  | ||||||
|  |     let assets = database::assets::select(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .unwrap() | ||||||
|  |         .into_iter() | ||||||
|  |         .map(|asset| (asset.symbol, asset.class)) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |     create_send_await!( | ||||||
|  |         data_sender, | ||||||
|  |         threads::data::Message::new, | ||||||
|  |         threads::data::Action::Enable, | ||||||
|  |         assets | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     routes::run(config, data_sender).await; | ||||||
|  | } | ||||||
							
								
								
									
										99
									
								
								src/routes/assets.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								src/routes/assets.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | |||||||
|  | use crate::{ | ||||||
|  |     config::Config, | ||||||
|  |     create_send_await, database, threads, | ||||||
|  |     types::{alpaca, Asset}, | ||||||
|  | }; | ||||||
|  | use axum::{extract::Path, Extension, Json}; | ||||||
|  | use http::StatusCode; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use std::sync::Arc; | ||||||
|  | use tokio::sync::mpsc; | ||||||
|  |  | ||||||
|  | pub async fn get( | ||||||
|  |     Extension(config): Extension<Arc<Config>>, | ||||||
|  | ) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> { | ||||||
|  |     let assets = database::assets::select(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||||||
|  |  | ||||||
|  |     Ok((StatusCode::OK, Json(assets))) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get_where_symbol( | ||||||
|  |     Extension(config): Extension<Arc<Config>>, | ||||||
|  |     Path(symbol): Path<String>, | ||||||
|  | ) -> Result<(StatusCode, Json<Asset>), StatusCode> { | ||||||
|  |     let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) | ||||||
|  |         .await | ||||||
|  |         .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||||||
|  |  | ||||||
|  |     asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { | ||||||
|  |         Ok((StatusCode::OK, Json(asset))) | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct AddAssetRequest { | ||||||
|  |     symbol: String, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn add( | ||||||
|  |     Extension(config): Extension<Arc<Config>>, | ||||||
|  |     Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, | ||||||
|  |     Json(request): Json<AddAssetRequest>, | ||||||
|  | ) -> Result<StatusCode, StatusCode> { | ||||||
|  |     if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol) | ||||||
|  |         .await | ||||||
|  |         .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? | ||||||
|  |         .is_some() | ||||||
|  |     { | ||||||
|  |         return Err(StatusCode::CONFLICT); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let asset = alpaca::api::incoming::asset::get_by_symbol( | ||||||
|  |         &config.alpaca_client, | ||||||
|  |         &config.alpaca_rate_limiter, | ||||||
|  |         &request.symbol, | ||||||
|  |         None, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  |     .map_err(|e| { | ||||||
|  |         e.status() | ||||||
|  |             .map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| { | ||||||
|  |                 StatusCode::from_u16(status.as_u16()).unwrap() | ||||||
|  |             }) | ||||||
|  |     })?; | ||||||
|  |  | ||||||
|  |     if !asset.tradable || !asset.fractionable { | ||||||
|  |         return Err(StatusCode::FORBIDDEN); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     create_send_await!( | ||||||
|  |         data_sender, | ||||||
|  |         threads::data::Message::new, | ||||||
|  |         threads::data::Action::Add, | ||||||
|  |         vec![(asset.symbol, asset.class.into())] | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     Ok(StatusCode::CREATED) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn delete( | ||||||
|  |     Extension(config): Extension<Arc<Config>>, | ||||||
|  |     Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, | ||||||
|  |     Path(symbol): Path<String>, | ||||||
|  | ) -> Result<StatusCode, StatusCode> { | ||||||
|  |     let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) | ||||||
|  |         .await | ||||||
|  |         .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? | ||||||
|  |         .ok_or(StatusCode::NOT_FOUND)?; | ||||||
|  |  | ||||||
|  |     create_send_await!( | ||||||
|  |         data_sender, | ||||||
|  |         threads::data::Message::new, | ||||||
|  |         threads::data::Action::Remove, | ||||||
|  |         vec![(asset.symbol, asset.class)] | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     Ok(StatusCode::NO_CONTENT) | ||||||
|  | } | ||||||
| @@ -16,7 +16,6 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M | |||||||
|         .route("/assets", get(assets::get)) |         .route("/assets", get(assets::get)) | ||||||
|         .route("/assets/:symbol", get(assets::get_where_symbol)) |         .route("/assets/:symbol", get(assets::get_where_symbol)) | ||||||
|         .route("/assets", post(assets::add)) |         .route("/assets", post(assets::add)) | ||||||
|         .route("/assets/:symbol", post(assets::add_symbol)) |  | ||||||
|         .route("/assets/:symbol", delete(assets::delete)) |         .route("/assets/:symbol", delete(assets::delete)) | ||||||
|         .layer(Extension(config)) |         .layer(Extension(config)) | ||||||
|         .layer(Extension(data_sender)); |         .layer(Extension(data_sender)); | ||||||
| @@ -1,17 +1,14 @@ | |||||||
| use crate::{ | use crate::{ | ||||||
|     config::{Config, ALPACA_API_BASE}, |     config::Config, | ||||||
|     database, |     database, | ||||||
| }; |     types::{alpaca, Calendar}, | ||||||
| use log::info; |  | ||||||
| use qrust::{ |  | ||||||
|     alpaca, |  | ||||||
|     types::{self, Calendar}, |  | ||||||
|     utils::{backoff, duration_until}, |     utils::{backoff, duration_until}, | ||||||
| }; | }; | ||||||
|  | use log::info; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
|  | use time::OffsetDateTime; | ||||||
| use tokio::{join, sync::mpsc, time::sleep}; | use tokio::{join, sync::mpsc, time::sleep}; | ||||||
| 
 | 
 | ||||||
| #[derive(PartialEq, Eq)] |  | ||||||
| pub enum Status { | pub enum Status { | ||||||
|     Open, |     Open, | ||||||
|     Closed, |     Closed, | ||||||
| @@ -19,16 +16,21 @@ pub enum Status { | |||||||
| 
 | 
 | ||||||
| pub struct Message { | pub struct Message { | ||||||
|     pub status: Status, |     pub status: Status, | ||||||
|  |     pub next_switch: OffsetDateTime, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl From<types::alpaca::api::incoming::clock::Clock> for Message { | impl From<alpaca::api::incoming::clock::Clock> for Message { | ||||||
|     fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self { |     fn from(clock: alpaca::api::incoming::clock::Clock) -> Self { | ||||||
|  |         if clock.is_open { | ||||||
|             Self { |             Self { | ||||||
|             status: if clock.is_open { |                 status: Status::Open, | ||||||
|                 Status::Open |                 next_switch: clock.next_close, | ||||||
|  |             } | ||||||
|         } else { |         } else { | ||||||
|                 Status::Closed |             Self { | ||||||
|             }, |                 status: Status::Closed, | ||||||
|  |                 next_switch: clock.next_open, | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -36,23 +38,21 @@ impl From<types::alpaca::api::incoming::clock::Clock> for Message { | |||||||
| pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) { | pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) { | ||||||
|     loop { |     loop { | ||||||
|         let clock_future = async { |         let clock_future = async { | ||||||
|             alpaca::clock::get( |             alpaca::api::incoming::clock::get( | ||||||
|                 &config.alpaca_client, |                 &config.alpaca_client, | ||||||
|                 &config.alpaca_rate_limiter, |                 &config.alpaca_rate_limiter, | ||||||
|                 Some(backoff::infinite()), |                 Some(backoff::infinite()), | ||||||
|                 &ALPACA_API_BASE, |  | ||||||
|             ) |             ) | ||||||
|             .await |             .await | ||||||
|             .unwrap() |             .unwrap() | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let calendar_future = async { |         let calendar_future = async { | ||||||
|             alpaca::calendar::get( |             alpaca::api::incoming::calendar::get( | ||||||
|                 &config.alpaca_client, |                 &config.alpaca_client, | ||||||
|                 &config.alpaca_rate_limiter, |                 &config.alpaca_rate_limiter, | ||||||
|                 &types::alpaca::api::outgoing::calendar::Calendar::default(), |                 &alpaca::api::outgoing::calendar::Calendar::default(), | ||||||
|                 Some(backoff::infinite()), |                 Some(backoff::infinite()), | ||||||
|                 &ALPACA_API_BASE, |  | ||||||
|             ) |             ) | ||||||
|             .await |             .await | ||||||
|             .unwrap() |             .unwrap() | ||||||
| @@ -74,11 +74,7 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) { | |||||||
|         let sleep_future = sleep(sleep_until); |         let sleep_future = sleep(sleep_until); | ||||||
| 
 | 
 | ||||||
|         let calendar_future = async { |         let calendar_future = async { | ||||||
|             database::calendar::upsert_batch_and_delete( |             database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar) | ||||||
|                 &config.clickhouse_client, |  | ||||||
|                 &config.clickhouse_concurrency_limiter, |  | ||||||
|                 &calendar, |  | ||||||
|             ) |  | ||||||
|                 .await |                 .await | ||||||
|                 .unwrap(); |                 .unwrap(); | ||||||
|         }; |         }; | ||||||
							
								
								
									
										413
									
								
								src/threads/data/backfill.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										413
									
								
								src/threads/data/backfill.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,413 @@ | |||||||
|  | use super::ThreadType; | ||||||
|  | use crate::{ | ||||||
|  |     config::{ | ||||||
|  |         Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL, | ||||||
|  |         MAX_BERT_INPUTS, | ||||||
|  |     }, | ||||||
|  |     database, | ||||||
|  |     types::{ | ||||||
|  |         alpaca::{self, shared::Source}, | ||||||
|  |         news::Prediction, | ||||||
|  |         Backfill, Bar, Class, News, | ||||||
|  |     }, | ||||||
|  |     utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND}, | ||||||
|  | }; | ||||||
|  | use async_trait::async_trait; | ||||||
|  | use futures_util::future::join_all; | ||||||
|  | use log::{error, info, warn}; | ||||||
|  | use std::{collections::HashMap, sync::Arc}; | ||||||
|  | use time::OffsetDateTime; | ||||||
|  | use tokio::{ | ||||||
|  |     spawn, | ||||||
|  |     sync::{mpsc, oneshot, Mutex}, | ||||||
|  |     task::{block_in_place, JoinHandle}, | ||||||
|  |     time::sleep, | ||||||
|  |     try_join, | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | pub enum Action { | ||||||
|  |     Backfill, | ||||||
|  |     Purge, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<super::Action> for Option<Action> { | ||||||
|  |     fn from(action: super::Action) -> Self { | ||||||
|  |         match action { | ||||||
|  |             super::Action::Add | super::Action::Enable => Some(Action::Backfill), | ||||||
|  |             super::Action::Remove => Some(Action::Purge), | ||||||
|  |             super::Action::Disable => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub struct Message { | ||||||
|  |     pub action: Option<Action>, | ||||||
|  |     pub symbols: Vec<String>, | ||||||
|  |     pub response: oneshot::Sender<()>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Message { | ||||||
|  |     pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) { | ||||||
|  |         let (sender, receiver) = oneshot::channel::<()>(); | ||||||
|  |         ( | ||||||
|  |             Self { | ||||||
|  |                 action, | ||||||
|  |                 symbols, | ||||||
|  |                 response: sender, | ||||||
|  |             }, | ||||||
|  |             receiver, | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | pub trait Handler: Send + Sync { | ||||||
|  |     async fn select_latest_backfill( | ||||||
|  |         &self, | ||||||
|  |         symbol: String, | ||||||
|  |     ) -> Result<Option<Backfill>, clickhouse::error::Error>; | ||||||
|  |     async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; | ||||||
|  |     async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>; | ||||||
|  |     async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime); | ||||||
|  |     async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime); | ||||||
|  |     fn log_string(&self) -> &'static str; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) { | ||||||
|  |     let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); | ||||||
|  |  | ||||||
|  |     loop { | ||||||
|  |         let message = receiver.recv().await.unwrap(); | ||||||
|  |         spawn(handle_backfill_message( | ||||||
|  |             handler.clone(), | ||||||
|  |             backfill_jobs.clone(), | ||||||
|  |             message, | ||||||
|  |         )); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | async fn handle_backfill_message( | ||||||
|  |     handler: Arc<Box<dyn Handler>>, | ||||||
|  |     backfill_jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>, | ||||||
|  |     message: Message, | ||||||
|  | ) { | ||||||
|  |     let mut backfill_jobs = backfill_jobs.lock().await; | ||||||
|  |  | ||||||
|  |     match message.action { | ||||||
|  |         Some(Action::Backfill) => { | ||||||
|  |             let log_string = handler.log_string(); | ||||||
|  |  | ||||||
|  |             for symbol in message.symbols { | ||||||
|  |                 if let Some(job) = backfill_jobs.get(&symbol) { | ||||||
|  |                     if !job.is_finished() { | ||||||
|  |                         warn!( | ||||||
|  |                             "Backfill for {} {} is already running, skipping.", | ||||||
|  |                             symbol, log_string | ||||||
|  |                         ); | ||||||
|  |                         continue; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 let handler = handler.clone(); | ||||||
|  |                 backfill_jobs.insert( | ||||||
|  |                     symbol.clone(), | ||||||
|  |                     spawn(async move { | ||||||
|  |                         let fetch_from = match handler | ||||||
|  |                             .select_latest_backfill(symbol.clone()) | ||||||
|  |                             .await | ||||||
|  |                             .unwrap() | ||||||
|  |                         { | ||||||
|  |                             Some(latest_backfill) => latest_backfill.time + ONE_SECOND, | ||||||
|  |                             None => OffsetDateTime::UNIX_EPOCH, | ||||||
|  |                         }; | ||||||
|  |  | ||||||
|  |                         let fetch_to = last_minute(); | ||||||
|  |  | ||||||
|  |                         if fetch_from > fetch_to { | ||||||
|  |                             info!("No need to backfill {} {}.", symbol, log_string,); | ||||||
|  |                             return; | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         handler.queue_backfill(&symbol, fetch_to).await; | ||||||
|  |                         handler.backfill(symbol, fetch_from, fetch_to).await; | ||||||
|  |                     }), | ||||||
|  |                 ); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Some(Action::Purge) => { | ||||||
|  |             for symbol in &message.symbols { | ||||||
|  |                 if let Some(job) = backfill_jobs.remove(symbol) { | ||||||
|  |                     if !job.is_finished() { | ||||||
|  |                         job.abort(); | ||||||
|  |                     } | ||||||
|  |                     let _ = job.await; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             try_join!( | ||||||
|  |                 handler.delete_backfills(&message.symbols), | ||||||
|  |                 handler.delete_data(&message.symbols) | ||||||
|  |             ) | ||||||
|  |             .unwrap(); | ||||||
|  |         } | ||||||
|  |         None => {} | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     message.response.send(()).unwrap(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct BarHandler { | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     data_url: &'static str, | ||||||
|  |     api_query_constructor: fn( | ||||||
|  |         symbol: String, | ||||||
|  |         fetch_from: OffsetDateTime, | ||||||
|  |         fetch_to: OffsetDateTime, | ||||||
|  |         next_page_token: Option<String>, | ||||||
|  |     ) -> alpaca::api::outgoing::bar::Bar, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn us_equity_query_constructor( | ||||||
|  |     symbol: String, | ||||||
|  |     fetch_from: OffsetDateTime, | ||||||
|  |     fetch_to: OffsetDateTime, | ||||||
|  |     next_page_token: Option<String>, | ||||||
|  | ) -> alpaca::api::outgoing::bar::Bar { | ||||||
|  |     alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity { | ||||||
|  |         symbols: vec![symbol], | ||||||
|  |         start: Some(fetch_from), | ||||||
|  |         end: Some(fetch_to), | ||||||
|  |         page_token: next_page_token, | ||||||
|  |         ..Default::default() | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn crypto_query_constructor( | ||||||
|  |     symbol: String, | ||||||
|  |     fetch_from: OffsetDateTime, | ||||||
|  |     fetch_to: OffsetDateTime, | ||||||
|  |     next_page_token: Option<String>, | ||||||
|  | ) -> alpaca::api::outgoing::bar::Bar { | ||||||
|  |     alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto { | ||||||
|  |         symbols: vec![symbol], | ||||||
|  |         start: Some(fetch_from), | ||||||
|  |         end: Some(fetch_to), | ||||||
|  |         page_token: next_page_token, | ||||||
|  |         ..Default::default() | ||||||
|  |     }) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | impl Handler for BarHandler { | ||||||
|  |     async fn select_latest_backfill( | ||||||
|  |         &self, | ||||||
|  |         symbol: String, | ||||||
|  |     ) -> Result<Option<Backfill>, clickhouse::error::Error> { | ||||||
|  |         database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { | ||||||
|  |         database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols) | ||||||
|  |             .await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { | ||||||
|  |         database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { | ||||||
|  |         if *ALPACA_SOURCE == Source::Iex { | ||||||
|  |             let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); | ||||||
|  |             info!("Queing bar backfill for {} in {:?}.", symbol, run_delay); | ||||||
|  |             sleep(run_delay).await; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { | ||||||
|  |         info!("Backfilling bars for {}.", symbol); | ||||||
|  |  | ||||||
|  |         let mut bars = vec![]; | ||||||
|  |         let mut next_page_token = None; | ||||||
|  |  | ||||||
|  |         loop { | ||||||
|  |             let Ok(message) = alpaca::api::incoming::bar::get_historical( | ||||||
|  |                 &self.config.alpaca_client, | ||||||
|  |                 &self.config.alpaca_rate_limiter, | ||||||
|  |                 self.data_url, | ||||||
|  |                 &(self.api_query_constructor)( | ||||||
|  |                     symbol.clone(), | ||||||
|  |                     fetch_from, | ||||||
|  |                     fetch_to, | ||||||
|  |                     next_page_token.clone(), | ||||||
|  |                 ), | ||||||
|  |                 None, | ||||||
|  |             ) | ||||||
|  |             .await | ||||||
|  |             else { | ||||||
|  |                 error!("Failed to backfill bars for {}.", symbol); | ||||||
|  |                 return; | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             message.bars.into_iter().for_each(|(symbol, bar_vec)| { | ||||||
|  |                 for bar in bar_vec { | ||||||
|  |                     bars.push(Bar::from((bar, symbol.clone()))); | ||||||
|  |                 } | ||||||
|  |             }); | ||||||
|  |  | ||||||
|  |             if message.next_page_token.is_none() { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             next_page_token = message.next_page_token; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if bars.is_empty() { | ||||||
|  |             info!("No bars to backfill for {}.", symbol); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let backfill = bars.last().unwrap().clone().into(); | ||||||
|  |  | ||||||
|  |         database::bars::upsert_batch(&self.config.clickhouse_client, &bars) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |         database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |  | ||||||
|  |         info!("Backfilled bars for {}.", symbol); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn log_string(&self) -> &'static str { | ||||||
|  |         "bars" | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct NewsHandler { | ||||||
|  |     config: Arc<Config>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | impl Handler for NewsHandler { | ||||||
|  |     async fn select_latest_backfill( | ||||||
|  |         &self, | ||||||
|  |         symbol: String, | ||||||
|  |     ) -> Result<Option<Backfill>, clickhouse::error::Error> { | ||||||
|  |         database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { | ||||||
|  |         database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols) | ||||||
|  |             .await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { | ||||||
|  |         database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) { | ||||||
|  |         let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); | ||||||
|  |         info!("Queing news backfill for {} in {:?}.", symbol, run_delay); | ||||||
|  |         sleep(run_delay).await; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) { | ||||||
|  |         info!("Backfilling news for {}.", symbol); | ||||||
|  |  | ||||||
|  |         let mut news = vec![]; | ||||||
|  |         let mut next_page_token = None; | ||||||
|  |  | ||||||
|  |         loop { | ||||||
|  |             let Ok(message) = alpaca::api::incoming::news::get_historical( | ||||||
|  |                 &self.config.alpaca_client, | ||||||
|  |                 &self.config.alpaca_rate_limiter, | ||||||
|  |                 &alpaca::api::outgoing::news::News { | ||||||
|  |                     symbols: vec![symbol.clone()], | ||||||
|  |                     start: Some(fetch_from), | ||||||
|  |                     end: Some(fetch_to), | ||||||
|  |                     page_token: next_page_token.clone(), | ||||||
|  |                     ..Default::default() | ||||||
|  |                 }, | ||||||
|  |                 None, | ||||||
|  |             ) | ||||||
|  |             .await | ||||||
|  |             else { | ||||||
|  |                 error!("Failed to backfill news for {}.", symbol); | ||||||
|  |                 return; | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             message.news.into_iter().for_each(|news_item| { | ||||||
|  |                 news.push(News::from(news_item)); | ||||||
|  |             }); | ||||||
|  |  | ||||||
|  |             if message.next_page_token.is_none() { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             next_page_token = message.next_page_token; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if news.is_empty() { | ||||||
|  |             info!("No news to backfill for {}.", symbol); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let inputs = news | ||||||
|  |             .iter() | ||||||
|  |             .map(|news| format!("{}\n\n{}", news.headline, news.content)) | ||||||
|  |             .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |         let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move { | ||||||
|  |             let sequence_classifier = self.config.sequence_classifier.lock().await; | ||||||
|  |             block_in_place(|| { | ||||||
|  |                 sequence_classifier | ||||||
|  |                     .predict(inputs.iter().map(String::as_str).collect::<Vec<_>>()) | ||||||
|  |                     .into_iter() | ||||||
|  |                     .map(|label| Prediction::try_from(label).unwrap()) | ||||||
|  |                     .collect::<Vec<_>>() | ||||||
|  |             }) | ||||||
|  |         })) | ||||||
|  |         .await | ||||||
|  |         .into_iter() | ||||||
|  |         .flatten(); | ||||||
|  |  | ||||||
|  |         let news = news | ||||||
|  |             .into_iter() | ||||||
|  |             .zip(predictions) | ||||||
|  |             .map(|(news, prediction)| News { | ||||||
|  |                 sentiment: prediction.sentiment, | ||||||
|  |                 confidence: prediction.confidence, | ||||||
|  |                 ..news | ||||||
|  |             }) | ||||||
|  |             .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |         let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); | ||||||
|  |  | ||||||
|  |         database::news::upsert_batch(&self.config.clickhouse_client, &news) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |         database::backfills_news::upsert(&self.config.clickhouse_client, &backfill) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |  | ||||||
|  |         info!("Backfilled news for {}.", symbol); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn log_string(&self) -> &'static str { | ||||||
|  |         "news" | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> { | ||||||
|  |     match thread_type { | ||||||
|  |         ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler { | ||||||
|  |             config, | ||||||
|  |             data_url: ALPACA_STOCK_DATA_API_URL, | ||||||
|  |             api_query_constructor: us_equity_query_constructor, | ||||||
|  |         }), | ||||||
|  |         ThreadType::Bars(Class::Crypto) => Box::new(BarHandler { | ||||||
|  |             config, | ||||||
|  |             data_url: ALPACA_CRYPTO_DATA_API_URL, | ||||||
|  |             api_query_constructor: crypto_query_constructor, | ||||||
|  |         }), | ||||||
|  |         ThreadType::News => Box::new(NewsHandler { config }), | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										326
									
								
								src/threads/data/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										326
									
								
								src/threads/data/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,326 @@ | |||||||
|  | mod backfill; | ||||||
|  | mod websocket; | ||||||
|  |  | ||||||
|  | use super::clock; | ||||||
|  | use crate::{ | ||||||
|  |     config::{ | ||||||
|  |         Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_SOURCE, | ||||||
|  |         ALPACA_STOCK_DATA_WEBSOCKET_URL, | ||||||
|  |     }, | ||||||
|  |     create_send_await, database, | ||||||
|  |     types::{alpaca, Asset, Class}, | ||||||
|  |     utils::backoff, | ||||||
|  | }; | ||||||
|  | use futures_util::{future::join_all, StreamExt}; | ||||||
|  | use itertools::{Either, Itertools}; | ||||||
|  | use std::sync::Arc; | ||||||
|  | use tokio::{ | ||||||
|  |     join, select, spawn, | ||||||
|  |     sync::{mpsc, oneshot}, | ||||||
|  | }; | ||||||
|  | use tokio_tungstenite::connect_async; | ||||||
|  |  | ||||||
|  | #[derive(Clone, Copy)] | ||||||
|  | #[allow(dead_code)] | ||||||
|  | pub enum Action { | ||||||
|  |     Add, | ||||||
|  |     Enable, | ||||||
|  |     Remove, | ||||||
|  |     Disable, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub struct Message { | ||||||
|  |     pub action: Action, | ||||||
|  |     pub assets: Vec<(String, Class)>, | ||||||
|  |     pub response: oneshot::Sender<()>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Message { | ||||||
|  |     pub fn new(action: Action, assets: Vec<(String, Class)>) -> (Self, oneshot::Receiver<()>) { | ||||||
|  |         let (sender, receiver) = oneshot::channel(); | ||||||
|  |         ( | ||||||
|  |             Self { | ||||||
|  |                 action, | ||||||
|  |                 assets, | ||||||
|  |                 response: sender, | ||||||
|  |             }, | ||||||
|  |             receiver, | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Clone, Copy)] | ||||||
|  | pub enum ThreadType { | ||||||
|  |     Bars(Class), | ||||||
|  |     News, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn run( | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     mut receiver: mpsc::Receiver<Message>, | ||||||
|  |     mut clock_receiver: mpsc::Receiver<clock::Message>, | ||||||
|  | ) { | ||||||
|  |     let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = | ||||||
|  |         init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await; | ||||||
|  |     let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = | ||||||
|  |         init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await; | ||||||
|  |     let (news_websocket_sender, news_backfill_sender) = | ||||||
|  |         init_thread(config.clone(), ThreadType::News).await; | ||||||
|  |  | ||||||
|  |     loop { | ||||||
|  |         select! { | ||||||
|  |             Some(message) = receiver.recv() => { | ||||||
|  |                 spawn(handle_message( | ||||||
|  |                     config.clone(), | ||||||
|  |                     bars_us_equity_websocket_sender.clone(), | ||||||
|  |                     bars_us_equity_backfill_sender.clone(), | ||||||
|  |                     bars_crypto_websocket_sender.clone(), | ||||||
|  |                     bars_crypto_backfill_sender.clone(), | ||||||
|  |                     news_websocket_sender.clone(), | ||||||
|  |                     news_backfill_sender.clone(), | ||||||
|  |                     message, | ||||||
|  |                 )); | ||||||
|  |             } | ||||||
|  |             Some(_) = clock_receiver.recv() => { | ||||||
|  |                 spawn(handle_clock_message( | ||||||
|  |                     config.clone(), | ||||||
|  |                     bars_us_equity_backfill_sender.clone(), | ||||||
|  |                     bars_crypto_backfill_sender.clone(), | ||||||
|  |                     news_backfill_sender.clone(), | ||||||
|  |                 )); | ||||||
|  |             } | ||||||
|  |             else => panic!("Communication channel unexpectedly closed.") | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | async fn init_thread( | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     thread_type: ThreadType, | ||||||
|  | ) -> ( | ||||||
|  |     mpsc::Sender<websocket::Message>, | ||||||
|  |     mpsc::Sender<backfill::Message>, | ||||||
|  | ) { | ||||||
|  |     let websocket_url = match thread_type { | ||||||
|  |         ThreadType::Bars(Class::UsEquity) => { | ||||||
|  |             format!("{}/{}", ALPACA_STOCK_DATA_WEBSOCKET_URL, *ALPACA_SOURCE) | ||||||
|  |         } | ||||||
|  |         ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(), | ||||||
|  |         ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let (websocket, _) = connect_async(websocket_url).await.unwrap(); | ||||||
|  |     let (mut websocket_sink, mut websocket_stream) = websocket.split(); | ||||||
|  |     alpaca::websocket::data::authenticate(&mut websocket_sink, &mut websocket_stream).await; | ||||||
|  |  | ||||||
|  |     let (backfill_sender, backfill_receiver) = mpsc::channel(100); | ||||||
|  |     spawn(backfill::run( | ||||||
|  |         Arc::new(backfill::create_handler(thread_type, config.clone())), | ||||||
|  |         backfill_receiver, | ||||||
|  |     )); | ||||||
|  |  | ||||||
|  |     let (websocket_sender, websocket_receiver) = mpsc::channel(100); | ||||||
|  |     spawn(websocket::run( | ||||||
|  |         Arc::new(websocket::create_handler(thread_type, config.clone())), | ||||||
|  |         websocket_receiver, | ||||||
|  |         websocket_stream, | ||||||
|  |         websocket_sink, | ||||||
|  |     )); | ||||||
|  |  | ||||||
|  |     (websocket_sender, backfill_sender) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[allow(clippy::too_many_arguments)] | ||||||
|  | #[allow(clippy::too_many_lines)] | ||||||
|  | async fn handle_message( | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>, | ||||||
|  |     bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  |     bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>, | ||||||
|  |     bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  |     news_websocket_sender: mpsc::Sender<websocket::Message>, | ||||||
|  |     news_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  |     message: Message, | ||||||
|  | ) { | ||||||
|  |     if message.assets.is_empty() { | ||||||
|  |         message.response.send(()).unwrap(); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message | ||||||
|  |         .assets | ||||||
|  |         .clone() | ||||||
|  |         .into_iter() | ||||||
|  |         .partition_map(|asset| match asset.1 { | ||||||
|  |             Class::UsEquity => Either::Left(asset.0), | ||||||
|  |             Class::Crypto => Either::Right(asset.0), | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     let symbols = message | ||||||
|  |         .assets | ||||||
|  |         .into_iter() | ||||||
|  |         .map(|(symbol, _)| symbol) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |     let bars_us_equity_future = async { | ||||||
|  |         if us_equity_symbols.is_empty() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_us_equity_websocket_sender, | ||||||
|  |             websocket::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             us_equity_symbols.clone() | ||||||
|  |         ); | ||||||
|  |  | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_us_equity_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             us_equity_symbols | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let bars_crypto_future = async { | ||||||
|  |         if crypto_symbols.is_empty() { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_crypto_websocket_sender, | ||||||
|  |             websocket::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             crypto_symbols.clone() | ||||||
|  |         ); | ||||||
|  |  | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_crypto_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             crypto_symbols | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let news_future = async { | ||||||
|  |         create_send_await!( | ||||||
|  |             news_websocket_sender, | ||||||
|  |             websocket::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             symbols.clone() | ||||||
|  |         ); | ||||||
|  |  | ||||||
|  |         create_send_await!( | ||||||
|  |             news_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             message.action.into(), | ||||||
|  |             symbols.clone() | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join!(bars_us_equity_future, bars_crypto_future, news_future); | ||||||
|  |  | ||||||
|  |     match message.action { | ||||||
|  |         Action::Add => { | ||||||
|  |             let assets = join_all(symbols.into_iter().map(|symbol| { | ||||||
|  |                 let config = config.clone(); | ||||||
|  |                 async move { | ||||||
|  |                     let asset_future = async { | ||||||
|  |                         alpaca::api::incoming::asset::get_by_symbol( | ||||||
|  |                             &config.alpaca_client, | ||||||
|  |                             &config.alpaca_rate_limiter, | ||||||
|  |                             &symbol, | ||||||
|  |                             Some(backoff::infinite()), | ||||||
|  |                         ) | ||||||
|  |                         .await | ||||||
|  |                         .unwrap() | ||||||
|  |                     }; | ||||||
|  |  | ||||||
|  |                     let position_future = async { | ||||||
|  |                         alpaca::api::incoming::position::get_by_symbol( | ||||||
|  |                             &config.alpaca_client, | ||||||
|  |                             &config.alpaca_rate_limiter, | ||||||
|  |                             &symbol, | ||||||
|  |                             Some(backoff::infinite()), | ||||||
|  |                         ) | ||||||
|  |                         .await | ||||||
|  |                         .unwrap() | ||||||
|  |                     }; | ||||||
|  |  | ||||||
|  |                     let (asset, position) = join!(asset_future, position_future); | ||||||
|  |                     Asset::from((asset, position)) | ||||||
|  |                 } | ||||||
|  |             })) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |             database::assets::upsert_batch(&config.clickhouse_client, &assets) | ||||||
|  |                 .await | ||||||
|  |                 .unwrap(); | ||||||
|  |         } | ||||||
|  |         Action::Remove => { | ||||||
|  |             database::assets::delete_where_symbols(&config.clickhouse_client, &symbols) | ||||||
|  |                 .await | ||||||
|  |                 .unwrap(); | ||||||
|  |         } | ||||||
|  |         _ => {} | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     message.response.send(()).unwrap(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | async fn handle_clock_message( | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  |     bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  |     news_backfill_sender: mpsc::Sender<backfill::Message>, | ||||||
|  | ) { | ||||||
|  |     database::cleanup_all(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     let assets = database::assets::select(&config.clickhouse_client) | ||||||
|  |         .await | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets | ||||||
|  |         .clone() | ||||||
|  |         .into_iter() | ||||||
|  |         .partition_map(|asset| match asset.class { | ||||||
|  |             Class::UsEquity => Either::Left(asset.symbol), | ||||||
|  |             Class::Crypto => Either::Right(asset.symbol), | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     let symbols = assets | ||||||
|  |         .into_iter() | ||||||
|  |         .map(|asset| asset.symbol) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |     let bars_us_equity_future = async { | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_us_equity_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             Some(backfill::Action::Backfill), | ||||||
|  |             us_equity_symbols.clone() | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let bars_crypto_future = async { | ||||||
|  |         create_send_await!( | ||||||
|  |             bars_crypto_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             Some(backfill::Action::Backfill), | ||||||
|  |             crypto_symbols.clone() | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let news_future = async { | ||||||
|  |         create_send_await!( | ||||||
|  |             news_backfill_sender, | ||||||
|  |             backfill::Message::new, | ||||||
|  |             Some(backfill::Action::Backfill), | ||||||
|  |             symbols | ||||||
|  |         ); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join!(bars_us_equity_future, bars_crypto_future, news_future); | ||||||
|  | } | ||||||
							
								
								
									
										427
									
								
								src/threads/data/websocket.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										427
									
								
								src/threads/data/websocket.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,427 @@ | |||||||
|  | use super::ThreadType; | ||||||
|  | use crate::{ | ||||||
|  |     config::Config, | ||||||
|  |     database, | ||||||
|  |     types::{alpaca::websocket, news::Prediction, Bar, Class, News}, | ||||||
|  | }; | ||||||
|  | use async_trait::async_trait; | ||||||
|  | use futures_util::{ | ||||||
|  |     future::join_all, | ||||||
|  |     stream::{SplitSink, SplitStream}, | ||||||
|  |     SinkExt, StreamExt, | ||||||
|  | }; | ||||||
|  | use log::{debug, error, info}; | ||||||
|  | use serde_json::{from_str, to_string}; | ||||||
|  | use std::{collections::HashMap, sync::Arc}; | ||||||
|  | use tokio::{ | ||||||
|  |     net::TcpStream, | ||||||
|  |     select, spawn, | ||||||
|  |     sync::{mpsc, oneshot, Mutex, RwLock}, | ||||||
|  |     task::block_in_place, | ||||||
|  | }; | ||||||
|  | use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; | ||||||
|  |  | ||||||
|  | pub enum Action { | ||||||
|  |     Subscribe, | ||||||
|  |     Unsubscribe, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<super::Action> for Option<Action> { | ||||||
|  |     fn from(action: super::Action) -> Self { | ||||||
|  |         match action { | ||||||
|  |             super::Action::Add | super::Action::Enable => Some(Action::Subscribe), | ||||||
|  |             super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub struct Message { | ||||||
|  |     pub action: Option<Action>, | ||||||
|  |     pub symbols: Vec<String>, | ||||||
|  |     pub response: oneshot::Sender<()>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Message { | ||||||
|  |     pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) { | ||||||
|  |         let (sender, receiver) = oneshot::channel(); | ||||||
|  |         ( | ||||||
|  |             Self { | ||||||
|  |                 action, | ||||||
|  |                 symbols, | ||||||
|  |                 response: sender, | ||||||
|  |             }, | ||||||
|  |             receiver, | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub struct Pending { | ||||||
|  |     pub subscriptions: HashMap<String, oneshot::Sender<()>>, | ||||||
|  |     pub unsubscriptions: HashMap<String, oneshot::Sender<()>>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | pub trait Handler: Send + Sync { | ||||||
|  |     fn create_subscription_message( | ||||||
|  |         &self, | ||||||
|  |         symbols: Vec<String>, | ||||||
|  |     ) -> websocket::data::outgoing::subscribe::Message; | ||||||
|  |     async fn handle_websocket_message( | ||||||
|  |         &self, | ||||||
|  |         pending: Arc<RwLock<Pending>>, | ||||||
|  |         message: websocket::data::incoming::Message, | ||||||
|  |     ); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn run( | ||||||
|  |     handler: Arc<Box<dyn Handler>>, | ||||||
|  |     mut receiver: mpsc::Receiver<Message>, | ||||||
|  |     mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, | ||||||
|  |     websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>, | ||||||
|  | ) { | ||||||
|  |     let pending = Arc::new(RwLock::new(Pending { | ||||||
|  |         subscriptions: HashMap::new(), | ||||||
|  |         unsubscriptions: HashMap::new(), | ||||||
|  |     })); | ||||||
|  |     let websocket_sink = Arc::new(Mutex::new(websocket_sink)); | ||||||
|  |  | ||||||
|  |     loop { | ||||||
|  |         select! { | ||||||
|  |             Some(message) = receiver.recv() => { | ||||||
|  |                 spawn(handle_message( | ||||||
|  |                     handler.clone(), | ||||||
|  |                     pending.clone(), | ||||||
|  |                     websocket_sink.clone(), | ||||||
|  |                     message, | ||||||
|  |                 )); | ||||||
|  |             } | ||||||
|  |             Some(Ok(message)) = websocket_stream.next() => { | ||||||
|  |                 match message { | ||||||
|  |                     tungstenite::Message::Text(message) => { | ||||||
|  |                         let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message); | ||||||
|  |  | ||||||
|  |                         if parsed_message.is_err() { | ||||||
|  |                             error!("Failed to deserialize websocket message: {:?}", message); | ||||||
|  |                             continue; | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         for message in parsed_message.unwrap() { | ||||||
|  |                             let handler = handler.clone(); | ||||||
|  |                             let pending = pending.clone(); | ||||||
|  |                             spawn(async move { | ||||||
|  |                                 handler.handle_websocket_message(pending, message).await; | ||||||
|  |                             }); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                     tungstenite::Message::Ping(_) => {} | ||||||
|  |                     _ => error!("Unexpected websocket message: {:?}", message), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else => panic!("Communication channel unexpectedly closed.") | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | async fn handle_message( | ||||||
|  |     handler: Arc<Box<dyn Handler>>, | ||||||
|  |     pending: Arc<RwLock<Pending>>, | ||||||
|  |     sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>, | ||||||
|  |     message: Message, | ||||||
|  | ) { | ||||||
|  |     if message.symbols.is_empty() { | ||||||
|  |         message.response.send(()).unwrap(); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     match message.action { | ||||||
|  |         Some(Action::Subscribe) => { | ||||||
|  |             let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message | ||||||
|  |                 .symbols | ||||||
|  |                 .iter() | ||||||
|  |                 .map(|symbol| { | ||||||
|  |                     let (sender, receiver) = oneshot::channel(); | ||||||
|  |                     ((symbol.clone(), sender), receiver) | ||||||
|  |                 }) | ||||||
|  |                 .unzip(); | ||||||
|  |  | ||||||
|  |             pending | ||||||
|  |                 .write() | ||||||
|  |                 .await | ||||||
|  |                 .subscriptions | ||||||
|  |                 .extend(pending_subscriptions); | ||||||
|  |  | ||||||
|  |             sink.lock() | ||||||
|  |                 .await | ||||||
|  |                 .send(tungstenite::Message::Text( | ||||||
|  |                     to_string(&websocket::data::outgoing::Message::Subscribe( | ||||||
|  |                         handler.create_subscription_message(message.symbols), | ||||||
|  |                     )) | ||||||
|  |                     .unwrap(), | ||||||
|  |                 )) | ||||||
|  |                 .await | ||||||
|  |                 .unwrap(); | ||||||
|  |  | ||||||
|  |             join_all(receivers).await; | ||||||
|  |         } | ||||||
|  |         Some(Action::Unsubscribe) => { | ||||||
|  |             let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message | ||||||
|  |                 .symbols | ||||||
|  |                 .iter() | ||||||
|  |                 .map(|symbol| { | ||||||
|  |                     let (sender, receiver) = oneshot::channel(); | ||||||
|  |                     ((symbol.clone(), sender), receiver) | ||||||
|  |                 }) | ||||||
|  |                 .unzip(); | ||||||
|  |  | ||||||
|  |             pending | ||||||
|  |                 .write() | ||||||
|  |                 .await | ||||||
|  |                 .unsubscriptions | ||||||
|  |                 .extend(pending_unsubscriptions); | ||||||
|  |  | ||||||
|  |             sink.lock() | ||||||
|  |                 .await | ||||||
|  |                 .send(tungstenite::Message::Text( | ||||||
|  |                     to_string(&websocket::data::outgoing::Message::Unsubscribe( | ||||||
|  |                         handler.create_subscription_message(message.symbols.clone()), | ||||||
|  |                     )) | ||||||
|  |                     .unwrap(), | ||||||
|  |                 )) | ||||||
|  |                 .await | ||||||
|  |                 .unwrap(); | ||||||
|  |  | ||||||
|  |             join_all(receivers).await; | ||||||
|  |         } | ||||||
|  |         None => {} | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     message.response.send(()).unwrap(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct BarsHandler { | ||||||
|  |     config: Arc<Config>, | ||||||
|  |     subscription_message_constructor: | ||||||
|  |         fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | impl Handler for BarsHandler { | ||||||
|  |     fn create_subscription_message( | ||||||
|  |         &self, | ||||||
|  |         symbols: Vec<String>, | ||||||
|  |     ) -> websocket::data::outgoing::subscribe::Message { | ||||||
|  |         (self.subscription_message_constructor)(symbols) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn handle_websocket_message( | ||||||
|  |         &self, | ||||||
|  |         pending: Arc<RwLock<Pending>>, | ||||||
|  |         message: websocket::data::incoming::Message, | ||||||
|  |     ) { | ||||||
|  |         match message { | ||||||
|  |             websocket::data::incoming::Message::Subscription(message) => { | ||||||
|  |                 let websocket::data::incoming::subscription::Message::Market { | ||||||
|  |                     bars: symbols, .. | ||||||
|  |                 } = message | ||||||
|  |                 else { | ||||||
|  |                     unreachable!() | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 let mut pending = pending.write().await; | ||||||
|  |  | ||||||
|  |                 let newly_subscribed = pending | ||||||
|  |                     .subscriptions | ||||||
|  |                     .extract_if(|symbol, _| symbols.contains(symbol)) | ||||||
|  |                     .collect::<HashMap<_, _>>(); | ||||||
|  |  | ||||||
|  |                 let newly_unsubscribed = pending | ||||||
|  |                     .unsubscriptions | ||||||
|  |                     .extract_if(|symbol, _| !symbols.contains(symbol)) | ||||||
|  |                     .collect::<HashMap<_, _>>(); | ||||||
|  |  | ||||||
|  |                 drop(pending); | ||||||
|  |  | ||||||
|  |                 if !newly_subscribed.is_empty() { | ||||||
|  |                     info!( | ||||||
|  |                         "Subscribed to bars for {:?}.", | ||||||
|  |                         newly_subscribed.keys().collect::<Vec<_>>() | ||||||
|  |                     ); | ||||||
|  |  | ||||||
|  |                     for sender in newly_subscribed.into_values() { | ||||||
|  |                         sender.send(()).unwrap(); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 if !newly_unsubscribed.is_empty() { | ||||||
|  |                     info!( | ||||||
|  |                         "Unsubscribed from bars for {:?}.", | ||||||
|  |                         newly_unsubscribed.keys().collect::<Vec<_>>() | ||||||
|  |                     ); | ||||||
|  |  | ||||||
|  |                     for sender in newly_unsubscribed.into_values() { | ||||||
|  |                         sender.send(()).unwrap(); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             websocket::data::incoming::Message::Bar(message) | ||||||
|  |             | websocket::data::incoming::Message::UpdatedBar(message) => { | ||||||
|  |                 let bar = Bar::from(message); | ||||||
|  |                 debug!("Received bar for {}: {}.", bar.symbol, bar.time); | ||||||
|  |  | ||||||
|  |                 database::bars::upsert(&self.config.clickhouse_client, &bar) | ||||||
|  |                     .await | ||||||
|  |                     .unwrap(); | ||||||
|  |             } | ||||||
|  |             websocket::data::incoming::Message::Status(message) => { | ||||||
|  |                 debug!( | ||||||
|  |                     "Received status message for {}: {:?}.", | ||||||
|  |                     message.symbol, message.status | ||||||
|  |                 ); | ||||||
|  |  | ||||||
|  |                 match message.status { | ||||||
|  |                     websocket::data::incoming::status::Status::TradingHalt(_) | ||||||
|  |                     | websocket::data::incoming::status::Status::VolatilityTradingPause(_) => { | ||||||
|  |                         database::assets::update_status_where_symbol( | ||||||
|  |                             &self.config.clickhouse_client, | ||||||
|  |                             &message.symbol, | ||||||
|  |                             false, | ||||||
|  |                         ) | ||||||
|  |                         .await | ||||||
|  |                         .unwrap(); | ||||||
|  |                     } | ||||||
|  |                     websocket::data::incoming::status::Status::Resume(_) | ||||||
|  |                     | websocket::data::incoming::status::Status::TradingResumption(_) => { | ||||||
|  |                         database::assets::update_status_where_symbol( | ||||||
|  |                             &self.config.clickhouse_client, | ||||||
|  |                             &message.symbol, | ||||||
|  |                             true, | ||||||
|  |                         ) | ||||||
|  |                         .await | ||||||
|  |                         .unwrap(); | ||||||
|  |                     } | ||||||
|  |                     _ => {} | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             websocket::data::incoming::Message::Error(message) => { | ||||||
|  |                 error!("Received error message: {}.", message.message); | ||||||
|  |             } | ||||||
|  |             _ => unreachable!(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct NewsHandler { | ||||||
|  |     config: Arc<Config>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[async_trait] | ||||||
|  | impl Handler for NewsHandler { | ||||||
|  |     fn create_subscription_message( | ||||||
|  |         &self, | ||||||
|  |         symbols: Vec<String>, | ||||||
|  |     ) -> websocket::data::outgoing::subscribe::Message { | ||||||
|  |         websocket::data::outgoing::subscribe::Message::new_news(symbols) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     async fn handle_websocket_message( | ||||||
|  |         &self, | ||||||
|  |         pending: Arc<RwLock<Pending>>, | ||||||
|  |         message: websocket::data::incoming::Message, | ||||||
|  |     ) { | ||||||
|  |         match message { | ||||||
|  |             websocket::data::incoming::Message::Subscription(message) => { | ||||||
|  |                 let websocket::data::incoming::subscription::Message::News { news: symbols } = | ||||||
|  |                     message | ||||||
|  |                 else { | ||||||
|  |                     unreachable!() | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 let mut pending = pending.write().await; | ||||||
|  |  | ||||||
|  |                 let newly_subscribed = pending | ||||||
|  |                     .subscriptions | ||||||
|  |                     .extract_if(|symbol, _| symbols.contains(symbol)) | ||||||
|  |                     .collect::<HashMap<_, _>>(); | ||||||
|  |  | ||||||
|  |                 let newly_unsubscribed = pending | ||||||
|  |                     .unsubscriptions | ||||||
|  |                     .extract_if(|symbol, _| !symbols.contains(symbol)) | ||||||
|  |                     .collect::<HashMap<_, _>>(); | ||||||
|  |  | ||||||
|  |                 drop(pending); | ||||||
|  |  | ||||||
|  |                 if !newly_subscribed.is_empty() { | ||||||
|  |                     info!( | ||||||
|  |                         "Subscribed to news for {:?}.", | ||||||
|  |                         newly_subscribed.keys().collect::<Vec<_>>() | ||||||
|  |                     ); | ||||||
|  |  | ||||||
|  |                     for sender in newly_subscribed.into_values() { | ||||||
|  |                         sender.send(()).unwrap(); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 if !newly_unsubscribed.is_empty() { | ||||||
|  |                     info!( | ||||||
|  |                         "Unsubscribed from news for {:?}.", | ||||||
|  |                         newly_unsubscribed.keys().collect::<Vec<_>>() | ||||||
|  |                     ); | ||||||
|  |  | ||||||
|  |                     for sender in newly_unsubscribed.into_values() { | ||||||
|  |                         sender.send(()).unwrap(); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             websocket::data::incoming::Message::News(message) => { | ||||||
|  |                 let news = News::from(message); | ||||||
|  |  | ||||||
|  |                 debug!( | ||||||
|  |                     "Received news for {:?}: {}.", | ||||||
|  |                     news.symbols, news.time_created | ||||||
|  |                 ); | ||||||
|  |  | ||||||
|  |                 let input = format!("{}\n\n{}", news.headline, news.content); | ||||||
|  |  | ||||||
|  |                 let sequence_classifier = self.config.sequence_classifier.lock().await; | ||||||
|  |                 let prediction = block_in_place(|| { | ||||||
|  |                     sequence_classifier | ||||||
|  |                         .predict(vec![input.as_str()]) | ||||||
|  |                         .into_iter() | ||||||
|  |                         .map(|label| Prediction::try_from(label).unwrap()) | ||||||
|  |                         .collect::<Vec<_>>()[0] | ||||||
|  |                 }); | ||||||
|  |                 drop(sequence_classifier); | ||||||
|  |  | ||||||
|  |                 let news = News { | ||||||
|  |                     sentiment: prediction.sentiment, | ||||||
|  |                     confidence: prediction.confidence, | ||||||
|  |                     ..news | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 database::news::upsert(&self.config.clickhouse_client, &news) | ||||||
|  |                     .await | ||||||
|  |                     .unwrap(); | ||||||
|  |             } | ||||||
|  |             websocket::data::incoming::Message::Error(message) => { | ||||||
|  |                 error!("Received error message: {}.", message.message); | ||||||
|  |             } | ||||||
|  |             _ => unreachable!(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> { | ||||||
|  |     match thread_type { | ||||||
|  |         ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler { | ||||||
|  |             config, | ||||||
|  |             subscription_message_constructor: | ||||||
|  |                 websocket::data::outgoing::subscribe::Message::new_market_us_equity, | ||||||
|  |         }), | ||||||
|  |         ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler { | ||||||
|  |             config, | ||||||
|  |             subscription_message_constructor: | ||||||
|  |                 websocket::data::outgoing::subscribe::Message::new_market_crypto, | ||||||
|  |         }), | ||||||
|  |         ThreadType::News => Box::new(NewsHandler { config }), | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								src/threads/trading/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								src/threads/trading/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | mod websocket; | ||||||
|  |  | ||||||
|  | use crate::{ | ||||||
|  |     config::{Config, ALPACA_WEBSOCKET_URL}, | ||||||
|  |     types::alpaca, | ||||||
|  | }; | ||||||
|  | use futures_util::StreamExt; | ||||||
|  | use std::sync::Arc; | ||||||
|  | use tokio::spawn; | ||||||
|  | use tokio_tungstenite::connect_async; | ||||||
|  |  | ||||||
|  | pub async fn run(config: Arc<Config>) { | ||||||
|  |     let (websocket, _) = connect_async(&*ALPACA_WEBSOCKET_URL).await.unwrap(); | ||||||
|  |     let (mut websocket_sink, mut websocket_stream) = websocket.split(); | ||||||
|  |  | ||||||
|  |     alpaca::websocket::trading::authenticate(&mut websocket_sink, &mut websocket_stream).await; | ||||||
|  |     alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await; | ||||||
|  |  | ||||||
|  |     spawn(websocket::run(config, websocket_stream)); | ||||||
|  | } | ||||||
| @@ -1,7 +1,10 @@ | |||||||
| use crate::{config::Config, database}; | use crate::{ | ||||||
|  |     config::Config, | ||||||
|  |     database, | ||||||
|  |     types::{alpaca::websocket, Order}, | ||||||
|  | }; | ||||||
| use futures_util::{stream::SplitStream, StreamExt}; | use futures_util::{stream::SplitStream, StreamExt}; | ||||||
| use log::{debug, error}; | use log::{debug, error}; | ||||||
| use qrust::types::{alpaca::websocket, Order}; |  | ||||||
| use serde_json::from_str; | use serde_json::from_str; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| use tokio::{net::TcpStream, spawn}; | use tokio::{net::TcpStream, spawn}; | ||||||
| @@ -21,7 +24,7 @@ pub async fn run( | |||||||
|                 ); |                 ); | ||||||
| 
 | 
 | ||||||
|                 if parsed_message.is_err() { |                 if parsed_message.is_err() { | ||||||
|                     error!("Failed to deserialize websocket message: {:?}.", message); |                     error!("Failed to deserialize websocket message: {:?}", message); | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
| @@ -31,7 +34,7 @@ pub async fn run( | |||||||
|                 )); |                 )); | ||||||
|             } |             } | ||||||
|             tungstenite::Message::Ping(_) => {} |             tungstenite::Message::Ping(_) => {} | ||||||
|             _ => error!("Unexpected websocket message: {:?}.", message), |             _ => error!("Unexpected websocket message: {:?}", message), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -43,17 +46,13 @@ async fn handle_websocket_message( | |||||||
|     match message { |     match message { | ||||||
|         websocket::trading::incoming::Message::Order(message) => { |         websocket::trading::incoming::Message::Order(message) => { | ||||||
|             debug!( |             debug!( | ||||||
|                 "Received order message for {}: {:?}.", |                 "Received order message for {}: {:?}", | ||||||
|                 message.order.symbol, message.event |                 message.order.symbol, message.event | ||||||
|             ); |             ); | ||||||
| 
 | 
 | ||||||
|             let order = Order::from(message.order); |             let order = Order::from(message.order); | ||||||
| 
 | 
 | ||||||
|             database::orders::upsert( |             database::orders::upsert(&config.clickhouse_client, &order) | ||||||
|                 &config.clickhouse_client, |  | ||||||
|                 &config.clickhouse_concurrency_limiter, |  | ||||||
|                 &order, |  | ||||||
|             ) |  | ||||||
|                 .await |                 .await | ||||||
|                 .unwrap(); |                 .unwrap(); | ||||||
| 
 | 
 | ||||||
| @@ -64,7 +63,6 @@ async fn handle_websocket_message( | |||||||
|                 } => { |                 } => { | ||||||
|                     database::assets::update_qty_where_symbol( |                     database::assets::update_qty_where_symbol( | ||||||
|                         &config.clickhouse_client, |                         &config.clickhouse_client, | ||||||
|                         &config.clickhouse_concurrency_limiter, |  | ||||||
|                         &order.symbol, |                         &order.symbol, | ||||||
|                         position_qty, |                         position_qty, | ||||||
|                     ) |                     ) | ||||||
| @@ -1,7 +1,13 @@ | |||||||
|  | use crate::config::ALPACA_API_URL; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
| use serde::Deserialize; | use serde::Deserialize; | ||||||
| use serde_aux::field_attributes::{ | use serde_aux::field_attributes::{ | ||||||
|     deserialize_number_from_string, deserialize_option_number_from_string, |     deserialize_number_from_string, deserialize_option_number_from_string, | ||||||
| }; | }; | ||||||
|  | use std::time::Duration; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
| use uuid::Uuid; | use uuid::Uuid; | ||||||
| 
 | 
 | ||||||
| @@ -73,3 +79,38 @@ pub struct Account { | |||||||
|     #[serde(deserialize_with = "deserialize_number_from_string")] |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|     pub regt_buying_power: f64, |     pub regt_buying_power: f64, | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | pub async fn get( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Account, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(&format!("{}/account", *ALPACA_API_URL)) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Account>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get account, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										86
									
								
								src/types/alpaca/api/incoming/asset.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								src/types/alpaca/api/incoming/asset.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | |||||||
|  | use super::position::Position; | ||||||
|  | use crate::{ | ||||||
|  |     config::ALPACA_API_URL, | ||||||
|  |     types::{ | ||||||
|  |         self, | ||||||
|  |         alpaca::shared::asset::{Class, Exchange, Status}, | ||||||
|  |     }, | ||||||
|  | }; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use serde_aux::field_attributes::deserialize_option_number_from_string; | ||||||
|  | use std::time::Duration; | ||||||
|  | use uuid::Uuid; | ||||||
|  |  | ||||||
|  | #[allow(clippy::struct_excessive_bools)] | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Asset { | ||||||
|  |     pub id: Uuid, | ||||||
|  |     pub class: Class, | ||||||
|  |     pub exchange: Exchange, | ||||||
|  |     pub symbol: String, | ||||||
|  |     pub name: String, | ||||||
|  |     pub status: Status, | ||||||
|  |     pub tradable: bool, | ||||||
|  |     pub marginable: bool, | ||||||
|  |     pub shortable: bool, | ||||||
|  |     pub easy_to_borrow: bool, | ||||||
|  |     pub fractionable: bool, | ||||||
|  |     #[serde(deserialize_with = "deserialize_option_number_from_string")] | ||||||
|  |     pub maintenance_margin_requirement: Option<f32>, | ||||||
|  |     pub attributes: Option<Vec<String>>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<(Asset, Option<Position>)> for types::Asset { | ||||||
|  |     fn from((asset, position): (Asset, Option<Position>)) -> Self { | ||||||
|  |         Self { | ||||||
|  |             symbol: asset.symbol, | ||||||
|  |             class: asset.class.into(), | ||||||
|  |             exchange: asset.exchange.into(), | ||||||
|  |             status: asset.status.into(), | ||||||
|  |             time_added: time::OffsetDateTime::now_utc(), | ||||||
|  |             qty: position.map(|position| position.qty).unwrap_or_default(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get_by_symbol( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     symbol: &str, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Asset, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol)) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some( | ||||||
|  |                         reqwest::StatusCode::BAD_REQUEST | ||||||
|  |                         | reqwest::StatusCode::FORBIDDEN | ||||||
|  |                         | reqwest::StatusCode::NOT_FOUND, | ||||||
|  |                     ) => backoff::Error::Permanent(e), | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Asset>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get asset, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										89
									
								
								src/types/alpaca/api/incoming/bar.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								src/types/alpaca/api/incoming/bar.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | use crate::types::{self, alpaca::api::outgoing}; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use std::{collections::HashMap, time::Duration}; | ||||||
|  | use time::OffsetDateTime; | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Bar { | ||||||
|  |     #[serde(rename = "t")] | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     pub time: OffsetDateTime, | ||||||
|  |     #[serde(rename = "o")] | ||||||
|  |     pub open: f64, | ||||||
|  |     #[serde(rename = "h")] | ||||||
|  |     pub high: f64, | ||||||
|  |     #[serde(rename = "l")] | ||||||
|  |     pub low: f64, | ||||||
|  |     #[serde(rename = "c")] | ||||||
|  |     pub close: f64, | ||||||
|  |     #[serde(rename = "v")] | ||||||
|  |     pub volume: f64, | ||||||
|  |     #[serde(rename = "n")] | ||||||
|  |     pub trades: i64, | ||||||
|  |     #[serde(rename = "vw")] | ||||||
|  |     pub vwap: f64, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<(Bar, String)> for types::Bar { | ||||||
|  |     fn from((bar, symbol): (Bar, String)) -> Self { | ||||||
|  |         Self { | ||||||
|  |             time: bar.time, | ||||||
|  |             symbol, | ||||||
|  |             open: bar.open, | ||||||
|  |             high: bar.high, | ||||||
|  |             low: bar.low, | ||||||
|  |             close: bar.close, | ||||||
|  |             volume: bar.volume, | ||||||
|  |             trades: bar.trades, | ||||||
|  |             vwap: bar.vwap, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Message { | ||||||
|  |     pub bars: HashMap<String, Vec<Bar>>, | ||||||
|  |     pub next_page_token: Option<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get_historical( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     data_url: &str, | ||||||
|  |     query: &outgoing::bar::Bar, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Message, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(data_url) | ||||||
|  |                 .query(query) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Message>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get historical bars, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										69
									
								
								src/types/alpaca/api/incoming/calendar.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								src/types/alpaca/api/incoming/calendar.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | use crate::{ | ||||||
|  |     config::ALPACA_API_URL, | ||||||
|  |     types::{self, alpaca::api::outgoing}, | ||||||
|  |     utils::{de, time::EST_OFFSET}, | ||||||
|  | }; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use std::time::Duration; | ||||||
|  | use time::{Date, OffsetDateTime, Time}; | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Calendar { | ||||||
|  |     pub date: Date, | ||||||
|  |     #[serde(deserialize_with = "de::human_time_hh_mm")] | ||||||
|  |     pub open: Time, | ||||||
|  |     #[serde(deserialize_with = "de::human_time_hh_mm")] | ||||||
|  |     pub close: Time, | ||||||
|  |     pub settlement_date: Date, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<Calendar> for types::Calendar { | ||||||
|  |     fn from(calendar: Calendar) -> Self { | ||||||
|  |         Self { | ||||||
|  |             date: calendar.date, | ||||||
|  |             open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET), | ||||||
|  |             close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     query: &outgoing::calendar::Calendar, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Vec<Calendar>, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(&format!("{}/calendar", *ALPACA_API_URL)) | ||||||
|  |                 .query(query) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Vec<Calendar>>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get calendar, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										54
									
								
								src/types/alpaca/api/incoming/clock.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								src/types/alpaca/api/incoming/clock.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | |||||||
|  | use crate::config::ALPACA_API_URL; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use std::time::Duration; | ||||||
|  | use time::OffsetDateTime; | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Clock { | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     pub timestamp: OffsetDateTime, | ||||||
|  |     pub is_open: bool, | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     pub next_open: OffsetDateTime, | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     pub next_close: OffsetDateTime, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Clock, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(&format!("{}/clock", *ALPACA_API_URL)) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Clock>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get clock, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										111
									
								
								src/types/alpaca/api/incoming/news.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								src/types/alpaca/api/incoming/news.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | |||||||
|  | use crate::{ | ||||||
|  |     config::ALPACA_NEWS_DATA_API_URL, | ||||||
|  |     types::{ | ||||||
|  |         self, | ||||||
|  |         alpaca::{api::outgoing, shared::news::normalize_html_content}, | ||||||
|  |     }, | ||||||
|  |     utils::de, | ||||||
|  | }; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::{Client, Error}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use std::time::Duration; | ||||||
|  | use time::OffsetDateTime; | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | #[serde(rename_all = "snake_case")] | ||||||
|  | pub enum ImageSize { | ||||||
|  |     Thumb, | ||||||
|  |     Small, | ||||||
|  |     Large, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Image { | ||||||
|  |     pub size: ImageSize, | ||||||
|  |     pub url: String, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct News { | ||||||
|  |     pub id: i64, | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     #[serde(rename = "created_at")] | ||||||
|  |     pub time_created: OffsetDateTime, | ||||||
|  |     #[serde(with = "time::serde::rfc3339")] | ||||||
|  |     #[serde(rename = "updated_at")] | ||||||
|  |     pub time_updated: OffsetDateTime, | ||||||
|  |     #[serde(deserialize_with = "de::add_slash_to_symbols")] | ||||||
|  |     pub symbols: Vec<String>, | ||||||
|  |     pub headline: String, | ||||||
|  |     pub author: String, | ||||||
|  |     pub source: String, | ||||||
|  |     pub summary: String, | ||||||
|  |     pub content: String, | ||||||
|  |     pub url: Option<String>, | ||||||
|  |     pub images: Vec<Image>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<News> for types::News { | ||||||
|  |     fn from(news: News) -> Self { | ||||||
|  |         Self { | ||||||
|  |             id: news.id, | ||||||
|  |             time_created: news.time_created, | ||||||
|  |             time_updated: news.time_updated, | ||||||
|  |             symbols: news.symbols, | ||||||
|  |             headline: normalize_html_content(&news.headline), | ||||||
|  |             author: normalize_html_content(&news.author), | ||||||
|  |             source: normalize_html_content(&news.source), | ||||||
|  |             summary: normalize_html_content(&news.summary), | ||||||
|  |             content: normalize_html_content(&news.content), | ||||||
|  |             sentiment: types::news::Sentiment::Neutral, | ||||||
|  |             confidence: 0.0, | ||||||
|  |             url: news.url.unwrap_or_default(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Message { | ||||||
|  |     pub news: Vec<News>, | ||||||
|  |     pub next_page_token: Option<String>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get_historical( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     query: &outgoing::news::News, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Message, Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(ALPACA_NEWS_DATA_API_URL) | ||||||
|  |                 .query(query) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Message>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get historical news, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
| @@ -1,39 +1,44 @@ | |||||||
| use super::error_to_backoff; | use crate::{ | ||||||
| use crate::types::alpaca::{api::outgoing, shared::order}; |     config::ALPACA_API_URL, | ||||||
|  |     types::alpaca::{api::outgoing, shared}, | ||||||
|  | }; | ||||||
| use backoff::{future::retry_notify, ExponentialBackoff}; | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
| use governor::DefaultDirectRateLimiter; | use governor::DefaultDirectRateLimiter; | ||||||
| use log::warn; | use log::warn; | ||||||
| use reqwest::{Client, Error}; | use reqwest::{Client, Error}; | ||||||
| use std::time::Duration; | use std::time::Duration; | ||||||
| 
 | 
 | ||||||
| pub use order::Order; | pub use shared::order::Order; | ||||||
| 
 | 
 | ||||||
| pub async fn get( | pub async fn get( | ||||||
|     client: &Client, |     alpaca_client: &Client, | ||||||
|     rate_limiter: &DefaultDirectRateLimiter, |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|     query: &outgoing::order::Order, |     query: &outgoing::order::Order, | ||||||
|     backoff: Option<ExponentialBackoff>, |     backoff: Option<ExponentialBackoff>, | ||||||
|     api_base: &str, |  | ||||||
| ) -> Result<Vec<Order>, Error> { | ) -> Result<Vec<Order>, Error> { | ||||||
|     retry_notify( |     retry_notify( | ||||||
|         backoff.unwrap_or_default(), |         backoff.unwrap_or_default(), | ||||||
|         || async { |         || async { | ||||||
|             rate_limiter.until_ready().await; |             alpaca_rate_limiter.until_ready().await; | ||||||
|             client |             alpaca_client | ||||||
|                 .get(&format!("https://{}.alpaca.markets/v2/orders", api_base)) |                 .get(&format!("{}/orders", *ALPACA_API_URL)) | ||||||
|                 .query(query) |                 .query(query) | ||||||
|                 .send() |                 .send() | ||||||
|                 .await |                 .await? | ||||||
|                 .map_err(error_to_backoff)? |  | ||||||
|                 .error_for_status() |                 .error_for_status() | ||||||
|                 .map_err(error_to_backoff)? |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|                 .json::<Vec<Order>>() |                 .json::<Vec<Order>>() | ||||||
|                 .await |                 .await | ||||||
|                 .map_err(error_to_backoff) |                 .map_err(backoff::Error::Permanent) | ||||||
|         }, |         }, | ||||||
|         |e, duration: Duration| { |         |e, duration: Duration| { | ||||||
|             warn!( |             warn!( | ||||||
|                 "Failed to get orders, will retry in {} seconds: {}.", |                 "Failed to get orders, will retry in {} seconds: {}", | ||||||
|                 duration.as_secs(), |                 duration.as_secs(), | ||||||
|                 e |                 e | ||||||
|             ); |             ); | ||||||
							
								
								
									
										145
									
								
								src/types/alpaca/api/incoming/position.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								src/types/alpaca/api/incoming/position.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | |||||||
|  | use crate::{ | ||||||
|  |     config::ALPACA_API_URL, | ||||||
|  |     types::alpaca::shared::{ | ||||||
|  |         self, | ||||||
|  |         asset::{Class, Exchange}, | ||||||
|  |     }, | ||||||
|  |     utils::de, | ||||||
|  | }; | ||||||
|  | use backoff::{future::retry_notify, ExponentialBackoff}; | ||||||
|  | use governor::DefaultDirectRateLimiter; | ||||||
|  | use log::warn; | ||||||
|  | use reqwest::Client; | ||||||
|  | use serde::Deserialize; | ||||||
|  | use serde_aux::field_attributes::deserialize_number_from_string; | ||||||
|  | use std::time::Duration; | ||||||
|  | use uuid::Uuid; | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | #[serde(rename_all = "snake_case")] | ||||||
|  | pub enum Side { | ||||||
|  |     Long, | ||||||
|  |     Short, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<Side> for shared::order::Side { | ||||||
|  |     fn from(side: Side) -> Self { | ||||||
|  |         match side { | ||||||
|  |             Side::Long => Self::Buy, | ||||||
|  |             Side::Short => Self::Sell, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[derive(Deserialize)] | ||||||
|  | pub struct Position { | ||||||
|  |     pub asset_id: Uuid, | ||||||
|  |     #[serde(deserialize_with = "de::add_slash_to_symbol")] | ||||||
|  |     pub symbol: String, | ||||||
|  |     pub exchange: Exchange, | ||||||
|  |     pub asset_class: Class, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub avg_entry_price: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub qty: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub qty_available: f64, | ||||||
|  |     pub side: Side, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub market_value: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub cost_basis: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub unrealized_pl: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub unrealized_plpc: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub unrealized_intraday_pl: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub unrealized_intraday_plpc: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub current_price: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub lastday_price: f64, | ||||||
|  |     #[serde(deserialize_with = "deserialize_number_from_string")] | ||||||
|  |     pub change_today: f64, | ||||||
|  |     pub asset_marginable: bool, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Vec<Position>, reqwest::Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             alpaca_client | ||||||
|  |                 .get(&format!("{}/positions", *ALPACA_API_URL)) | ||||||
|  |                 .send() | ||||||
|  |                 .await? | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Vec<Position>>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get positions, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn get_by_symbol( | ||||||
|  |     alpaca_client: &Client, | ||||||
|  |     alpaca_rate_limiter: &DefaultDirectRateLimiter, | ||||||
|  |     symbol: &str, | ||||||
|  |     backoff: Option<ExponentialBackoff>, | ||||||
|  | ) -> Result<Option<Position>, reqwest::Error> { | ||||||
|  |     retry_notify( | ||||||
|  |         backoff.unwrap_or_default(), | ||||||
|  |         || async { | ||||||
|  |             alpaca_rate_limiter.until_ready().await; | ||||||
|  |             let response = alpaca_client | ||||||
|  |                 .get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol)) | ||||||
|  |                 .send() | ||||||
|  |                 .await?; | ||||||
|  |  | ||||||
|  |             if response.status() == reqwest::StatusCode::NOT_FOUND { | ||||||
|  |                 return Ok(None); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             response | ||||||
|  |                 .error_for_status() | ||||||
|  |                 .map_err(|e| match e.status() { | ||||||
|  |                     Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => { | ||||||
|  |                         backoff::Error::Permanent(e) | ||||||
|  |                     } | ||||||
|  |                     _ => e.into(), | ||||||
|  |                 })? | ||||||
|  |                 .json::<Position>() | ||||||
|  |                 .await | ||||||
|  |                 .map_err(backoff::Error::Permanent) | ||||||
|  |                 .map(Some) | ||||||
|  |         }, | ||||||
|  |         |e, duration: Duration| { | ||||||
|  |             warn!( | ||||||
|  |                 "Failed to get position, will retry in {} seconds: {}", | ||||||
|  |                 duration.as_secs(), | ||||||
|  |                 e | ||||||
|  |             ); | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     .await | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								src/types/alpaca/api/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								src/types/alpaca/api/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | pub mod incoming; | ||||||
|  | pub mod outgoing; | ||||||
| @@ -1,14 +1,12 @@ | |||||||
| use crate::{ | use crate::{ | ||||||
|     alpaca::bars::MAX_LIMIT, |     config::ALPACA_SOURCE, | ||||||
|     types::alpaca::shared, |     types::alpaca::shared::{Sort, Source}, | ||||||
|     utils::{ser, ONE_MINUTE}, |     utils::{ser, ONE_MINUTE}, | ||||||
| }; | }; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use std::time::Duration; | use std::time::Duration; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
| 
 | 
 | ||||||
| pub use shared::{Sort, Source}; |  | ||||||
| 
 |  | ||||||
| #[derive(Serialize)] | #[derive(Serialize)] | ||||||
| #[serde(rename_all = "snake_case")] | #[serde(rename_all = "snake_case")] | ||||||
| #[allow(dead_code)] | #[allow(dead_code)] | ||||||
| @@ -55,10 +53,10 @@ impl Default for UsEquity { | |||||||
|             timeframe: ONE_MINUTE, |             timeframe: ONE_MINUTE, | ||||||
|             start: None, |             start: None, | ||||||
|             end: None, |             end: None, | ||||||
|             limit: Some(MAX_LIMIT), |             limit: Some(10000), | ||||||
|             adjustment: Some(Adjustment::All), |             adjustment: Some(Adjustment::All), | ||||||
|             asof: None, |             asof: None, | ||||||
|             feed: Some(Source::Iex), |             feed: Some(*ALPACA_SOURCE), | ||||||
|             currency: None, |             currency: None, | ||||||
|             page_token: None, |             page_token: None, | ||||||
|             sort: Some(Sort::Asc), |             sort: Some(Sort::Asc), | ||||||
| @@ -93,7 +91,7 @@ impl Default for Crypto { | |||||||
|             timeframe: ONE_MINUTE, |             timeframe: ONE_MINUTE, | ||||||
|             start: None, |             start: None, | ||||||
|             end: None, |             end: None, | ||||||
|             limit: Some(MAX_LIMIT), |             limit: Some(10000), | ||||||
|             page_token: None, |             page_token: None, | ||||||
|             sort: Some(Sort::Asc), |             sort: Some(Sort::Asc), | ||||||
|         } |         } | ||||||
| @@ -1,4 +1,3 @@ | |||||||
| pub mod asset; |  | ||||||
| pub mod bar; | pub mod bar; | ||||||
| pub mod calendar; | pub mod calendar; | ||||||
| pub mod news; | pub mod news; | ||||||
| @@ -1,10 +1,10 @@ | |||||||
| use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser}; | use crate::{types::alpaca::shared::Sort, utils::ser}; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
| 
 | 
 | ||||||
| #[derive(Serialize)] | #[derive(Serialize)] | ||||||
| pub struct News { | pub struct News { | ||||||
|     #[serde(serialize_with = "ser::remove_slash_and_join_symbols")] |     #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] | ||||||
|     pub symbols: Vec<String>, |     pub symbols: Vec<String>, | ||||||
|     #[serde(skip_serializing_if = "Option::is_none")] |     #[serde(skip_serializing_if = "Option::is_none")] | ||||||
|     #[serde(with = "time::serde::rfc3339::option")] |     #[serde(with = "time::serde::rfc3339::option")] | ||||||
| @@ -30,7 +30,7 @@ impl Default for News { | |||||||
|             symbols: vec![], |             symbols: vec![], | ||||||
|             start: None, |             start: None, | ||||||
|             end: None, |             end: None, | ||||||
|             limit: Some(MAX_LIMIT), |             limit: Some(50), | ||||||
|             include_content: Some(true), |             include_content: Some(true), | ||||||
|             exclude_contentless: Some(false), |             exclude_contentless: Some(false), | ||||||
|             page_token: None, |             page_token: None, | ||||||
| @@ -1,12 +1,10 @@ | |||||||
| use crate::{ | use crate::{ | ||||||
|     types::alpaca::shared::{order, Sort}, |     types::alpaca::shared::{order::Side, Sort}, | ||||||
|     utils::ser, |     utils::ser, | ||||||
| }; | }; | ||||||
| use serde::Serialize; | use serde::Serialize; | ||||||
| use time::OffsetDateTime; | use time::OffsetDateTime; | ||||||
| 
 | 
 | ||||||
| pub use order::Side; |  | ||||||
| 
 |  | ||||||
| #[derive(Serialize)] | #[derive(Serialize)] | ||||||
| #[serde(rename_all = "snake_case")] | #[serde(rename_all = "snake_case")] | ||||||
| #[allow(dead_code)] | #[allow(dead_code)] | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user