26 Commits

Author SHA1 Message Date
2036e5fa32 Add LSTM experiment
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-27 09:53:52 +00:00
080f91b044 Fix README
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-26 19:01:22 +00:00
3006264af1 Fix calendar EST offset
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-26 18:55:15 +00:00
a84daea61c Add local market calendar storage
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-22 12:35:01 +00:00
0b9c6ca122 Add defaults for outgoing types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-22 11:31:28 +00:00
4665891316 Fix error on initialization with no symbols
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-19 17:51:44 +00:00
4f73058792 Add calendar
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-17 20:36:02 +00:00
152a0b4682 Fix bad request response handling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-17 20:35:50 +00:00
ae5044142d Fix status message deserialization
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-16 19:39:36 +00:00
a1781cdf29 Remove manual pongs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-15 14:17:47 +00:00
cdaa2d20a9 Update random bits and bobs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-15 01:09:16 +00:00
4b194e168f Add paper URL support
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 21:15:27 +00:00
6f85b9b0e8 Fix string to number deserialization
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 18:53:58 +00:00
6adf2b46c8 Add partial account management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 17:38:56 +00:00
648d413ac7 Add order/position management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 17:07:30 +00:00
6ec71ee144 Add position types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-14 10:48:37 +00:00
5961717520 Add order types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-12 16:45:11 +00:00
dee21d5324 Add asset status management
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-09 15:43:42 +00:00
76bf2fddcb Clean up error propagation
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-08 18:13:52 +00:00
52e88f4bc9 Remove asset_status thread
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-07 20:40:11 +00:00
85eef2bf0b Refactor threads to use trait implementations
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 13:47:43 +00:00
a796feb299 Lower incoming data log level
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:32:46 +00:00
a2bcb6d17e Make sentiment predictions blocking
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:30:32 +00:00
caaa31133a Improve outgoing Alpaca API types
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-05 00:30:11 +00:00
61c573cbc7 Remove stored abbreviation
- Alpaca is fuck

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-04 21:24:14 +00:00
65c9ae8b25 Add finbert sentiment analysis
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-02-03 18:58:40 +00:00
104 changed files with 39121 additions and 1915 deletions

6
.gitignore vendored
View File

@@ -10,3 +10,9 @@ target/
*.pdb *.pdb
.env* .env*
# ML models
models/*/rust_model.ot
notebooks/models/
libdevice.10.bc

995
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,8 @@ log4rs = "1.2.0"
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.105" serde_json = "1.0.105"
serde_repr = "0.1.18" serde_repr = "0.1.18"
serde_with = "3.6.1"
serde-aux = "4.4.0"
futures-util = "0.3.28" futures-util = "0.3.28"
reqwest = { version = "0.11.20", features = [ reqwest = { version = "0.11.20", features = [
"json", "json",
@@ -39,15 +41,23 @@ clickhouse = { version = "0.11.6", features = [
"time", "time",
"uuid", "uuid",
] } ] }
uuid = "1.6.1" uuid = { version = "1.6.1", features = [
"serde",
] }
time = { version = "0.3.31", features = [ time = { version = "0.3.31", features = [
"serde", "serde",
"serde-well-known",
"serde-human-readable",
"formatting", "formatting",
"macros", "macros",
"serde-well-known", "local-offset",
] } ] }
backoff = { version = "0.4.0", features = [ backoff = { version = "0.4.0", features = [
"tokio", "tokio",
] } ] }
regex = "1.10.3" regex = "1.10.3"
html-escape = "0.2.13" html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77"
itertools = "0.12.1"
lazy_static = "1.4.0"

View File

@@ -1,3 +1,5 @@
# QRust # qrust
QRust (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust. ![XKCD - Engineer Syllogism](./static/engineer-syllogism.png)
`qrust` (/kɹʌst/, QuantitativeRust) is an algorithmic trading library written in Rust.

View File

@@ -4,11 +4,6 @@ services:
file: support/clickhouse/docker-compose.yml file: support/clickhouse/docker-compose.yml
service: clickhouse service: clickhouse
ollama:
extends:
file: support/ollama/docker-compose.yml
service: ollama
grafana: grafana:
extends: extends:
file: support/grafana/docker-compose.yml file: support/grafana/docker-compose.yml
@@ -24,12 +19,10 @@ services:
- 7878:7878 - 7878:7878
depends_on: depends_on:
- clickhouse - clickhouse
- ollama
env_file: env_file:
- .env.docker - .env.docker
volumes: volumes:
clickhouse-lib: clickhouse-lib:
clickhouse-log: clickhouse-log:
ollama:
grafana-lib: grafana-lib:

View File

@@ -0,0 +1,32 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -0,0 +1 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -0,0 +1 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

30522
models/finbert/vocab.txt Normal file

File diff suppressed because it is too large Load Diff

3868
notebooks/lstm.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,87 +1,119 @@
use crate::types::alpaca::Source; use crate::types::alpaca::shared::{Mode, Source};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static;
use reqwest::{ use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue}, header::{HeaderMap, HeaderName, HeaderValue},
Client, Client,
}; };
use std::{env, num::NonZeroU32, sync::Arc, time::Duration}; use rust_bert::{
pipelines::{
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets"; pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock"; pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_STOCK_DATA_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_CRYPTO_DATA_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; "wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.")
.parse()
.expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
#[derive(Debug)]
pub static ref ALPACA_API_URL: String = format!(
"https://{}.alpaca.markets/v2",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
#[derive(Debug)]
pub static ref ALPACA_WEBSOCKET_URL: String = format!(
"wss://{}.alpaca.markets/stream",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse()
.expect("MAX_BERT_INPUTS must be a positive integer.");
}
pub struct Config { pub struct Config {
pub alpaca_api_key: String,
pub alpaca_api_secret: String,
pub alpaca_rate_limit: DefaultDirectRateLimiter,
pub alpaca_source: Source,
pub alpaca_client: Client, pub alpaca_client: Client,
pub ollama_url: String, pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub ollama_model: String,
pub ollama_client: Client,
pub clickhouse_client: clickhouse::Client, pub clickhouse_client: clickhouse::Client,
pub sequence_classifier: Mutex<SequenceClassificationModel>,
} }
impl Config { impl Config {
pub fn from_env() -> Self { pub fn from_env() -> Self {
let alpaca_api_key = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
let alpaca_api_secret =
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
let alpaca_source: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.")
.parse()
.expect("ALPACA_SOURCE must be a either 'iex' or 'sip'.");
let ollama_url = env::var("OLLAMA_URL").expect("OLLAMA_URL must be set.");
let ollama_model = env::var("OLLAMA_MODEL").expect("OLLAMA_MODEL must be set.");
let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.");
let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.");
let clickhouse_password =
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.");
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
Self { Self {
alpaca_client: Client::builder() alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([ .default_headers(HeaderMap::from_iter([
( (
HeaderName::from_static("apca-api-key-id"), HeaderName::from_static("apca-api-key-id"),
HeaderValue::from_str(&alpaca_api_key) HeaderValue::from_str(&ALPACA_API_KEY)
.expect("Alpaca API key must not contain invalid characters."), .expect("Alpaca API key must not contain invalid characters."),
), ),
( (
HeaderName::from_static("apca-api-secret-key"), HeaderName::from_static("apca-api-secret-key"),
HeaderValue::from_str(&alpaca_api_secret) HeaderValue::from_str(&ALPACA_API_SECRET)
.expect("Alpaca API secret must not contain invalid characters."), .expect("Alpaca API secret must not contain invalid characters."),
), ),
])) ]))
.timeout(Duration::from_secs(60))
.build() .build()
.unwrap(), .unwrap(),
alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source { alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => unsafe { NonZeroU32::new_unchecked(190) }, Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(9990) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
Source::Otc => unimplemented!("OTC rate limit not implemented."),
})), })),
alpaca_source,
alpaca_api_key,
alpaca_api_secret,
ollama_url,
ollama_model,
ollama_client: Client::builder()
.timeout(Duration::from_secs(15))
.build()
.unwrap(),
clickhouse_client: clickhouse::Client::default() clickhouse_client: clickhouse::Client::default()
.with_url(clickhouse_url) .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
.with_user(clickhouse_user) .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
.with_password(clickhouse_password) .with_password(
.with_database(clickhouse_db), env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
)
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
sequence_classifier: Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
),
} }
} }

View File

@@ -1,48 +1,43 @@
use crate::types::Asset; use crate::{
use clickhouse::Client; delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
};
use clickhouse::{error::Error, Client};
use serde::Serialize; use serde::Serialize;
pub async fn select(clickhouse_client: &Client) -> Vec<Asset> { select!(Asset, "assets");
clickhouse_client select_where_symbol!(Asset, "assets");
.query("SELECT ?fields FROM assets FINAL") upsert_batch!(Asset, "assets");
.fetch_all::<Asset>() delete_where_symbols!("assets");
.await optimize!("assets");
.unwrap()
}
pub async fn select_where_symbol<T>(clickhouse_client: &Client, symbol: &T) -> Option<Asset> pub async fn update_status_where_symbol<T>(
clickhouse_client: &Client,
symbol: &T,
status: bool,
) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client clickhouse_client
.query("SELECT ?fields FROM assets FINAL WHERE symbol = ? OR abbreviation = ?") .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status)
.bind(symbol)
.execute()
.await
}
pub async fn update_qty_where_symbol<T>(
clickhouse_client: &Client,
symbol: &T,
qty: f64,
) -> Result<(), Error>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty)
.bind(symbol) .bind(symbol)
.bind(symbol)
.fetch_optional::<Asset>()
.await
.unwrap()
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, assets: T)
where
T: IntoIterator<Item = Asset> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("assets").unwrap();
for asset in assets {
insert.write(&asset).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM assets WHERE symbol IN ?")
.bind(symbols)
.execute() .execute()
.await .await
.unwrap();
} }

View File

@@ -1,93 +0,0 @@
use crate::{database::assets, threads::data::ThreadType, types::Backfill};
use clickhouse::Client;
use serde::Serialize;
use tokio::join;
pub async fn select_latest_where_symbol<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbol: &T,
) -> Option<Backfill>
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbol)
.fetch_optional::<Backfill>()
.await
.unwrap()
}
pub async fn upsert(clickhouse_client: &Client, thread_type: &ThreadType, backfill: &Backfill) {
let mut insert = clickhouse_client
.insert(match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
})
.unwrap();
insert.write(backfill).await.unwrap();
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(
clickhouse_client: &Client,
thread_type: &ThreadType,
symbols: &[T],
) where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query(&format!(
"DELETE FROM {} WHERE symbol IN ?",
match thread_type {
ThreadType::Bars(_) => "backfills_bars",
ThreadType::News => "backfills_news",
}
))
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let bars_symbols = assets
.clone()
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let news_symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
let delete_bars_future = async {
clickhouse_client
.query("DELETE FROM backfills_bars WHERE symbol NOT IN ?")
.bind(bars_symbols)
.execute()
.await
.unwrap();
};
let delete_news_future = async {
clickhouse_client
.query("DELETE FROM backfills_news WHERE symbol NOT IN ?")
.bind(news_symbols)
.execute()
.await
.unwrap();
};
join!(delete_bars_future, delete_news_future);
}

View File

@@ -0,0 +1,17 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -0,0 +1,17 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -1,50 +1,7 @@
use super::assets; use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use crate::types::Bar;
use clickhouse::Client;
use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, bar: &Bar) { upsert!(Bar, "bars");
let mut insert = clickhouse_client.insert("bars").unwrap(); upsert_batch!(Bar, "bars");
insert.write(bar).await.unwrap(); delete_where_symbols!("bars");
insert.end().await.unwrap(); cleanup!("bars");
} optimize!("bars");
pub async fn upsert_batch<T>(clickhouse_client: &Client, bars: T)
where
T: IntoIterator<Item = Bar> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("bars").unwrap();
for bar in bars {
insert.write(&bar).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where
T: AsRef<str> + Serialize + Send + Sync,
{
clickhouse_client
.query("DELETE FROM bars WHERE symbol IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}
pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
clickhouse_client
.query("DELETE FROM bars WHERE symbol NOT IN ?")
.bind(symbols)
.execute()
.await
.unwrap();
}

38
src/database/calendar.rs Normal file
View File

@@ -0,0 +1,38 @@
use crate::{optimize, types::Calendar};
use clickhouse::error::Error;
use tokio::try_join;
optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), Error>
where
T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
T::IntoIter: Send,
{
let upsert_future = async {
let mut insert = client.insert("calendar")?;
for record in records.clone() {
insert.write(record).await?;
}
insert.end().await
};
let delete_future = async {
let dates = records
.clone()
.into_iter()
.map(|r| r.date)
.collect::<Vec<_>>();
client
.query("DELETE FROM calendar WHERE date NOT IN ?")
.bind(dates)
.execute()
.await
};
try_join!(upsert_future, delete_future).map(|_| ())
}

View File

@@ -1,4 +1,152 @@
pub mod assets; pub mod assets;
pub mod backfills; pub mod backfills_bars;
pub mod backfills_news;
pub mod bars; pub mod bars;
pub mod calendar;
pub mod news; pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use log::info;
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
) -> Result<Vec<$record>, clickhouse::error::Error> {
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), clickhouse::error::Error>
where
T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Cleaning up database.");
try_join!(
bars::cleanup(clickhouse_client),
news::cleanup(clickhouse_client),
backfills_bars::cleanup(clickhouse_client),
backfills_news::cleanup(clickhouse_client)
)
.map(|_| ())
}
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Optimizing database.");
try_join!(
assets::optimize(clickhouse_client),
bars::optimize(clickhouse_client),
news::optimize(clickhouse_client),
backfills_bars::optimize(clickhouse_client),
backfills_news::optimize(clickhouse_client),
orders::optimize(clickhouse_client),
calendar::optimize(clickhouse_client)
)
.map(|_| ())
}

View File

@@ -1,50 +1,27 @@
use super::assets; use crate::{optimize, types::News, upsert, upsert_batch};
use crate::types::News; use clickhouse::{error::Error, Client};
use clickhouse::Client;
use serde::Serialize; use serde::Serialize;
pub async fn upsert(clickhouse_client: &Client, news: &News) { upsert!(News, "news");
let mut insert = clickhouse_client.insert("news").unwrap(); upsert_batch!(News, "news");
insert.write(news).await.unwrap(); optimize!("news");
insert.end().await.unwrap();
}
pub async fn upsert_batch<T>(clickhouse_client: &Client, news: T) pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error>
where
T: IntoIterator<Item = News> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = clickhouse_client.insert("news").unwrap();
for news in news {
insert.write(&news).await.unwrap();
}
insert.end().await.unwrap();
}
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client clickhouse_client
.query("DELETE FROM news WHERE hasAny(symbols, ?)") .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols) .bind(symbols)
.execute() .execute()
.await .await
.unwrap();
} }
pub async fn cleanup(clickhouse_client: &Client) { pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> {
let assets = assets::select(clickhouse_client).await;
let symbols = assets
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
clickhouse_client clickhouse_client
.query("DELETE FROM news WHERE NOT hasAny(symbols, ?)") .query(
.bind(symbols) "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
)
.execute() .execute()
.await .await
.unwrap();
} }

5
src/database/orders.rs Normal file
View File

@@ -0,0 +1,5 @@
use crate::{optimize, types::Order, upsert, upsert_batch};
upsert!(Order, "orders");
upsert_batch!(Order, "orders");
optimize!("orders");

126
src/init.rs Normal file
View File

@@ -0,0 +1,126 @@
use crate::{
config::{Config, ALPACA_MODE},
database,
types::alpaca,
};
use log::{info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::join;
pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::api::incoming::account::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
)
.await
.unwrap();
assert!(
!(account.status != alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.",
account.status
);
assert!(
!account.trade_suspend_by_user,
"Account trading is suspended by user."
);
assert!(!account.trading_blocked, "Account trading is blocked.");
assert!(!account.blocked, "Account is blocked.");
if account.cash == 0.0 {
warn!("Account cash is zero, qrust will not be able to trade.");
}
warn!(
"qrust active on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_MODE, account.currency, account.cash
);
}
pub async fn rehydrate_orders(config: &Arc<Config>) {
info!("Rehydrating order data.");
let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH;
while let Some(message) = alpaca::api::incoming::order::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::order::Order {
status: Some(alpaca::api::outgoing::order::Status::All),
after: Some(after),
..Default::default()
},
None,
)
.await
.ok()
.filter(|message| !message.is_empty())
{
orders.extend(message);
after = orders.last().unwrap().submitted_at;
}
let orders = orders
.into_iter()
.flat_map(&alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>();
database::orders::upsert_batch(&config.clickhouse_client, &orders)
.await
.unwrap();
info!("Rehydrated order data.");
}
pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data.");
let positions_future = async {
alpaca::api::incoming::position::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
None,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let assets_future = async {
database::assets::select(&config.clickhouse_client)
.await
.unwrap()
};
let (mut positions, assets) = join!(positions_future, assets_future);
let assets = assets
.into_iter()
.map(|mut asset| {
if let Some(position) = positions.remove(&asset.symbol) {
asset.qty = position.qty_available;
} else {
asset.qty = 0.0;
}
asset
})
.collect::<Vec<_>>();
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await
.unwrap();
for position in positions.values() {
warn!(
"Position for unmonitored asset: {}, {} shares.",
position.symbol, position.qty
);
}
info!("Rehydrated position data.");
}

View File

@@ -4,6 +4,7 @@
mod config; mod config;
mod database; mod database;
mod init;
mod routes; mod routes;
mod threads; mod threads;
mod types; mod types;
@@ -12,39 +13,59 @@ mod utils;
use config::Config; use config::Config;
use dotenv::dotenv; use dotenv::dotenv;
use log4rs::config::Deserializers; use log4rs::config::Deserializers;
use tokio::{spawn, sync::mpsc}; use tokio::{join, spawn, sync::mpsc, try_join};
use utils::{cleanup::cleanup, init};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
dotenv().ok(); dotenv().ok();
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let app_config = Config::arc_from_env(); let config = Config::arc_from_env();
cleanup(&app_config.clickhouse_client).await; try_join!(
init::ollama(&app_config).await; database::backfills_bars::unfresh(&config.clickhouse_client),
database::backfills_news::unfresh(&config.clickhouse_client)
)
.unwrap();
let (asset_status_sender, asset_status_receiver) = database::cleanup_all(&config.clickhouse_client)
mpsc::channel::<threads::data::asset_status::Message>(10); .await
.unwrap();
database::optimize_all(&config.clickhouse_client)
.await
.unwrap();
init::check_account(&config).await;
join!(
init::rehydrate_orders(&config),
init::rehydrate_positions(&config)
);
spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1); let (clock_sender, clock_receiver) = mpsc::channel::<threads::clock::Message>(1);
spawn(threads::data::run( spawn(threads::data::run(
app_config.clone(), config.clone(),
asset_status_receiver, data_receiver,
clock_receiver, clock_receiver,
)); ));
spawn(threads::clock::run(app_config.clone(), clock_sender)); spawn(threads::clock::run(config.clone(), clock_sender));
let assets = database::assets::select(&app_config.clickhouse_client).await; let assets = database::assets::select(&config.clickhouse_client)
let (asset_status_message, asset_status_receiver) =
threads::data::asset_status::Message::new(threads::data::asset_status::Action::Add, assets);
asset_status_sender
.send(asset_status_message)
.await .await
.unwrap(); .unwrap()
asset_status_receiver.await.unwrap(); .into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
routes::run(app_config, asset_status_sender).await; create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
routes::run(config, data_sender).await;
} }

View File

@@ -1,31 +1,32 @@
use crate::{ use crate::{
config::{Config, ALPACA_ASSET_API_URL}, config::Config,
database, threads, create_send_await, database, threads,
types::{ types::{alpaca, Asset},
alpaca::api::incoming::{self, asset::Status},
Asset,
},
}; };
use axum::{extract::Path, Extension, Json}; use axum::{extract::Path, Extension, Json};
use backoff::{future::retry, ExponentialBackoff};
use core::panic;
use http::StatusCode; use http::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub async fn get( pub async fn get(
Extension(app_config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> { ) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(&app_config.clickhouse_client).await; let assets = database::assets::select(&config.clickhouse_client)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets))) Ok((StatusCode::OK, Json(assets)))
} }
pub async fn get_where_symbol( pub async fn get_where_symbol(
Extension(app_config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> { ) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol).await; let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset))) Ok((StatusCode::OK, Json(asset)))
}) })
@@ -37,75 +38,62 @@ pub struct AddAssetRequest {
} }
pub async fn add( pub async fn add(
Extension(app_config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Extension(asset_status_sender): Extension<mpsc::Sender<threads::data::asset_status::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetRequest>, Json(request): Json<AddAssetRequest>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> { ) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(&app_config.clickhouse_client, &request.symbol) if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol)
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some() .is_some()
{ {
return Err(StatusCode::CONFLICT); return Err(StatusCode::CONFLICT);
} }
let asset = retry(ExponentialBackoff::default(), || async { let asset = alpaca::api::incoming::asset::get_by_symbol(
app_config.alpaca_rate_limit.until_ready().await; &config.alpaca_client,
app_config &config.alpaca_rate_limiter,
.alpaca_client &request.symbol,
.get(&format!("{}/{}", ALPACA_ASSET_API_URL, request.symbol)) None,
.send() )
.await?
.error_for_status()
.map_err(backoff::Error::Permanent)?
.json::<incoming::asset::Asset>()
.await .await
.map_err(backoff::Error::Permanent) .map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
}) })
.await
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::NOT_FOUND) => StatusCode::NOT_FOUND,
_ => panic!("Unexpected error: {}.", e),
})?; })?;
if asset.status != Status::Active || !asset.tradable || !asset.fractionable { if !asset.tradable || !asset.fractionable {
return Err(StatusCode::FORBIDDEN); return Err(StatusCode::FORBIDDEN);
} }
let asset = Asset::from(asset); create_send_await!(
data_sender,
let (asset_status_message, asset_status_response) = threads::data::asset_status::Message::new( threads::data::Message::new,
threads::data::asset_status::Action::Add, threads::data::Action::Add,
vec![asset.clone()], vec![(asset.symbol, asset.class.into())]
); );
asset_status_sender Ok(StatusCode::CREATED)
.send(asset_status_message)
.await
.unwrap();
asset_status_response.await.unwrap();
Ok((StatusCode::CREATED, Json(asset)))
} }
pub async fn delete( pub async fn delete(
Extension(app_config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Extension(asset_status_sender): Extension<mpsc::Sender<threads::data::asset_status::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> { ) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&app_config.clickhouse_client, &symbol) let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol)
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?; .ok_or(StatusCode::NOT_FOUND)?;
let (asset_status_message, asset_status_response) = threads::data::asset_status::Message::new( create_send_await!(
threads::data::asset_status::Action::Remove, data_sender,
vec![asset], threads::data::Message::new,
threads::data::Action::Remove,
vec![(asset.symbol, asset.class)]
); );
asset_status_sender
.send(asset_status_message)
.await
.unwrap();
asset_status_response.await.unwrap();
Ok(StatusCode::NO_CONTENT) Ok(StatusCode::NO_CONTENT)
} }

View File

@@ -1,5 +1,5 @@
pub mod assets; mod assets;
pub mod health; mod health;
use crate::{config::Config, threads}; use crate::{config::Config, threads};
use axum::{ use axum::{
@@ -10,18 +10,15 @@ use log::info;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::{net::TcpListener, sync::mpsc}; use tokio::{net::TcpListener, sync::mpsc};
pub async fn run( pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::Message>) {
app_config: Arc<Config>,
asset_status_sender: mpsc::Sender<threads::data::asset_status::Message>,
) {
let app = Router::new() let app = Router::new()
.route("/health", get(health::get)) .route("/health", get(health::get))
.route("/assets", get(assets::get)) .route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol)) .route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add)) .route("/assets", post(assets::add))
.route("/assets/:symbol", delete(assets::delete)) .route("/assets/:symbol", delete(assets::delete))
.layer(Extension(app_config)) .layer(Extension(config))
.layer(Extension(asset_status_sender)); .layer(Extension(data_sender));
let addr = SocketAddr::from(([0, 0, 0, 0], 7878)); let addr = SocketAddr::from(([0, 0, 0, 0], 7878));
let listener = TcpListener::bind(addr).await.unwrap(); let listener = TcpListener::bind(addr).await.unwrap();

View File

@@ -1,13 +1,13 @@
use crate::{ use crate::{
config::{Config, ALPACA_CLOCK_API_URL}, config::Config,
types::alpaca, database,
utils::duration_until, types::{alpaca, Calendar},
utils::{backoff, duration_until},
}; };
use backoff::{future::retry, ExponentialBackoff};
use log::info; use log::info;
use std::sync::Arc; use std::sync::Arc;
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{sync::mpsc, time::sleep}; use tokio::{join, sync::mpsc, time::sleep};
pub enum Status { pub enum Status {
Open, Open,
@@ -35,23 +35,33 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
} }
} }
pub async fn run(app_config: Arc<Config>, clock_sender: mpsc::Sender<Message>) { pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop { loop {
let clock = retry(ExponentialBackoff::default(), || async { let clock_future = async {
app_config.alpaca_rate_limit.until_ready().await; alpaca::api::incoming::clock::get(
app_config &config.alpaca_client,
.alpaca_client &config.alpaca_rate_limiter,
.get(ALPACA_CLOCK_API_URL) Some(backoff::infinite()),
.send() )
.await?
.error_for_status()
.map_err(backoff::Error::Permanent)?
.json::<alpaca::api::incoming::clock::Clock>()
.await .await
.map_err(backoff::Error::Permanent) .unwrap()
}) };
let calendar_future = async {
alpaca::api::incoming::calendar::get(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()),
)
.await .await
.unwrap(); .unwrap()
.into_iter()
.map(Calendar::from)
.collect::<Vec<_>>()
};
let (clock, calendar) = join!(clock_future, calendar_future);
let sleep_until = duration_until(if clock.is_open { let sleep_until = duration_until(if clock.is_open {
info!("Market is open, will close at {}.", clock.next_close); info!("Market is open, will close at {}.", clock.next_close);
@@ -61,7 +71,15 @@ pub async fn run(app_config: Arc<Config>, clock_sender: mpsc::Sender<Message>) {
clock.next_open clock.next_open
}); });
sleep(sleep_until).await; let sleep_future = sleep(sleep_until);
clock_sender.send(clock.into()).await.unwrap();
let calendar_future = async {
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar)
.await
.unwrap();
};
join!(sleep_future, calendar_future);
sender.send(clock.into()).await.unwrap();
} }
} }

View File

@@ -1,170 +0,0 @@
use super::{Guard, ThreadType};
use crate::{
config::Config,
database,
types::{alpaca::websocket, Asset},
};
use futures_util::{stream::SplitSink, SinkExt};
use log::info;
use serde_json::to_string;
use std::sync::Arc;
use tokio::{
join,
net::TcpStream,
spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
#[derive(Clone)]
pub enum Action {
Add,
Remove,
}
pub struct Message {
pub action: Action,
pub assets: Vec<Asset>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: Vec<Asset>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
assets,
response: sender,
},
receiver,
)
}
}
pub async fn run(
app_config: Arc<Config>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
mut asset_status_receiver: mpsc::Receiver<Message>,
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
) {
loop {
let message = asset_status_receiver.recv().await.unwrap();
spawn(handle_asset_status_message(
app_config.clone(),
thread_type,
guard.clone(),
websocket_sender.clone(),
message,
));
}
}
#[allow(clippy::significant_drop_tightening)]
async fn handle_asset_status_message(
app_config: Arc<Config>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
message: Message,
) {
let symbols = message
.assets
.clone()
.into_iter()
.map(|asset| match thread_type {
ThreadType::Bars(_) => asset.symbol,
ThreadType::News => asset.abbreviation,
})
.collect::<Vec<_>>();
match message.action {
Action::Add => {
let mut guard = guard.write().await;
guard.symbols.extend(symbols.clone());
guard
.pending_subscriptions
.extend(symbols.clone().into_iter().zip(message.assets.clone()));
info!("{:?} - Added {:?}.", thread_type, symbols);
let database_future = async {
if matches!(thread_type, ThreadType::Bars(_)) {
database::assets::upsert_batch(&app_config.clickhouse_client, message.assets)
.await;
}
};
let websocket_future = async move {
websocket_sender
.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Subscribe(
websocket_market_message_factory(thread_type, symbols),
))
.unwrap(),
))
.await
.unwrap();
};
join!(database_future, websocket_future);
}
Action::Remove => {
let mut guard = guard.write().await;
guard.symbols.retain(|symbol| !symbols.contains(symbol));
guard
.pending_unsubscriptions
.extend(symbols.clone().into_iter().zip(message.assets.clone()));
info!("{:?} - Removed {:?}.", thread_type, symbols);
let sybols_clone = symbols.clone();
let database_future = database::assets::delete_where_symbols(
&app_config.clickhouse_client,
&sybols_clone,
);
let websocket_future = async move {
websocket_sender
.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Unsubscribe(
websocket_market_message_factory(thread_type, symbols),
))
.unwrap(),
))
.await
.unwrap();
};
join!(database_future, websocket_future);
}
}
message.response.send(()).unwrap();
}
fn websocket_market_message_factory(
thread_type: ThreadType,
symbols: Vec<String>,
) -> websocket::outgoing::subscribe::Message {
match thread_type {
ThreadType::Bars(_) => websocket::outgoing::subscribe::Message::Market(
websocket::outgoing::subscribe::MarketMessage::new(symbols),
),
ThreadType::News => websocket::outgoing::subscribe::Message::News(
websocket::outgoing::subscribe::NewsMessage::new(symbols),
),
}
}

View File

@@ -1,23 +1,28 @@
use super::{Guard, ThreadType}; use super::ThreadType;
use crate::{ use crate::{
config::{Config, ALPACA_CRYPTO_DATA_URL, ALPACA_NEWS_DATA_URL, ALPACA_STOCK_DATA_URL}, config::{
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
MAX_BERT_INPUTS,
},
database, database,
types::{ types::{
alpaca::{self, Source}, alpaca::{self, shared::Source},
ollama, Asset, Bar, Class, News, Subset, news::Prediction,
Backfill, Bar, Class, News,
}, },
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}, utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND},
}; };
use backoff::{future::retry, ExponentialBackoff}; use async_trait::async_trait;
use log::{error, info}; use futures_util::future::join_all;
use serde_json::{from_str, to_string}; use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{ use tokio::{
join, spawn, spawn,
sync::{mpsc, oneshot, Mutex, RwLock}, sync::{mpsc, oneshot, Mutex},
task::JoinHandle, task::{block_in_place, JoinHandle},
time::sleep, time::sleep,
try_join,
}; };
pub enum Action { pub enum Action {
@@ -25,19 +30,29 @@ pub enum Action {
Purge, Purge,
} }
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Backfill),
super::Action::Remove => Some(Action::Purge),
super::Action::Disable => None,
}
}
}
pub struct Message { pub struct Message {
pub action: Action, pub action: Option<Action>,
pub assets: Subset<Asset>, pub symbols: Vec<String>,
pub response: oneshot::Sender<()>, pub response: oneshot::Sender<()>,
} }
impl Message { impl Message {
pub fn new(action: Action, assets: Subset<Asset>) -> (Self, oneshot::Receiver<()>) { pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>(); let (sender, receiver) = oneshot::channel::<()>();
( (
Self { Self {
action, action,
assets, symbols,
response: sender, response: sender,
}, },
receiver, receiver,
@@ -45,111 +60,82 @@ impl Message {
} }
} }
pub async fn run( #[async_trait]
app_config: Arc<Config>, pub trait Handler: Send + Sync {
thread_type: ThreadType, async fn select_latest_backfill(
guard: Arc<RwLock<Guard>>, &self,
mut backfill_receiver: mpsc::Receiver<Message>, symbol: String,
) { ) -> Result<Option<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime);
fn log_string(&self) -> &'static str;
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(HashMap::new())); let backfill_jobs = Arc::new(Mutex::new(HashMap::new()));
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_STOCK_DATA_URL.to_string(),
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_URL.to_string(),
ThreadType::News => ALPACA_NEWS_DATA_URL.to_string(),
};
loop { loop {
let message = backfill_receiver.recv().await.unwrap(); let message = receiver.recv().await.unwrap();
spawn(handle_backfill_message( spawn(handle_backfill_message(
app_config.clone(), handler.clone(),
thread_type,
guard.clone(),
data_url.clone(),
backfill_jobs.clone(), backfill_jobs.clone(),
message, message,
)); ));
} }
} }
#[allow(clippy::significant_drop_tightening)]
#[allow(clippy::too_many_lines)]
async fn handle_backfill_message( async fn handle_backfill_message(
app_config: Arc<Config>, handler: Arc<Box<dyn Handler>>,
thread_type: ThreadType,
guard: Arc<RwLock<Guard>>,
data_url: String,
backfill_jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>, backfill_jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
message: Message, message: Message,
) { ) {
let guard = guard.read().await;
let mut backfill_jobs = backfill_jobs.lock().await; let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = match message.assets {
Subset::All => guard.symbols.clone().into_iter().collect::<Vec<_>>(),
Subset::Some(assets) => assets
.into_iter()
.map(|asset| match thread_type {
ThreadType::Bars(_) => asset.symbol,
ThreadType::News => asset.abbreviation,
})
.filter(|symbol| match message.action {
Action::Backfill => guard.symbols.contains(symbol),
Action::Purge => !guard.symbols.contains(symbol),
})
.collect::<Vec<_>>(),
};
match message.action { match message.action {
Action::Backfill => { Some(Action::Backfill) => {
for symbol in symbols { let log_string = handler.log_string();
if let Some(job) = backfill_jobs.remove(&symbol) {
for symbol in message.symbols {
if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() { if !job.is_finished() {
job.abort(); warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
} }
let _ = job.await;
} }
let app_config = app_config.clone(); let handler = handler.clone();
let data_url = data_url.clone();
backfill_jobs.insert( backfill_jobs.insert(
symbol.clone(), symbol.clone(),
spawn(async move { spawn(async move {
let (fetch_from, fetch_to) = let fetch_from = match handler
queue_backfill(&app_config, thread_type, &symbol).await; .select_latest_backfill(symbol.clone())
.await
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
match thread_type { let fetch_to = last_minute();
ThreadType::Bars(_) => {
execute_backfill_bars( if fetch_from > fetch_to {
app_config, info!("No need to backfill {} {}.", symbol, log_string,);
thread_type, return;
data_url,
symbol,
fetch_from,
fetch_to,
)
.await;
}
ThreadType::News => {
execute_backfill_news(
app_config,
thread_type,
data_url,
symbol,
fetch_from,
fetch_to,
)
.await;
}
} }
handler.queue_backfill(&symbol, fetch_to).await;
handler.backfill(symbol, fetch_from, fetch_to).await;
}), }),
); );
} }
} }
Action::Purge => { Some(Action::Purge) => {
for symbol in &symbols { for symbol in &message.symbols {
if let Some(job) = backfill_jobs.remove(symbol) { if let Some(job) = backfill_jobs.remove(symbol) {
if !job.is_finished() { if !job.is_finished() {
job.abort(); job.abort();
@@ -158,120 +144,108 @@ async fn handle_backfill_message(
} }
} }
let backfills_future = database::backfills::delete_where_symbols( try_join!(
&app_config.clickhouse_client, handler.delete_backfills(&message.symbols),
&thread_type, handler.delete_data(&message.symbols)
&symbols,
);
let data_future = async {
match thread_type {
ThreadType::Bars(_) => {
database::bars::delete_where_symbols(
&app_config.clickhouse_client,
&symbols,
) )
.await; .unwrap();
}
ThreadType::News => {
database::news::delete_where_symbols(
&app_config.clickhouse_client,
&symbols,
)
.await;
}
}
};
join!(backfills_future, data_future);
} }
None => {}
} }
message.response.send(()).unwrap(); message.response.send(()).unwrap();
} }
async fn queue_backfill( struct BarHandler {
app_config: &Arc<Config>, config: Arc<Config>,
thread_type: ThreadType, data_url: &'static str,
symbol: &String, api_query_constructor: fn(
) -> (OffsetDateTime, OffsetDateTime) {
let latest_backfill = database::backfills::select_latest_where_symbol(
&app_config.clickhouse_client,
&thread_type,
&symbol,
)
.await;
let fetch_from = latest_backfill
.as_ref()
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_MINUTE
});
let fetch_to = last_minute();
if app_config.alpaca_source == Source::Iex {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!(
"{:?} - Queing backfill for {} in {:?}.",
thread_type, symbol, run_delay
);
sleep(run_delay).await;
}
(fetch_from, fetch_to)
}
async fn execute_backfill_bars(
app_config: Arc<Config>,
thread_type: ThreadType,
data_url: String,
symbol: String, symbol: String,
fetch_from: OffsetDateTime, fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime, fetch_to: OffsetDateTime,
) { next_page_token: Option<String>,
if fetch_from > fetch_to { ) -> alpaca::api::outgoing::bar::Bar,
return; }
fn us_equity_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
fn crypto_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
#[async_trait]
impl Handler for BarHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await
} }
info!("{:?} - Backfilling data for {}.", thread_type, symbol); async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
let mut bars = Vec::new(); async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
if *ALPACA_SOURCE == Source::Iex {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing bar backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling bars for {}.", symbol);
let mut bars = vec![];
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let message = retry(ExponentialBackoff::default(), || async { let Ok(message) = alpaca::api::incoming::bar::get_historical(
app_config.alpaca_rate_limit.until_ready().await; &self.config.alpaca_client,
app_config &self.config.alpaca_rate_limiter,
.alpaca_client self.data_url,
.get(&data_url) &(self.api_query_constructor)(
.query(&alpaca::api::outgoing::bar::Bar { symbol.clone(),
symbols: vec![symbol.clone()], fetch_from,
timeframe: ONE_MINUTE, fetch_to,
start: fetch_from, next_page_token.clone(),
end: fetch_to, ),
limit: 10000, None,
page_token: next_page_token.clone(), )
})
.send()
.await?
.error_for_status()
.map_err(backoff::Error::Permanent)?
.json::<alpaca::api::incoming::bar::Message>()
.await .await
.map_err(backoff::Error::Permanent) else {
}) error!("Failed to backfill bars for {}.", symbol);
.await;
let message = match message {
Ok(message) => message,
Err(e) => {
error!(
"{:?} - Failed to backfill data for {}: {}.",
thread_type, symbol, e
);
return; return;
}
}; };
message.bars.into_iter().for_each(|(symbol, bar_vec)| { message.bars.into_iter().for_each(|(symbol, bar_vec)| {
@@ -287,67 +261,78 @@ async fn execute_backfill_bars(
} }
if bars.is_empty() { if bars.is_empty() {
info!("No bars to backfill for {}.", symbol);
return; return;
} }
let backfill = bars.last().unwrap().clone().into(); let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&app_config.clickhouse_client, bars).await;
database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await;
info!("{:?} - Backfilled data for {}.", thread_type, symbol); database::bars::upsert_batch(&self.config.clickhouse_client, &bars)
} .await
.unwrap();
database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill)
.await
.unwrap();
async fn execute_backfill_news( info!("Backfilled bars for {}.", symbol);
app_config: Arc<Config>,
thread_type: ThreadType,
data_url: String,
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
) {
if fetch_from > fetch_to {
return;
} }
info!("{:?} - Backfilling data for {}.", thread_type, symbol); fn log_string(&self) -> &'static str {
"bars"
}
}
let mut news = Vec::new(); struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing news backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling news for {}.", symbol);
let mut news = vec![];
let mut next_page_token = None; let mut next_page_token = None;
loop { loop {
let message = retry(ExponentialBackoff::default(), || async { let Ok(message) = alpaca::api::incoming::news::get_historical(
app_config.alpaca_rate_limit.until_ready().await; &self.config.alpaca_client,
app_config &self.config.alpaca_rate_limiter,
.alpaca_client &alpaca::api::outgoing::news::News {
.get(&data_url)
.query(&alpaca::api::outgoing::news::News {
symbols: vec![symbol.clone()], symbols: vec![symbol.clone()],
start: fetch_from, start: Some(fetch_from),
end: fetch_to, end: Some(fetch_to),
limit: 50,
include_content: true,
exclude_contentless: false,
page_token: next_page_token.clone(), page_token: next_page_token.clone(),
}) ..Default::default()
.send() },
.await? None,
.error_for_status() )
.map_err(backoff::Error::Permanent)?
.json::<alpaca::api::incoming::news::Message>()
.await .await
.map_err(backoff::Error::Permanent) else {
}) error!("Failed to backfill news for {}.", symbol);
.await;
let message = match message {
Ok(message) => message,
Err(e) => {
error!(
"{:?} - Failed to backfill data for {}: {}.",
thread_type, symbol, e
);
return; return;
}
}; };
message.news.into_iter().for_each(|news_item| { message.news.into_iter().for_each(|news_item| {
@@ -361,59 +346,68 @@ async fn execute_backfill_news(
} }
if news.is_empty() { if news.is_empty() {
info!("No news to backfill for {}.", symbol);
return; return;
} }
for news in &mut news { let inputs = news
info!( .iter()
"{:?} - Getting sentiment for news: {}.", .map(|news| format!("{}\n\n{}", news.headline, news.content))
thread_type, news.headline .collect::<Vec<_>>();
);
let prediction = retry(ExponentialBackoff::default(), || async { let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move {
let response = app_config let sequence_classifier = self.config.sequence_classifier.lock().await;
.ollama_client block_in_place(|| {
.post(format!("{}/api/chat", app_config.ollama_url)) sequence_classifier
.body( .predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
to_string(&ollama::outgoing::sentiment::Sentiment::new( .into_iter()
app_config.ollama_model.clone(), .map(|label| Prediction::try_from(label).unwrap())
&news.clone().into(), .collect::<Vec<_>>()
)) })
.unwrap(), }))
)
.send()
.await .await
.unwrap() .into_iter()
.json::<ollama::incoming::sentiment::Response>() .flatten();
let news = news
.into_iter()
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.config.clickhouse_client, &news)
.await
.unwrap();
database::backfills_news::upsert(&self.config.clickhouse_client, &backfill)
.await .await
.unwrap(); .unwrap();
from_str::<ollama::incoming::sentiment::Prediction>(&response.message.content) info!("Backfilled news for {}.", symbol);
.map_err(Into::into)
})
.await;
match prediction {
Ok(prediction) => {
info!(
"{:?} - Received sentiment for news: {:?}.",
thread_type, prediction
);
news.sentiment = prediction.sentiment.into();
news.confidence = prediction.confidence.into();
}
Err(e) => {
error!(
"{:?} - Failed to get sentiment for news: {:?}.",
thread_type, e
);
}
}
} }
let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); fn log_string(&self) -> &'static str {
database::news::upsert_batch(&app_config.clickhouse_client, news).await; "news"
database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await; }
}
info!("{:?} - Backfilled data for {}.", thread_type, symbol);
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler {
config,
data_url: ALPACA_STOCK_DATA_API_URL,
api_query_constructor: us_equity_query_constructor,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarHandler {
config,
data_url: ALPACA_CRYPTO_DATA_API_URL,
api_query_constructor: crypto_query_constructor,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
} }

View File

@@ -1,218 +1,325 @@
pub mod asset_status; mod backfill;
pub mod backfill; mod websocket;
pub mod websocket;
use super::clock; use super::clock;
use crate::{ use crate::{
config::{ config::{
Config, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_NEWS_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL, Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_SOURCE,
ALPACA_STOCK_DATA_WEBSOCKET_URL,
}, },
types::{Asset, Class, Subset}, create_send_await, database,
utils::authenticate, types::{alpaca, Asset, Class},
}; utils::backoff,
use futures_util::StreamExt;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
}; };
use futures_util::{future::join_all, StreamExt};
use itertools::{Either, Itertools};
use std::sync::Arc;
use tokio::{ use tokio::{
join, select, spawn, join, select, spawn,
sync::{mpsc, Mutex, RwLock}, sync::{mpsc, oneshot},
}; };
use tokio_tungstenite::connect_async; use tokio_tungstenite::connect_async;
pub struct Guard { #[derive(Clone, Copy)]
pub symbols: HashSet<String>, #[allow(dead_code)]
pub pending_subscriptions: HashMap<String, Asset>, pub enum Action {
pub pending_unsubscriptions: HashMap<String, Asset>, Add,
Enable,
Remove,
Disable,
} }
#[derive(Clone, Copy, Debug)] pub struct Message {
pub action: Action,
pub assets: Vec<(String, Class)>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, assets: Vec<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
assets,
response: sender,
},
receiver,
)
}
}
#[derive(Clone, Copy)]
pub enum ThreadType { pub enum ThreadType {
Bars(Class), Bars(Class),
News, News,
} }
pub async fn run( pub async fn run(
app_config: Arc<Config>, config: Arc<Config>,
mut asset_receiver: mpsc::Receiver<asset_status::Message>, mut receiver: mpsc::Receiver<Message>,
mut clock_receiver: mpsc::Receiver<clock::Message>, mut clock_receiver: mpsc::Receiver<clock::Message>,
) { ) {
let (bars_us_equity_asset_status_sender, bars_us_equity_backfill_sender) = let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(app_config.clone(), ThreadType::Bars(Class::UsEquity)).await; init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await;
let (bars_crypto_asset_status_sender, bars_crypto_backfill_sender) = let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(app_config.clone(), ThreadType::Bars(Class::Crypto)).await; init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await;
let (news_asset_status_sender, news_backfill_sender) = let (news_websocket_sender, news_backfill_sender) =
init_thread(app_config.clone(), ThreadType::News).await; init_thread(config.clone(), ThreadType::News).await;
loop { loop {
select! { select! {
Some(asset_message) = asset_receiver.recv() => { Some(message) = receiver.recv() => {
spawn(handle_asset_message( spawn(handle_message(
bars_us_equity_asset_status_sender.clone(), config.clone(),
bars_crypto_asset_status_sender.clone(), bars_us_equity_websocket_sender.clone(),
news_asset_status_sender.clone(), bars_us_equity_backfill_sender.clone(),
asset_message, bars_crypto_websocket_sender.clone(),
bars_crypto_backfill_sender.clone(),
news_websocket_sender.clone(),
news_backfill_sender.clone(),
message,
)); ));
} }
Some(_) = clock_receiver.recv() => { Some(_) = clock_receiver.recv() => {
spawn(handle_clock_message( spawn(handle_clock_message(
config.clone(),
bars_us_equity_backfill_sender.clone(), bars_us_equity_backfill_sender.clone(),
bars_crypto_backfill_sender.clone(), bars_crypto_backfill_sender.clone(),
news_backfill_sender.clone(), news_backfill_sender.clone(),
)); ));
} }
else => { else => panic!("Communication channel unexpectedly closed.")
panic!("Communication channel unexpectedly closed.")
}
} }
} }
} }
async fn init_thread( async fn init_thread(
app_config: Arc<Config>, config: Arc<Config>,
thread_type: ThreadType, thread_type: ThreadType,
) -> ( ) -> (
mpsc::Sender<asset_status::Message>, mpsc::Sender<websocket::Message>,
mpsc::Sender<backfill::Message>, mpsc::Sender<backfill::Message>,
) { ) {
let guard = Arc::new(RwLock::new(Guard {
symbols: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let websocket_url = match thread_type { let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => format!( ThreadType::Bars(Class::UsEquity) => {
"{}/{}", format!("{}/{}", ALPACA_STOCK_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
ALPACA_STOCK_WEBSOCKET_URL, &app_config.alpaca_source }
), ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_WEBSOCKET_URL.into(), ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_WEBSOCKET_URL.into(),
}; };
let (websocket, _) = connect_async(websocket_url).await.unwrap(); let (websocket, _) = connect_async(websocket_url).await.unwrap();
let (mut websocket_sender, mut websocket_receiver) = websocket.split(); let (mut websocket_sink, mut websocket_stream) = websocket.split();
authenticate(&app_config, &mut websocket_sender, &mut websocket_receiver).await; alpaca::websocket::data::authenticate(&mut websocket_sink, &mut websocket_stream).await;
let websocket_sender = Arc::new(Mutex::new(websocket_sender));
let (asset_status_sender, asset_status_receiver) = mpsc::channel(10); let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(asset_status::run(
app_config.clone(),
thread_type,
guard.clone(),
asset_status_receiver,
websocket_sender.clone(),
));
let (backfill_sender, backfill_receiver) = mpsc::channel(10);
spawn(backfill::run( spawn(backfill::run(
app_config.clone(), Arc::new(backfill::create_handler(thread_type, config.clone())),
thread_type,
guard.clone(),
backfill_receiver, backfill_receiver,
)); ));
let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run( spawn(websocket::run(
app_config.clone(), Arc::new(websocket::create_handler(thread_type, config.clone())),
thread_type,
guard.clone(),
websocket_sender,
websocket_receiver, websocket_receiver,
backfill_sender.clone(), websocket_stream,
websocket_sink,
)); ));
(asset_status_sender, backfill_sender) (websocket_sender, backfill_sender)
} }
async fn handle_asset_message( #[allow(clippy::too_many_arguments)]
bars_us_equity_asset_status_sender: mpsc::Sender<asset_status::Message>, #[allow(clippy::too_many_lines)]
bars_crypto_asset_status_sender: mpsc::Sender<asset_status::Message>, async fn handle_message(
news_asset_status_sender: mpsc::Sender<asset_status::Message>, config: Arc<Config>,
asset_status_message: asset_status::Message, bars_us_equity_websocket_sender: mpsc::Sender<websocket::Message>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_websocket_sender: mpsc::Sender<websocket::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_websocket_sender: mpsc::Sender<websocket::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message,
) { ) {
let (us_equity_assets, crypto_assets): (Vec<_>, Vec<_>) = asset_status_message if message.assets.is_empty() {
message.response.send(()).unwrap();
return;
}
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets .assets
.clone() .clone()
.into_iter() .into_iter()
.partition(|asset| asset.class == Class::UsEquity); .partition_map(|asset| match asset.1 {
Class::UsEquity => Either::Left(asset.0),
Class::Crypto => Either::Right(asset.0),
});
let symbols = message
.assets
.into_iter()
.map(|(symbol, _)| symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async { let bars_us_equity_future = async {
if !us_equity_assets.is_empty() { if us_equity_symbols.is_empty() {
let (bars_us_equity_asset_status_message, bars_us_equity_asset_status_receiver) = return;
asset_status::Message::new(asset_status_message.action.clone(), us_equity_assets);
bars_us_equity_asset_status_sender
.send(bars_us_equity_asset_status_message)
.await
.unwrap();
bars_us_equity_asset_status_receiver.await.unwrap();
} }
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols.clone()
);
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
message.action.into(),
us_equity_symbols
);
}; };
let bars_crypto_future = async { let bars_crypto_future = async {
if !crypto_assets.is_empty() { if crypto_symbols.is_empty() {
let (crypto_asset_status_message, crypto_asset_status_receiver) = return;
asset_status::Message::new(asset_status_message.action.clone(), crypto_assets);
bars_crypto_asset_status_sender
.send(crypto_asset_status_message)
.await
.unwrap();
crypto_asset_status_receiver.await.unwrap();
} }
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols.clone()
);
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
message.action.into(),
crypto_symbols
);
}; };
let news_future = async { let news_future = async {
if !asset_status_message.assets.is_empty() { create_send_await!(
let (news_asset_status_message, news_asset_status_receiver) = news_websocket_sender,
asset_status::Message::new( websocket::Message::new,
asset_status_message.action.clone(), message.action.into(),
asset_status_message.assets, symbols.clone()
);
create_send_await!(
news_backfill_sender,
backfill::Message::new,
message.action.into(),
symbols.clone()
); );
news_asset_status_sender
.send(news_asset_status_message)
.await
.unwrap();
news_asset_status_receiver.await.unwrap();
}
}; };
join!(bars_us_equity_future, bars_crypto_future, news_future); join!(bars_us_equity_future, bars_crypto_future, news_future);
asset_status_message.response.send(()).unwrap();
match message.action {
Action::Add => {
let assets = join_all(symbols.into_iter().map(|symbol| {
let config = config.clone();
async move {
let asset_future = async {
alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let position_future = async {
alpaca::api::incoming::position::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let (asset, position) = join!(asset_future, position_future);
Asset::from((asset, position))
}
}))
.await;
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await
.unwrap();
}
Action::Remove => {
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols)
.await
.unwrap();
}
_ => {}
}
message.response.send(()).unwrap();
} }
async fn handle_clock_message( async fn handle_clock_message(
config: Arc<Config>,
bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>, bars_us_equity_backfill_sender: mpsc::Sender<backfill::Message>,
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>, news_backfill_sender: mpsc::Sender<backfill::Message>,
) { ) {
let bars_us_equity_future = async { database::cleanup_all(&config.clickhouse_client)
let (bars_us_equity_backfill_message, bars_us_equity_backfill_receiver) =
backfill::Message::new(backfill::Action::Backfill, Subset::All);
bars_us_equity_backfill_sender
.send(bars_us_equity_backfill_message)
.await .await
.unwrap(); .unwrap();
bars_us_equity_backfill_receiver.await.unwrap();
let assets = database::assets::select(&config.clickhouse_client)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone()
.into_iter()
.partition_map(|asset| match asset.class {
Class::UsEquity => Either::Left(asset.symbol),
Class::Crypto => Either::Right(asset.symbol),
});
let symbols = assets
.into_iter()
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
Some(backfill::Action::Backfill),
us_equity_symbols.clone()
);
}; };
let bars_crypto_future = async { let bars_crypto_future = async {
let (bars_crypto_backfill_message, bars_crypto_backfill_receiver) = create_send_await!(
backfill::Message::new(backfill::Action::Backfill, Subset::All); bars_crypto_backfill_sender,
bars_crypto_backfill_sender backfill::Message::new,
.send(bars_crypto_backfill_message) Some(backfill::Action::Backfill),
.await crypto_symbols.clone()
.unwrap(); );
bars_crypto_backfill_receiver.await.unwrap();
}; };
let news_future = async { let news_future = async {
let (news_backfill_message, news_backfill_receiver) = create_send_await!(
backfill::Message::new(backfill::Action::Backfill, Subset::All); news_backfill_sender,
news_backfill_sender backfill::Message::new,
.send(news_backfill_message) Some(backfill::Action::Backfill),
.await symbols
.unwrap(); );
news_backfill_receiver.await.unwrap();
}; };
join!(bars_us_equity_future, bars_crypto_future, news_future); join!(bars_us_equity_future, bars_crypto_future, news_future);

View File

@@ -1,256 +1,427 @@
use super::{backfill, Guard, ThreadType}; use super::ThreadType;
use crate::{ use crate::{
config::Config, config::Config,
database, database,
types::{alpaca::websocket, ollama, Bar, News, Subset}, types::{alpaca::websocket, news::Prediction, Bar, Class, News},
}; };
use backoff::{future::retry, ExponentialBackoff}; use async_trait::async_trait;
use futures_util::{ use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream}, stream::{SplitSink, SplitStream},
SinkExt, StreamExt, SinkExt, StreamExt,
}; };
use log::{error, info, warn}; use log::{debug, error, info};
use serde_json::{from_str, to_string}; use serde_json::{from_str, to_string};
use std::{ use std::{collections::HashMap, sync::Arc};
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::{ use tokio::{
join,
net::TcpStream, net::TcpStream,
spawn, select, spawn,
sync::{mpsc, Mutex, RwLock}, sync::{mpsc, oneshot, Mutex, RwLock},
task::block_in_place,
}; };
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run( pub enum Action {
app_config: Arc<Config>, Subscribe,
thread_type: ThreadType, Unsubscribe,
guard: Arc<RwLock<Guard>>, }
websocket_sender: Arc<
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>,
mut websocket_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
backfill_sender: mpsc::Sender<backfill::Message>,
) {
loop {
let message = websocket_receiver.next().await.unwrap().unwrap();
spawn(handle_websocket_message( impl From<super::Action> for Option<Action> {
app_config.clone(), fn from(action: super::Action) -> Self {
thread_type, match action {
guard.clone(), super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
websocket_sender.clone(), super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
backfill_sender.clone(), }
message,
));
} }
} }
async fn handle_websocket_message( pub struct Message {
app_config: Arc<Config>, pub action: Option<Action>,
thread_type: ThreadType, pub symbols: Vec<String>,
guard: Arc<RwLock<Guard>>, pub response: oneshot::Sender<()>,
websocket_sender: Arc< }
Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>,
>, impl Message {
backfill_sender: mpsc::Sender<backfill::Message>, pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
message: tungstenite::Message, let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) { ) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
match message { match message {
tungstenite::Message::Text(message) => { tungstenite::Message::Text(message) => {
let message = from_str::<Vec<websocket::incoming::Message>>(&message); let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if let Ok(message) = message { if parsed_message.is_err() {
for message in message { error!("Failed to deserialize websocket message: {:?}", message);
spawn(handle_parsed_websocket_message( continue;
app_config.clone(),
thread_type,
guard.clone(),
backfill_sender.clone(),
message,
));
} }
} else {
error!( for message in parsed_message.unwrap() {
"{:?} - Failed to deserialize websocket message: {:?}", let handler = handler.clone();
thread_type, message let pending = pending.clone();
); spawn(async move {
handler.handle_websocket_message(pending, message).await;
});
} }
} }
tungstenite::Message::Ping(_) => { tungstenite::Message::Ping(_) => {}
websocket_sender _ => error!("Unexpected websocket message: {:?}", message),
.lock() }
.await }
.send(tungstenite::Message::Pong(vec![])) else => panic!("Communication channel unexpectedly closed.")
.await
.unwrap();
} }
_ => error!(
"{:?} - Unexpected websocket message: {:?}",
thread_type, message
),
} }
} }
#[allow(clippy::significant_drop_tightening)] async fn handle_message(
#[allow(clippy::too_many_lines)] handler: Arc<Box<dyn Handler>>,
async fn handle_parsed_websocket_message( pending: Arc<RwLock<Pending>>,
app_config: Arc<Config>, sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
thread_type: ThreadType, message: Message,
guard: Arc<RwLock<Guard>>,
backfill_sender: mpsc::Sender<backfill::Message>,
message: websocket::incoming::Message,
) { ) {
if message.symbols.is_empty() {
message.response.send(()).unwrap();
return;
}
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.subscriptions
.extend(pending_subscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.unsubscriptions
.extend(pending_unsubscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarsHandler {
config: Arc<Config>,
subscription_message_constructor:
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl Handler for BarsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message { match message {
websocket::incoming::Message::Subscription(message) => { websocket::data::incoming::Message::Subscription(message) => {
let symbols = match message { let websocket::data::incoming::subscription::Message::Market {
websocket::incoming::subscription::Message::Market(message) => message.bars, bars: symbols, ..
websocket::incoming::subscription::Message::News(message) => message.news, } = message
else {
unreachable!()
}; };
let mut guard = guard.write().await; let mut pending = pending.write().await;
let newly_subscribed = guard let newly_subscribed = pending
.pending_subscriptions .subscriptions
.extract_if(|symbol, _| symbols.contains(symbol)) .extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let newly_unsubscribed = guard let newly_unsubscribed = pending
.pending_unsubscriptions .unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol)) .extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
drop(guard); drop(pending);
let newly_subscribed_future = async {
if !newly_subscribed.is_empty() { if !newly_subscribed.is_empty() {
info!( info!(
"{:?} - Subscribed to {:?}.", "Subscribed to bars for {:?}.",
thread_type,
newly_subscribed.keys().collect::<Vec<_>>() newly_subscribed.keys().collect::<Vec<_>>()
); );
let (backfill_message, backfill_receiver) = backfill::Message::new( for sender in newly_subscribed.into_values() {
backfill::Action::Backfill, sender.send(()).unwrap();
Subset::Some(newly_subscribed.into_values().collect::<Vec<_>>()), }
);
backfill_sender.send(backfill_message).await.unwrap();
backfill_receiver.await.unwrap();
} }
};
let newly_unsubscribed_future = async {
if !newly_unsubscribed.is_empty() { if !newly_unsubscribed.is_empty() {
info!( info!(
"{:?} - Unsubscribed from {:?}.", "Unsubscribed from bars for {:?}.",
thread_type,
newly_unsubscribed.keys().collect::<Vec<_>>() newly_unsubscribed.keys().collect::<Vec<_>>()
); );
let (purge_message, purge_receiver) = backfill::Message::new( for sender in newly_unsubscribed.into_values() {
backfill::Action::Purge, sender.send(()).unwrap();
Subset::Some(newly_unsubscribed.into_values().collect::<Vec<_>>()),
);
backfill_sender.send(purge_message).await.unwrap();
purge_receiver.await.unwrap();
} }
};
join!(newly_subscribed_future, newly_unsubscribed_future);
} }
websocket::incoming::Message::Bar(message) }
| websocket::incoming::Message::UpdatedBar(message) => { websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message); let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
let guard = guard.read().await; database::bars::upsert(&self.config.clickhouse_client, &bar)
if guard.symbols.get(&bar.symbol).is_none() {
warn!(
"{:?} - Race condition: received bar for unsubscribed symbol: {:?}.",
thread_type, bar.symbol
);
return;
}
info!(
"{:?} - Received bar for {}: {}.",
thread_type, bar.symbol, bar.time
);
database::bars::upsert(&app_config.clickhouse_client, &bar).await;
}
websocket::incoming::Message::News(message) => {
let mut news = News::from(message);
let symbols = news.symbols.clone().into_iter().collect::<HashSet<_>>();
let guard = guard.read().await;
if !guard.symbols.iter().any(|symbol| symbols.contains(symbol)) {
warn!(
"{:?} - Race condition: received news for unsubscribed symbols: {:?}.",
thread_type, news.symbols
);
return;
}
info!(
"{:?} - Received news for {:?}: {}.",
thread_type, news.symbols, news.time_created
);
info!(
"{:?} - Getting sentiment for news: {}.",
thread_type, news.headline
);
let prediction = retry(ExponentialBackoff::default(), || async {
let response = app_config
.ollama_client
.post(format!("{}/api/chat", app_config.ollama_url))
.body(
to_string(&ollama::outgoing::sentiment::Sentiment::new(
app_config.ollama_model.clone(),
&news.clone().into(),
))
.unwrap(),
)
.send()
.await
.unwrap()
.json::<ollama::incoming::sentiment::Response>()
.await .await
.unwrap(); .unwrap();
}
from_str::<ollama::incoming::sentiment::Prediction>(&response.message.content) websocket::data::incoming::Message::Status(message) => {
.map_err(Into::into) debug!(
}) "Received status message for {}: {:?}.",
.await; message.symbol, message.status
match prediction {
Ok(prediction) => {
info!(
"{:?} - Received sentiment for news: {:?}.",
thread_type, prediction
); );
news.sentiment = prediction.sentiment.into();
news.confidence = prediction.confidence.into();
}
Err(e) => {
error!(
"{:?} - Failed to get sentiment for news: {:?}.",
thread_type, e
);
}
}
database::news::upsert(&app_config.clickhouse_client, &news).await; match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&message.symbol,
false,
)
.await
.unwrap();
} }
websocket::incoming::Message::Success(_) => {} websocket::data::incoming::status::Status::Resume(_)
websocket::incoming::Message::Error(message) => { | websocket::data::incoming::status::Status::TradingResumption(_) => {
error!( database::assets::update_status_where_symbol(
"{:?} - Received error message: {}.", &self.config.clickhouse_client,
thread_type, message.message &message.symbol,
); true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
} }
} }
} }
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&self.config.clickhouse_client, &news)
.await
.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_crypto,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -1,2 +1,3 @@
pub mod clock; pub mod clock;
pub mod data; pub mod data;
pub mod trading;

View File

@@ -0,0 +1,20 @@
mod websocket;
use crate::{
config::{Config, ALPACA_WEBSOCKET_URL},
types::alpaca,
};
use futures_util::StreamExt;
use std::sync::Arc;
use tokio::spawn;
use tokio_tungstenite::connect_async;
pub async fn run(config: Arc<Config>) {
let (websocket, _) = connect_async(&*ALPACA_WEBSOCKET_URL).await.unwrap();
let (mut websocket_sink, mut websocket_stream) = websocket.split();
alpaca::websocket::trading::authenticate(&mut websocket_sink, &mut websocket_stream).await;
alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await;
spawn(websocket::run(config, websocket_stream));
}

View File

@@ -0,0 +1,77 @@
use crate::{
config::Config,
database,
types::{alpaca::websocket, Order},
};
use futures_util::{stream::SplitStream, StreamExt};
use log::{debug, error};
use serde_json::from_str;
use std::sync::Arc;
use tokio::{net::TcpStream, spawn};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub async fn run(
config: Arc<Config>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
loop {
let message = websocket_stream.next().await.unwrap().unwrap();
match message {
tungstenite::Message::Binary(message) => {
let parsed_message = from_str::<websocket::trading::incoming::Message>(
&String::from_utf8_lossy(&message),
);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
continue;
}
spawn(handle_websocket_message(
config.clone(),
parsed_message.unwrap(),
));
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
}
}
}
async fn handle_websocket_message(
config: Arc<Config>,
message: websocket::trading::incoming::Message,
) {
match message {
websocket::trading::incoming::Message::Order(message) => {
debug!(
"Received order message for {}: {:?}",
message.order.symbol, message.event
);
let order = Order::from(message.order);
database::orders::upsert(&config.clickhouse_client, &order)
.await
.unwrap();
match message.event {
websocket::trading::incoming::order::Event::Fill { position_qty, .. }
| websocket::trading::incoming::order::Event::PartialFill {
position_qty, ..
} => {
database::assets::update_qty_where_symbol(
&config.clickhouse_client,
&order.symbol,
position_qty,
)
.await
.unwrap();
}
_ => (),
}
}
_ => unreachable!(),
}
}

View File

@@ -1,3 +0,0 @@
pub mod subset;
pub use subset::Subset;

View File

@@ -1,5 +0,0 @@
#[derive(Clone, Debug)]
pub enum Subset<T> {
Some(Vec<T>),
All,
}

View File

@@ -0,0 +1,116 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use std::time::Duration;
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Status {
Onboarding,
SubmissionFailed,
Submitted,
AccountUpdated,
ApprovalPending,
Active,
Rejected,
}
#[derive(Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct Account {
pub id: Uuid,
#[serde(rename = "account_number")]
pub number: String,
pub status: Status,
pub currency: String,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cash: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub non_marginable_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub accrued_fees: f64,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_in: Option<f64>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub pending_transfer_out: Option<f64>,
pub pattern_day_trader: bool,
#[serde(default)]
pub trade_suspend_by_user: bool,
pub trading_blocked: bool,
pub transfers_blocked: bool,
#[serde(rename = "account_blocked")]
pub blocked: bool,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
pub shorting_enabled: bool,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub long_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub short_market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_equity: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub multiplier: i8,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub initial_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub sma: f64,
pub daytrade_count: i64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub last_maintenance_margin: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub daytrading_buying_power: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/account", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,52 +1,24 @@
use crate::types::{self, alpaca::api::impl_from_enum}; use super::position::Position;
use crate::{
config::ALPACA_API_URL,
types::{
self,
alpaca::shared::asset::{Class, Exchange, Status},
},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] use std::time::Duration;
#[serde(rename_all = "snake_case")] use uuid::Uuid;
pub enum Class {
UsEquity,
Crypto,
}
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum Exchange {
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto,
}
impl_from_enum!(
types::Exchange,
Exchange,
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto
);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Status {
Active,
Inactive,
}
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
#[derive(Clone, Debug, PartialEq, Deserialize)] #[derive(Deserialize)]
pub struct Asset { pub struct Asset {
pub id: String, pub id: Uuid,
pub class: Class, pub class: Class,
pub exchange: Exchange, pub exchange: Exchange,
pub symbol: String, pub symbol: String,
@@ -57,18 +29,58 @@ pub struct Asset {
pub shortable: bool, pub shortable: bool,
pub easy_to_borrow: bool, pub easy_to_borrow: bool,
pub fractionable: bool, pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>, pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>, pub attributes: Option<Vec<String>>,
} }
impl From<Asset> for types::Asset { impl From<(Asset, Option<Position>)> for types::Asset {
fn from(item: Asset) -> Self { fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self { Self {
symbol: item.symbol.clone(), symbol: asset.symbol,
abbreviation: item.symbol.replace('/', ""), class: asset.class.into(),
class: item.class.into(), exchange: asset.exchange.into(),
exchange: item.exchange.into(), status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(), time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
} }
} }
} }
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/assets/{}", *ALPACA_API_URL, symbol))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(
reqwest::StatusCode::BAD_REQUEST
| reqwest::StatusCode::FORBIDDEN
| reqwest::StatusCode::NOT_FOUND,
) => backoff::Error::Permanent(e),
_ => e.into(),
})?
.json::<Asset>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,9 +1,13 @@
use crate::types; use crate::types::{self, alpaca::api::outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Deserialize)] #[derive(Deserialize)]
pub struct Bar { pub struct Bar {
#[serde(rename = "t")] #[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -40,8 +44,46 @@ impl From<(Bar, String)> for types::Bar {
} }
} }
#[derive(Clone, Debug, PartialEq, Deserialize)] #[derive(Deserialize)]
pub struct Message { pub struct Message {
pub bars: HashMap<String, Vec<Bar>>, pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>, pub next_page_token: Option<String>,
} }
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(data_url)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,69 @@
use crate::{
config::ALPACA_API_URL,
types::{self, alpaca::api::outgoing},
utils::{de, time::EST_OFFSET},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/calendar", *ALPACA_API_URL))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Calendar>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,7 +1,13 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize)]
pub struct Clock { pub struct Clock {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime, pub timestamp: OffsetDateTime,
@@ -11,3 +17,38 @@ pub struct Clock {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime, pub next_close: OffsetDateTime,
} }
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/clock", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Clock>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,4 +1,8 @@
pub mod account;
pub mod asset; pub mod asset;
pub mod bar; pub mod bar;
pub mod calendar;
pub mod clock; pub mod clock;
pub mod news; pub mod news;
pub mod order;
pub mod position;

View File

@@ -1,28 +1,34 @@
use crate::{ use crate::{
config::ALPACA_NEWS_DATA_API_URL,
types::{ types::{
self, self,
news::{Confidence, Sentiment}, alpaca::{api::outgoing, shared::news::normalize_html_content},
}, },
utils::news, utils::de,
}; };
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "snake_case")]
pub enum ImageSize { pub enum ImageSize {
Thumb, Thumb,
Small, Small,
Large, Large,
} }
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize)]
pub struct Image { pub struct Image {
pub size: ImageSize, pub size: ImageSize,
pub url: String, pub url: String,
} }
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize)]
pub struct News { pub struct News {
pub id: i64, pub id: i64,
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -31,6 +37,7 @@ pub struct News {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")] #[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime, pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>, pub symbols: Vec<String>,
pub headline: String, pub headline: String,
pub author: String, pub author: String,
@@ -48,17 +55,57 @@ impl From<News> for types::News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: news::normalize(&news.headline), headline: normalize_html_content(&news.headline),
author: news::normalize(&news.author), author: normalize_html_content(&news.author),
content: news::normalize(&news.content), source: normalize_html_content(&news.source),
sentiment: Sentiment::Neutral, summary: normalize_html_content(&news.summary),
confidence: Confidence::VeryUncertain, content: normalize_html_content(&news.content),
sentiment: types::news::Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(),
} }
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize)]
pub struct Message { pub struct Message {
pub news: Vec<News>, pub news: Vec<News>,
pub next_page_token: Option<String>, pub next_page_token: Option<String>,
} }
pub async fn get_historical(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Message>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,48 @@
use crate::{
config::ALPACA_API_URL,
types::alpaca::{api::outgoing, shared},
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub use shared::order::Order;
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Order>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/orders", *ALPACA_API_URL))
.query(query)
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get orders, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,145 @@
use crate::{
config::ALPACA_API_URL,
types::alpaca::shared::{
self,
asset::{Class, Exchange},
},
utils::de,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use std::time::Duration;
use uuid::Uuid;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for shared::order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/positions", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Position>>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
let response = alpaca_client
.get(&format!("{}/positions/{}", *ALPACA_API_URL, symbol))
.send()
.await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Position>()
.await
.map_err(backoff::Error::Permanent)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,24 +1,2 @@
pub mod incoming; pub mod incoming;
pub mod outgoing; pub mod outgoing;
macro_rules! impl_from_enum {
($source:ty, $target:ty, $( $variant:ident ),* ) => {
impl From<$source> for $target {
fn from(item: $source) -> Self {
match item {
$( <$source>::$variant => <$target>::$variant, )*
}
}
}
impl From<$target> for $source {
fn from(item: $target) -> Self {
match item {
$( <$target>::$variant => <$source>::$variant, )*
}
}
}
};
}
use impl_from_enum;

View File

@@ -1,51 +1,106 @@
use super::serialize_symbols; use crate::{
config::ALPACA_SOURCE,
types::alpaca::shared::{Sort, Source},
utils::{ser, ONE_MINUTE},
};
use serde::Serialize; use serde::Serialize;
use std::time::Duration; use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
fn serialize_timeframe<S>(timeframe: &Duration, serializer: S) -> Result<S::Ok, S::Error> #[derive(Serialize)]
where #[serde(rename_all = "snake_case")]
S: serde::Serializer, #[allow(dead_code)]
{ pub enum Adjustment {
let mins = timeframe.as_secs() / 60; Raw,
if mins < 60 { Split,
return serializer.serialize_str(&format!("{mins}Min")); Dividend,
} All,
let hours = mins / 60;
if hours < 24 {
return serializer.serialize_str(&format!("{hours}Hour"));
}
let days = hours / 24;
if days == 1 {
return serializer.serialize_str("1Day");
}
let weeks = days / 7;
if weeks == 1 {
return serializer.serialize_str("1Week");
}
let months = days / 30;
if [1, 2, 3, 4, 6, 12].contains(&months) {
return serializer.serialize_str(&format!("{months}Month"));
};
Err(serde::ser::Error::custom("Invalid timeframe duration"))
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct Bar { pub struct UsEquity {
#[serde(serialize_with = "serialize_symbols")] #[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde(serialize_with = "serialize_timeframe")] #[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration, pub timeframe: Duration,
#[serde(with = "time::serde::rfc3339")] #[serde(skip_serializing_if = "Option::is_none")]
pub start: OffsetDateTime, #[serde(with = "time::serde::rfc3339::option")]
#[serde(with = "time::serde::rfc3339")] pub start: Option<OffsetDateTime>,
pub end: OffsetDateTime, #[serde(skip_serializing_if = "Option::is_none")]
pub limit: i64, #[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub adjustment: Option<Adjustment>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub asof: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feed: Option<Source>,
#[serde(skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>, pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for UsEquity {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(10000),
adjustment: Some(Adjustment::All),
asof: None,
feed: Some(*ALPACA_SOURCE),
currency: None,
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
pub struct Crypto {
#[serde(serialize_with = "ser::join_symbols")]
pub symbols: Vec<String>,
#[serde(serialize_with = "ser::timeframe")]
pub timeframe: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub start: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub end: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for Crypto {
fn default() -> Self {
Self {
symbols: vec![],
timeframe: ONE_MINUTE,
start: None,
end: None,
limit: Some(10000),
page_token: None,
sort: Some(Sort::Asc),
}
}
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Bar {
UsEquity(UsEquity),
Crypto(Crypto),
} }

View File

@@ -0,0 +1,31 @@
use crate::utils::time::MAX_TIMESTAMP;
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
#[allow(dead_code)]
pub enum DateType {
Trading,
Settlement,
}
#[derive(Serialize)]
pub struct Calendar {
#[serde(with = "time::serde::rfc3339")]
pub start: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub end: OffsetDateTime,
#[serde(rename = "date")]
pub date_type: DateType,
}
impl Default for Calendar {
fn default() -> Self {
Self {
start: OffsetDateTime::UNIX_EPOCH,
end: *MAX_TIMESTAMP,
date_type: DateType::Trading,
}
}
}

View File

@@ -1,12 +1,4 @@
pub mod bar; pub mod bar;
pub mod calendar;
pub mod news; pub mod news;
pub mod order;
use serde::Serializer;
fn serialize_symbols<S>(symbols: &[String], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = symbols.join(",");
serializer.serialize_str(&string)
}

View File

@@ -1,18 +1,40 @@
use super::serialize_symbols; use crate::{types::alpaca::shared::Sort, utils::ser};
use serde::Serialize; use serde::Serialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Serialize)] #[derive(Serialize)]
pub struct News { pub struct News {
#[serde(serialize_with = "serialize_symbols")] #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")]
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde(with = "time::serde::rfc3339")] #[serde(skip_serializing_if = "Option::is_none")]
pub start: OffsetDateTime, #[serde(with = "time::serde::rfc3339::option")]
#[serde(with = "time::serde::rfc3339")] pub start: Option<OffsetDateTime>,
pub end: OffsetDateTime, #[serde(skip_serializing_if = "Option::is_none")]
pub limit: i64, #[serde(with = "time::serde::rfc3339::option")]
pub include_content: bool, pub end: Option<OffsetDateTime>,
pub exclude_contentless: bool, #[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_content: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_contentless: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub page_token: Option<String>, pub page_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<Sort>,
}
impl Default for News {
fn default() -> Self {
Self {
symbols: vec![],
start: None,
end: None,
limit: Some(50),
include_content: Some(true),
exclude_contentless: Some(false),
page_token: None,
sort: Some(Sort::Asc),
}
}
} }

View File

@@ -0,0 +1,53 @@
use crate::{
types::alpaca::shared::{order::Side, Sort},
utils::ser,
};
use serde::Serialize;
use time::OffsetDateTime;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Status {
Open,
Closed,
All,
}
#[derive(Serialize)]
pub struct Order {
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<Status>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub after: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")]
pub until: Option<OffsetDateTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub direction: Option<Sort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nested: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "ser::join_symbols_option")]
pub symbols: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub side: Option<Side>,
}
impl Default for Order {
fn default() -> Self {
Self {
status: Some(Status::All),
limit: Some(500),
after: None,
until: None,
direction: Some(Sort::Asc),
nested: Some(true),
symbols: None,
side: None,
}
}
}

View File

@@ -1,5 +1,3 @@
pub mod api; pub mod api;
pub mod source; pub mod shared;
pub mod websocket; pub mod websocket;
pub use source::Source;

View File

@@ -0,0 +1,53 @@
use crate::{impl_from_enum, types};
use serde::Deserialize;
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
UsEquity,
Crypto,
}
impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange {
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto,
}
impl_from_enum!(
types::Exchange,
Exchange,
Amex,
Arca,
Bats,
Nyse,
Nasdaq,
Nysearca,
Otc,
Crypto
);
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Active,
Inactive,
}
impl From<Status> for bool {
fn from(status: Status) -> Self {
match status {
Status::Active => true,
Status::Inactive => false,
}
}
}

View File

@@ -0,0 +1,10 @@
pub mod asset;
pub mod mode;
pub mod news;
pub mod order;
pub mod sort;
pub mod source;
pub use mode::Mode;
pub use sort::Sort;
pub use source::Source;

View File

@@ -0,0 +1,33 @@
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter},
str::FromStr,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Mode {
Live,
Paper,
}
impl FromStr for Mode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"live" => Ok(Self::Live),
"paper" => Ok(Self::Paper),
_ => Err(format!("Unknown mode: {s}")),
}
}
}
impl Display for Mode {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Self::Live => write!(f, "live"),
Self::Paper => write!(f, "paper"),
}
}
}

View File

@@ -0,0 +1,18 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static;
use regex::Regex;
lazy_static! {
static ref RE_TAGS: Regex = Regex::new("<[^>]+>").unwrap();
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
}
pub fn normalize_html_content(content: &str) -> String {
let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim();
content.to_string()
}

View File

@@ -0,0 +1,225 @@
use crate::{impl_from_enum, types};
use serde::{Deserialize, Serialize};
use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string,
};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Class {
#[serde(alias = "")]
Simple,
Bracket,
Oco,
Oto,
}
impl_from_enum!(types::order::Class, Class, Simple, Bracket, Oco, Oto);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Type {
Market,
Limit,
Stop,
StopLimit,
TrailingStop,
}
impl_from_enum!(
types::order::Type,
Type,
Market,
Limit,
Stop,
StopLimit,
TrailingStop
);
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Buy,
Sell,
}
impl_from_enum!(types::order::Side, Side, Buy, Sell);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TimeInForce {
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok,
}
impl_from_enum!(
types::order::TimeInForce,
TimeInForce,
Day,
Gtc,
Opg,
Cls,
Ioc,
Fok
);
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated,
}
impl_from_enum!(
types::order::Status,
Status,
New,
PartiallyFilled,
Filled,
DoneForDay,
Canceled,
Expired,
Replaced,
PendingCancel,
PendingReplace,
Accepted,
PendingNew,
AcceptedForBidding,
Stopped,
Rejected,
Suspended,
Calculated
);
#[derive(Deserialize, Clone, Debug, PartialEq)]
#[allow(clippy::struct_field_names)]
pub struct Order {
pub id: Uuid,
pub client_order_id: Uuid,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub updated_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339")]
pub submitted_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub filled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub expired_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub cancel_requested_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub canceled_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub failed_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option")]
pub replaced_at: Option<OffsetDateTime>,
pub replaced_by: Option<Uuid>,
pub replaces: Option<Uuid>,
pub asset_id: Uuid,
pub symbol: String,
pub asset_class: super::asset::Class,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub notional: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub qty: Option<f64>,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub filled_qty: f64,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub filled_avg_price: Option<f64>,
pub order_class: Class,
#[serde(rename = "type")]
pub order_type: Type,
pub side: Side,
pub time_in_force: TimeInForce,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub limit_price: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub stop_price: Option<f64>,
pub status: Status,
pub extended_hours: bool,
pub legs: Option<Vec<Order>>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_percent: Option<f64>,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub trail_price: Option<f64>,
pub hwm: Option<f64>,
}
impl From<Order> for types::Order {
fn from(order: Order) -> Self {
Self {
id: order.id,
client_order_id: order.client_order_id,
time_submitted: order.submitted_at,
time_created: order.created_at,
time_updated: order.updated_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_filled: order.filled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_expired: order.expired_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_cancel_requested: order
.cancel_requested_at
.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_canceled: order.canceled_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_failed: order.failed_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
time_replaced: order.replaced_at.unwrap_or(OffsetDateTime::UNIX_EPOCH),
replaced_by: order.replaced_by.unwrap_or_default(),
replaces: order.replaces.unwrap_or_default(),
symbol: order.symbol,
order_class: order.order_class.into(),
order_type: order.order_type.into(),
side: order.side.into(),
time_in_force: order.time_in_force.into(),
notional: order.notional.unwrap_or_default(),
qty: order.qty.unwrap_or_default(),
filled_qty: order.filled_qty,
filled_avg_price: order.filled_avg_price.unwrap_or_default(),
status: order.status.into(),
extended_hours: order.extended_hours,
limit_price: order.limit_price.unwrap_or_default(),
stop_price: order.stop_price.unwrap_or_default(),
trail_percent: order.trail_percent.unwrap_or_default(),
trail_price: order.trail_price.unwrap_or_default(),
hwm: order.hwm.unwrap_or_default(),
legs: order
.legs
.unwrap_or_default()
.into_iter()
.map(|order| order.id)
.collect(),
}
}
}
impl Order {
pub fn normalize(self) -> Vec<types::Order> {
let mut orders = vec![self.clone().into()];
if let Some(legs) = self.legs {
for leg in legs {
orders.extend(leg.normalize());
}
}
orders
}
}

View File

@@ -0,0 +1,9 @@
use serde::Serialize;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum Sort {
Asc,
Desc,
}

View File

@@ -1,12 +1,15 @@
use serde::{Deserialize, Serialize};
use std::{ use std::{
fmt::{Display, Formatter}, fmt::{Display, Formatter},
str::FromStr, str::FromStr,
}; };
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Source { pub enum Source {
Iex, Iex,
Sip, Sip,
Otc,
} }
impl FromStr for Source { impl FromStr for Source {
@@ -26,6 +29,7 @@ impl Display for Source {
match self { match self {
Self::Iex => write!(f, "iex"), Self::Iex => write!(f, "iex"),
Self::Sip => write!(f, "sip"), Self::Sip => write!(f, "sip"),
Self::Otc => write!(f, "otc"),
} }
} }
} }

View File

@@ -1,8 +1,8 @@
use crate::types; use crate::types::Bar;
use serde::Deserialize; use serde::Deserialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Deserialize)] #[derive(Deserialize, Debug, PartialEq)]
pub struct Message { pub struct Message {
#[serde(rename = "t")] #[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -25,7 +25,7 @@ pub struct Message {
pub vwap: f64, pub vwap: f64,
} }
impl From<Message> for types::Bar { impl From<Message> for Bar {
fn from(bar: Message) -> Self { fn from(bar: Message) -> Self {
Self { Self {
time: bar.time, time: bar.time,

View File

@@ -1,7 +1,6 @@
use serde::Deserialize; use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct Message { pub struct Message {
pub code: u16, pub code: u16,
#[serde(rename = "msg")] #[serde(rename = "msg")]

View File

@@ -1,12 +1,13 @@
pub mod bar; pub mod bar;
pub mod error; pub mod error;
pub mod news; pub mod news;
pub mod status;
pub mod subscription; pub mod subscription;
pub mod success; pub mod success;
use serde::Deserialize; use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Deserialize)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "T")] #[serde(tag = "T")]
pub enum Message { pub enum Message {
#[serde(rename = "success")] #[serde(rename = "success")]
@@ -19,6 +20,8 @@ pub enum Message {
UpdatedBar(bar::Message), UpdatedBar(bar::Message),
#[serde(rename = "n")] #[serde(rename = "n")]
News(news::Message), News(news::Message),
#[serde(rename = "s")]
Status(status::Message),
#[serde(rename = "error")] #[serde(rename = "error")]
Error(error::Message), Error(error::Message),
} }

View File

@@ -1,14 +1,11 @@
use crate::{ use crate::{
types::{ types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News},
self, utils::de,
news::{Confidence, Sentiment},
},
utils::news,
}; };
use serde::Deserialize; use serde::Deserialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message { pub struct Message {
pub id: i64, pub id: i64,
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -17,6 +14,7 @@ pub struct Message {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")] #[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime, pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>, pub symbols: Vec<String>,
pub headline: String, pub headline: String,
pub author: String, pub author: String,
@@ -26,18 +24,21 @@ pub struct Message {
pub url: Option<String>, pub url: Option<String>,
} }
impl From<Message> for types::News { impl From<Message> for News {
fn from(news: Message) -> Self { fn from(news: Message) -> Self {
Self { Self {
id: news.id, id: news.id,
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: news::normalize(&news.headline), headline: normalize_html_content(&news.headline),
author: news::normalize(&news.author), author: normalize_html_content(&news.author),
content: news::normalize(&news.content), source: normalize_html_content(&news.source),
summary: normalize_html_content(&news.summary),
content: normalize_html_content(&news.content),
sentiment: Sentiment::Neutral, sentiment: Sentiment::Neutral,
confidence: Confidence::VeryUncertain, confidence: 0.0,
url: news.url.unwrap_or_default(),
} }
} }
} }

View File

@@ -0,0 +1,154 @@
use serde::Deserialize;
use serde_with::serde_as;
use time::OffsetDateTime;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "sc", content = "sm")]
pub enum Status {
#[serde(rename = "2")]
#[serde(alias = "H")]
TradingHalt(String),
#[serde(rename = "3")]
Resume(String),
#[serde(rename = "5")]
PriceIndication(String),
#[serde(rename = "6")]
TradingRangeIndication(String),
#[serde(rename = "7")]
MarketImbalanceBuy(String),
#[serde(rename = "8")]
MarketImbalanceSell(String),
#[serde(rename = "9")]
MarketOnCloseImbalanceBuy(String),
#[serde(rename = "A")]
MarketOnCloseImbalanceSell(String),
#[serde(rename = "C")]
NoMarketImbalance(String),
#[serde(rename = "D")]
NoMarketOnCloseImbalance(String),
#[serde(rename = "E")]
ShortSaleRestriction(String),
#[serde(rename = "F")]
LimitUpLimitDown(String),
#[serde(rename = "Q")]
QuotationResumption(String),
#[serde(rename = "T")]
TradingResumption(String),
#[serde(rename = "P")]
VolatilityTradingPause(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "rc", content = "rm")]
pub enum Reason {
#[serde(rename = "D")]
NewsReleased(String),
#[serde(rename = "I")]
OrderImbalance(String),
#[serde(rename = "M")]
LimitUpLimitDown(String),
#[serde(rename = "P")]
NewsPending(String),
#[serde(rename = "X")]
Operational(String),
#[serde(rename = "Y")]
SubPennyTrading(String),
#[serde(rename = "1")]
MarketWideCircuitBreakerL1Breached(String),
#[serde(rename = "2")]
MarketWideCircuitBreakerL2Breached(String),
#[serde(rename = "3")]
MarketWideCircuitBreakerL3Breached(String),
#[serde(rename = "T1")]
HaltNewsPending(String),
#[serde(rename = "T2")]
HaltNewsDissemination(String),
#[serde(rename = "T5")]
SingleStockTradingPauseInAffect(String),
#[serde(rename = "T6")]
RegulatoryHaltExtraordinaryMarketActivity(String),
#[serde(rename = "T8")]
HaltETF(String),
#[serde(rename = "T12")]
TradingHaltedForInformationRequestedByNASDAQ(String),
#[serde(rename = "H4")]
HaltNonCompliance(String),
#[serde(rename = "H9")]
HaltFilingsNotCurrent(String),
#[serde(rename = "H10")]
HaltSECTradingSuspension(String),
#[serde(rename = "H11")]
HaltRegulatoryConcern(String),
#[serde(rename = "01")]
OperationsHaltContactMarketOperations(String),
#[serde(rename = "IPO1")]
IPOIssueNotYetTrading(String),
#[serde(rename = "M1")]
CorporateAction(String),
#[serde(rename = "M2")]
QuotationNotAvailable(String),
#[serde(rename = "LUDP")]
VolatilityTradingPause(String),
#[serde(rename = "LUDS")]
VolatilityTradingPauseStraddleCondition(String),
#[serde(rename = "MWC1")]
MarketWideCircuitBreakerHaltL1(String),
#[serde(rename = "MWC2")]
MarketWideCircuitBreakerHaltL2(String),
#[serde(rename = "MWC3")]
MarketWideCircuitBreakerHaltL3(String),
#[serde(rename = "MWC0")]
MarketWideCircuitBreakerHaltCarryOverFromPreviousDay(String),
#[serde(rename = "T3")]
NewsAndResumptionTimes(String),
#[serde(rename = "T7")]
SingleStockTradingPauseQuotationOnlyPeriod(String),
#[serde(rename = "R4")]
QualificationsIssuesReviewedResolvedQuotationsTradingToResume(String),
#[serde(rename = "R9")]
FilingRequirementsSatisfiedResolvedQuotationsTradingToResume(String),
#[serde(rename = "C3")]
IssuerNewsNotForthcomingQuotationsTradingToResume(String),
#[serde(rename = "C4")]
QualificationsHaltEndedMaintReqMetResume(String),
#[serde(rename = "C9")]
QualificationsHaltConcludedFilingsMetQuotesTradesToResume(String),
#[serde(rename = "C11")]
TradeHaltConcludedByOtherRegulatoryAuthQuotesTradesResume(String),
#[serde(rename = "R1")]
NewIssueAvailable(String),
#[serde(rename = "R")]
IssueAvailable(String),
#[serde(rename = "IPOQ")]
IPOSecurityReleasedForQuotation(String),
#[serde(rename = "IPOE")]
IPOSecurityPositioningWindowExtension(String),
#[serde(rename = "MWCQ")]
MarketWideCircuitBreakerResumption(String),
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Tape {
A,
B,
C,
O,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[allow(clippy::struct_field_names)]
#[serde_as]
pub struct Message {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "S")]
pub symbol: String,
#[serde(flatten)]
pub status: Status,
#[serde(flatten)]
#[serde_as(as = "NoneAsEmptyString")]
pub reason: Option<Reason>,
#[serde(rename = "z")]
pub tape: Tape,
}

View File

@@ -0,0 +1,23 @@
use crate::utils::de;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(untagged)]
pub enum Message {
#[serde(rename_all = "camelCase")]
Market {
trades: Vec<String>,
quotes: Vec<String>,
bars: Vec<String>,
updated_bars: Vec<String>,
daily_bars: Vec<String>,
orderbooks: Option<Vec<String>>,
statuses: Option<Vec<String>>,
lulds: Option<Vec<String>>,
cancel_errors: Option<Vec<String>>,
},
News {
#[serde(deserialize_with = "de::add_slash_to_symbols")]
news: Vec<String>,
},
}

View File

@@ -1,8 +1,8 @@
use serde::Deserialize; use serde::Deserialize;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "msg")] #[serde(tag = "msg")]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "snake_case")]
pub enum Message { pub enum Message {
Connected, Connected,
Authenticated, Authenticated,

View File

@@ -0,0 +1,54 @@
pub mod incoming;
pub mod outgoing;
use crate::{
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde_json::{from_str, to_string};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Connected,
)) => {}
_ => panic!("Failed to connect to Alpaca websocket."),
}
sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message {
key: (*ALPACA_API_KEY).clone(),
secret: (*ALPACA_API_SECRET).clone(),
},
))
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Text(data)
if from_str::<Vec<websocket::data::incoming::Message>>(&data)
.unwrap()
.first()
== Some(&websocket::data::incoming::Message::Success(
websocket::data::incoming::success::Message::Authenticated,
)) => {}
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}

View File

@@ -1,11 +1,11 @@
pub mod auth;
pub mod subscribe; pub mod subscribe;
use crate::types::alpaca::websocket::auth;
use serde::Serialize; use serde::Serialize;
#[derive(Serialize)] #[derive(Serialize)]
#[serde(tag = "action")] #[serde(tag = "action")]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "snake_case")]
pub enum Message { pub enum Message {
Auth(auth::Message), Auth(auth::Message),
Subscribe(subscribe::Message), Subscribe(subscribe::Message),

View File

@@ -0,0 +1,49 @@
use crate::utils::ser;
use serde::Serialize;
#[derive(Serialize)]
#[serde(untagged)]
pub enum Market {
#[serde(rename_all = "camelCase")]
UsEquity {
bars: Vec<String>,
updated_bars: Vec<String>,
statuses: Vec<String>,
},
#[serde(rename_all = "camelCase")]
Crypto {
bars: Vec<String>,
updated_bars: Vec<String>,
},
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Message {
Market(Market),
News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: Vec<String>,
},
}
impl Message {
pub fn new_market_us_equity(symbols: Vec<String>) -> Self {
Self::Market(Market::UsEquity {
bars: symbols.clone(),
updated_bars: symbols.clone(),
statuses: symbols,
})
}
pub fn new_market_crypto(symbols: Vec<String>) -> Self {
Self::Market(Market::Crypto {
bars: symbols.clone(),
updated_bars: symbols,
})
}
pub fn new_news(symbols: Vec<String>) -> Self {
Self::News { news: symbols }
}
}

View File

@@ -1,28 +0,0 @@
use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MarketMessage {
pub trades: Vec<String>,
pub quotes: Vec<String>,
pub bars: Vec<String>,
pub updated_bars: Vec<String>,
pub daily_bars: Vec<String>,
pub orderbooks: Option<Vec<String>>,
pub statuses: Option<Vec<String>>,
pub lulds: Option<Vec<String>>,
pub cancel_errors: Option<Vec<String>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NewsMessage {
pub news: Vec<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(untagged)]
pub enum Message {
Market(MarketMessage),
News(NewsMessage),
}

View File

@@ -1,2 +1,3 @@
pub mod incoming; pub mod auth;
pub mod outgoing; pub mod data;
pub mod trading;

View File

@@ -1,36 +0,0 @@
use serde::Serialize;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MarketMessage {
bars: Vec<String>,
updated_bars: Vec<String>,
}
impl MarketMessage {
pub fn new(symbols: Vec<String>) -> Self {
Self {
bars: symbols.clone(),
updated_bars: symbols,
}
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct NewsMessage {
news: Vec<String>,
}
impl NewsMessage {
pub fn new(symbols: Vec<String>) -> Self {
Self { news: symbols }
}
}
#[derive(Serialize)]
#[serde(untagged)]
pub enum Message {
Market(MarketMessage),
News(NewsMessage),
}

View File

@@ -0,0 +1,22 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Status {
Authorized,
Unauthorized,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub enum Action {
#[serde(rename = "authenticate")]
Auth,
#[serde(rename = "listen")]
Subscribe,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub status: Status,
pub action: Action,
}

View File

@@ -0,0 +1,16 @@
pub mod auth;
pub mod order;
pub mod subscription;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "stream", content = "data")]
pub enum Message {
#[serde(rename = "authorization")]
Auth(auth::Message),
#[serde(rename = "listening")]
Subscription(subscription::Message),
#[serde(rename = "trade_updates")]
Order(order::Message),
}

View File

@@ -0,0 +1,57 @@
use crate::types::alpaca::shared;
use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime;
use uuid::Uuid;
pub use shared::order::Order;
#[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "event")]
pub enum Event {
New,
Fill {
timestamp: OffsetDateTime,
#[serde(deserialize_with = "deserialize_number_from_string")]
position_qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
price: f64,
},
PartialFill {
timestamp: OffsetDateTime,
#[serde(deserialize_with = "deserialize_number_from_string")]
position_qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
price: f64,
},
Canceled {
timestamp: OffsetDateTime,
},
Expired {
timestamp: OffsetDateTime,
},
DoneForDay,
Replaced {
timestamp: OffsetDateTime,
},
Rejected {
timestamp: OffsetDateTime,
},
PendingNew,
Stopped,
PendingCancel,
PendingReplace,
Calculated,
Suspended,
OrderReplaceRejected,
OrderCancelRejected,
}
#[derive(Deserialize, Debug, PartialEq)]
pub struct Message {
pub execution_id: Uuid,
pub order: Order,
#[serde(flatten)]
pub event: Event,
}

View File

@@ -0,0 +1,6 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Message {
pub streams: Vec<String>,
}

View File

@@ -0,0 +1,83 @@
pub mod incoming;
pub mod outgoing;
use crate::{
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde_json::{from_str, to_string};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Auth(
websocket::auth::Message {
key: (*ALPACA_API_KEY).clone(),
secret: (*ALPACA_API_SECRET).clone(),
},
))
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Binary(data) => {
let data = String::from_utf8(data).unwrap();
if from_str::<websocket::trading::incoming::Message>(&data).unwrap()
!= websocket::trading::incoming::Message::Auth(
websocket::trading::incoming::auth::Message {
status: websocket::trading::incoming::auth::Status::Authorized,
action: websocket::trading::incoming::auth::Action::Auth,
},
)
{
panic!("Failed to authenticate with Alpaca websocket.");
}
}
_ => panic!("Failed to authenticate with Alpaca websocket."),
};
}
pub async fn subscribe(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Subscribe {
data: websocket::trading::outgoing::subscribe::Message {
streams: vec![String::from("trade_updates")],
},
})
.unwrap(),
))
.await
.unwrap();
match stream.next().await.unwrap().unwrap() {
Message::Binary(data) => {
let data = String::from_utf8(data).unwrap();
if from_str::<websocket::trading::incoming::Message>(&data).unwrap()
!= websocket::trading::incoming::Message::Subscription(
websocket::trading::incoming::subscription::Message {
streams: vec![String::from("trade_updates")],
},
)
{
panic!("Failed to subscribe to Alpaca websocket.");
}
}
_ => panic!("Failed to subscribe to Alpaca websocket."),
};
}

View File

@@ -0,0 +1,15 @@
pub mod subscribe;
use crate::types::alpaca::websocket::auth;
use serde::Serialize;
#[derive(Serialize)]
#[serde(tag = "action")]
#[serde(rename_all = "snake_case")]
pub enum Message {
Auth(auth::Message),
#[serde(rename = "listen")]
Subscribe {
data: subscribe::Message,
},
}

View File

@@ -0,0 +1,6 @@
use serde::Serialize;
#[derive(Serialize)]
pub struct Message {
pub streams: Vec<String>,
}

View File

@@ -1,6 +1,7 @@
use clickhouse::Row; use clickhouse::Row;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr}; use serde_repr::{Deserialize_repr, Serialize_repr};
use std::hash::{Hash, Hasher};
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
@@ -23,12 +24,19 @@ pub enum Exchange {
Crypto = 8, Crypto = 8,
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct Asset { pub struct Asset {
pub symbol: String, pub symbol: String,
pub abbreviation: String,
pub class: Class, pub class: Class,
pub exchange: Exchange, pub exchange: Exchange,
pub status: bool,
#[serde(with = "clickhouse::serde::time::datetime")] #[serde(with = "clickhouse::serde::time::datetime")]
pub time_added: OffsetDateTime, pub time_added: OffsetDateTime,
pub qty: f64,
}
impl Hash for Asset {
fn hash<H: Hasher>(&self, state: &mut H) {
self.symbol.hash(state);
}
} }

13
src/types/calendar.rs Normal file
View File

@@ -0,0 +1,13 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::{Date, OffsetDateTime};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)]
pub struct Calendar {
#[serde(with = "clickhouse::serde::time::date")]
pub date: Date,
#[serde(with = "clickhouse::serde::time::datetime")]
pub open: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub close: OffsetDateTime,
}

View File

@@ -1,13 +1,14 @@
pub mod algebraic;
pub mod alpaca; pub mod alpaca;
pub mod asset; pub mod asset;
pub mod backfill; pub mod backfill;
pub mod bar; pub mod bar;
pub mod calendar;
pub mod news; pub mod news;
pub mod ollama; pub mod order;
pub use algebraic::Subset;
pub use asset::{Asset, Class, Exchange}; pub use asset::{Asset, Class, Exchange};
pub use backfill::Backfill; pub use backfill::Backfill;
pub use bar::Bar; pub use bar::Bar;
pub use calendar::Calendar;
pub use news::News; pub use news::News;
pub use order::Order;

View File

@@ -1,33 +1,49 @@
use clickhouse::Row; use clickhouse::Row;
use rust_bert::pipelines::sequence_classification::Label;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr}; use serde_repr::{Deserialize_repr, Serialize_repr};
use std::str::FromStr;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)] #[repr(i8)]
pub enum Sentiment { pub enum Sentiment {
VeryNegative = -3, Positive = 1,
Negative = -2,
MildlyNegative = -1,
Neutral = 0, Neutral = 0,
MildlyPositive = 1, Negative = -1,
Positive = 2,
VeryPositive = 3,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] impl FromStr for Sentiment {
#[repr(i8)] type Err = ();
pub enum Confidence {
VeryUncertain = -3, fn from_str(s: &str) -> Result<Self, Self::Err> {
Uncertain = -2, match s {
MildlyUncertain = -1, "positive" => Ok(Self::Positive),
Neutral = 0, "neutral" => Ok(Self::Neutral),
MildlyCertain = 1, "negative" => Ok(Self::Negative),
Certain = 2, _ => Err(()),
VeryCertain = 3, }
}
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] #[derive(Clone, Copy, Debug, PartialEq)]
pub struct Prediction {
pub sentiment: Sentiment,
pub confidence: f64,
}
impl TryFrom<Label> for Prediction {
type Error = ();
fn try_from(label: Label) -> Result<Self, Self::Error> {
Ok(Self {
sentiment: Sentiment::from_str(&label.text)?,
confidence: label.score,
})
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News { pub struct News {
pub id: i64, pub id: i64,
#[serde(with = "clickhouse::serde::time::datetime")] #[serde(with = "clickhouse::serde::time::datetime")]
@@ -37,7 +53,10 @@ pub struct News {
pub symbols: Vec<String>, pub symbols: Vec<String>,
pub headline: String, pub headline: String,
pub author: String, pub author: String,
pub source: String,
pub summary: String,
pub content: String, pub content: String,
pub sentiment: Sentiment, pub sentiment: Sentiment,
pub confidence: Confidence, pub confidence: f64,
pub url: String,
} }

View File

@@ -1,15 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}

View File

@@ -1,2 +0,0 @@
pub mod pull;
pub mod sentiment;

View File

@@ -1,6 +0,0 @@
use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
pub struct Response {
pub status: String,
}

View File

@@ -1,75 +0,0 @@
use crate::types::{self, ollama::chat::Message};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Sentiment {
VeryNegative,
Negative,
MildlyNegative,
Neutral,
MildlyPositive,
Positive,
VeryPositive,
}
impl From<Sentiment> for types::news::Sentiment {
fn from(sentiment: Sentiment) -> Self {
match sentiment {
Sentiment::VeryNegative => Self::VeryNegative,
Sentiment::Negative => Self::Negative,
Sentiment::MildlyNegative => Self::MildlyNegative,
Sentiment::Neutral => Self::Neutral,
Sentiment::MildlyPositive => Self::MildlyPositive,
Sentiment::Positive => Self::Positive,
Sentiment::VeryPositive => Self::VeryPositive,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Confidence {
VeryUncertain,
Uncertain,
MildlyUncertain,
Neutral,
MildlyCertain,
Certain,
VeryCertain,
}
impl From<Confidence> for types::news::Confidence {
fn from(confidence: Confidence) -> Self {
match confidence {
Confidence::VeryUncertain => Self::VeryUncertain,
Confidence::Uncertain => Self::Uncertain,
Confidence::MildlyUncertain => Self::MildlyUncertain,
Confidence::Neutral => Self::Neutral,
Confidence::MildlyCertain => Self::MildlyCertain,
Confidence::Certain => Self::Certain,
Confidence::VeryCertain => Self::VeryCertain,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
pub struct Prediction {
pub sentiment: Sentiment,
pub confidence: Confidence,
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
pub struct Response {
pub model: String,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
pub message: Message,
pub done: bool,
pub total_duration: i64,
pub load_duration: i64,
pub prompt_eval_duration: i64,
pub eval_count: i64,
pub eval_duration: i64,
}

View File

@@ -1,3 +0,0 @@
pub mod chat;
pub mod incoming;
pub mod outgoing;

View File

@@ -1,2 +0,0 @@
pub mod pull;
pub mod sentiment;

View File

@@ -1,16 +0,0 @@
use serde::Serialize;
#[derive(Serialize)]
pub struct Pull {
name: String,
stream: bool,
}
impl Pull {
pub const fn new(name: String) -> Self {
Self {
name,
stream: false,
}
}
}

View File

@@ -1,64 +0,0 @@
use crate::types::{
self,
ollama::chat::{self, Message},
};
use serde::Serialize;
use serde_json::to_string;
const PROMPT: &str = r#"You are SentimentLLama, a news classification AI. Users will input a news headline or article, and you will output a sentiment and confidence in one line of JSON.
For sentiment, pick out of ["very_negative", "negative", "mildly_negative", "neutral", "mildly_positive", "positive", "very_positive"].
For confidence, pick out of ["very_uncertain", "uncertain", "mildly_uncertain", "neutral", "mildly_certain", "certain", "very_certain"]."#;
#[derive(Serialize)]
struct ModelOptions {
temperature: f64,
seed: i64,
}
#[derive(Serialize)]
pub struct News {
headline: String,
}
impl From<types::News> for News {
fn from(news: types::News) -> Self {
Self {
headline: news.headline,
}
}
}
#[derive(Serialize)]
pub struct Sentiment {
model: String,
messages: Vec<Message>,
format: String,
stream: bool,
options: ModelOptions,
}
impl Sentiment {
pub fn new(model: String, content: &News) -> Self {
Self {
model,
messages: vec![
Message {
role: chat::Role::System,
content: PROMPT.to_string(),
},
Message {
role: chat::Role::User,
content: to_string(content).unwrap(),
},
],
format: "json".to_string(),
stream: false,
options: ModelOptions {
temperature: 0.0,
seed: 0,
},
}
}
}

107
src/types/order.rs Normal file
View File

@@ -0,0 +1,107 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Class {
Simple = 1,
Bracket = 2,
Oco = 3,
Oto = 4,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Type {
Market = 1,
Limit = 2,
Stop = 3,
StopLimit = 4,
TrailingStop = 5,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Side {
Buy = 1,
Sell = -1,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum TimeInForce {
Day = 1,
Gtc = 2,
Opg = 3,
Cls = 4,
Ioc = 5,
Fok = 6,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Status {
New = 1,
PartiallyFilled = 2,
Filled = 3,
DoneForDay = 4,
Canceled = 5,
Expired = 6,
Replaced = 7,
PendingCancel = 8,
PendingReplace = 9,
Accepted = 10,
PendingNew = 11,
AcceptedForBidding = 12,
Stopped = 13,
Rejected = 14,
Suspended = 15,
Calculated = 16,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
#[allow(clippy::struct_field_names)]
pub struct Order {
pub id: Uuid,
pub client_order_id: Uuid,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_submitted: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_created: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_filled: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_expired: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_cancel_requested: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_canceled: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_failed: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_replaced: OffsetDateTime,
pub replaced_by: Uuid,
pub replaces: Uuid,
pub symbol: String,
pub order_class: Class,
pub order_type: Type,
pub side: Side,
pub time_in_force: TimeInForce,
pub extended_hours: bool,
pub notional: f64,
pub qty: f64,
pub filled_qty: f64,
pub filled_avg_price: f64,
pub status: Status,
pub limit_price: f64,
pub stop_price: f64,
pub trail_percent: f64,
pub trail_price: f64,
pub hwm: f64,
pub legs: Vec<Uuid>,
}

8
src/utils/backoff.rs Normal file
View File

@@ -0,0 +1,8 @@
use backoff::ExponentialBackoff;
pub fn infinite() -> ExponentialBackoff {
ExponentialBackoff {
max_elapsed_time: None,
..ExponentialBackoff::default()
}
}

View File

@@ -1,11 +0,0 @@
use crate::database;
use clickhouse::Client;
use tokio::join;
pub async fn cleanup(clickhouse_client: &Client) {
let bars_future = database::bars::cleanup(clickhouse_client);
let news_future = database::news::cleanup(clickhouse_client);
let backfills_future = database::backfills::cleanup(clickhouse_client);
join!(bars_future, news_future, backfills_future);
}

103
src/utils/de.rs Normal file
View File

@@ -0,0 +1,103 @@
use lazy_static::lazy_static;
use regex::Regex;
use serde::{
de::{self, SeqAccess, Visitor},
Deserializer,
};
use std::fmt;
use time::{format_description::OwnedFormatItem, macros::format_description, Time};
lazy_static! {
static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap();
static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into();
}
fn add_slash(pair: &str) -> String {
RE_SLASH.captures(pair).map_or_else(
|| pair.to_string(),
|caps| format!("{}/{}", &caps[1], &caps[2]),
)
}
pub fn add_slash_to_symbol<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
struct StringVisitor;
impl<'de> Visitor<'de> for StringVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string without a slash")
}
fn visit_str<E>(self, pair: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(add_slash(pair))
}
fn visit_string<E>(self, pair: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(add_slash(&pair))
}
}
deserializer.deserialize_string(StringVisitor)
}
pub fn add_slash_to_symbols<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
struct VecStringVisitor;
impl<'de> Visitor<'de> for VecStringVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a list of strings without a slash")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Vec<String>, A::Error>
where
A: SeqAccess<'de>,
{
let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(25));
while let Some(value) = seq.next_element::<String>()? {
vec.push(add_slash(&value));
}
Ok(vec)
}
}
deserializer.deserialize_seq(VecStringVisitor)
}
pub fn human_time_hh_mm<'de, D>(deserializer: D) -> Result<Time, D::Error>
where
D: Deserializer<'de>,
{
struct TimeVisitor;
impl<'de> Visitor<'de> for TimeVisitor {
type Value = time::Time;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string in the format HH:MM")
}
fn visit_str<E>(self, time: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Time::parse(time, &FMT_HH_MM).map_err(|e| de::Error::custom(e.to_string()))
}
}
deserializer.deserialize_str(TimeVisitor)
}

View File

@@ -1,24 +0,0 @@
use crate::{config::Config, types::ollama};
use serde_json::to_string;
use std::time::Duration;
pub async fn ollama(app_config: &Config) {
let response = app_config
.ollama_client
.post(format!("{}/api/pull", app_config.ollama_url))
.body(
to_string(&ollama::outgoing::pull::Pull::new(
app_config.ollama_model.clone(),
))
.unwrap(),
)
.timeout(Duration::MAX)
.send()
.await
.unwrap()
.json::<ollama::incoming::pull::Response>()
.await
.unwrap();
assert!(response.status == "success", "Failed to pull Ollama model.");
}

29
src/utils/macros.rs Normal file
View File

@@ -0,0 +1,29 @@
#[macro_export]
macro_rules! impl_from_enum {
($source:ty, $target:ty, $( $variant:ident ),* ) => {
impl From<$source> for $target {
fn from(item: $source) -> Self {
match item {
$( <$source>::$variant => <$target>::$variant, )*
}
}
}
impl From<$target> for $source {
fn from(item: $target) -> Self {
match item {
$( <$target>::$variant => <$source>::$variant, )*
}
}
}
};
}
#[macro_export]
macro_rules! create_send_await {
($sender:expr, $action:expr, $($contents:expr),*) => {
let (message, receiver) = $action($($contents),*);
$sender.send(message).await.unwrap();
receiver.await.unwrap()
};
}

View File

@@ -1,8 +1,7 @@
pub mod cleanup; pub mod backoff;
pub mod init; pub mod de;
pub mod news; pub mod macros;
pub mod ser;
pub mod time; pub mod time;
pub mod websocket;
pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}; pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND};
pub use websocket::authenticate;

View File

@@ -1,13 +0,0 @@
use html_escape::decode_html_entities;
use regex::Regex;
pub fn normalize(content: &str) -> String {
let re_tags = Regex::new("<[^>]+>").unwrap();
let re_spaces = Regex::new("[\\u00A0\\s]+").unwrap();
let content = content.replace('\n', " ");
let content = re_tags.replace_all(&content, "");
let content = re_spaces.replace_all(&content, " ");
let content = decode_html_entities(&content);
content.trim().to_string()
}

90
src/utils/ser.rs Normal file
View File

@@ -0,0 +1,90 @@
use serde::{ser::SerializeSeq, Serializer};
use std::time::Duration;
pub fn timeframe<S>(timeframe: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mins = timeframe.as_secs() / 60;
if mins < 60 {
return serializer.serialize_str(&format!("{mins}Min"));
}
let hours = mins / 60;
if hours < 24 {
return serializer.serialize_str(&format!("{hours}Hour"));
}
let days = hours / 24;
if days == 1 {
return serializer.serialize_str("1Day");
}
let weeks = days / 7;
if weeks == 1 {
return serializer.serialize_str("1Week");
}
let months = days / 30;
if [1, 2, 3, 4, 6, 12].contains(&months) {
return serializer.serialize_str(&format!("{months}Month"));
};
Err(serde::ser::Error::custom("Invalid timeframe duration"))
}
fn remove_slash(pair: &str) -> String {
pair.replace('/', "")
}
pub fn join_symbols<S>(symbols: &[String], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = symbols.join(",");
serializer.serialize_str(&string)
}
pub fn join_symbols_option<S>(
symbols: &Option<Vec<String>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match symbols {
Some(symbols) => join_symbols(symbols, serializer),
None => serializer.serialize_none(),
}
}
pub fn remove_slash_from_symbols<S>(pairs: &[String], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let symbols = pairs
.iter()
.map(|pair| remove_slash(pair))
.collect::<Vec<_>>();
let mut seq = serializer.serialize_seq(Some(symbols.len()))?;
for symbol in symbols {
seq.serialize_element(&symbol)?;
}
seq.end()
}
pub fn remove_slash_from_pairs_join_symbols<S>(
symbols: &[String],
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let symbols = symbols
.iter()
.map(|symbol| remove_slash(symbol))
.collect::<Vec<_>>();
join_symbols(&symbols, serializer)
}

View File

@@ -1,9 +1,17 @@
use lazy_static::lazy_static;
use std::time::Duration; use std::time::Duration;
use time::OffsetDateTime; use time::{OffsetDateTime, UtcOffset};
pub const ONE_SECOND: Duration = Duration::from_secs(1);
pub const ONE_MINUTE: Duration = Duration::from_secs(60); pub const ONE_MINUTE: Duration = Duration::from_secs(60);
pub const FIFTEEN_MINUTES: Duration = Duration::from_secs(60 * 15); pub const FIFTEEN_MINUTES: Duration = Duration::from_secs(60 * 15);
lazy_static! {
pub static ref MAX_TIMESTAMP: OffsetDateTime =
OffsetDateTime::from_unix_timestamp(253_402_300_799).unwrap();
pub static ref EST_OFFSET: UtcOffset = UtcOffset::from_hms(-5, 0, 0).unwrap();
}
pub fn last_minute() -> OffsetDateTime { pub fn last_minute() -> OffsetDateTime {
let now_timestamp = OffsetDateTime::now_utc().unix_timestamp(); let now_timestamp = OffsetDateTime::now_utc().unix_timestamp();
OffsetDateTime::from_unix_timestamp(now_timestamp - now_timestamp % 60).unwrap() OffsetDateTime::from_unix_timestamp(now_timestamp - now_timestamp % 60).unwrap()

Some files were not shown because too many files have changed in this diff Show More