Compare commits

..

18 Commits

Author SHA1 Message Date
b7a175d5b4 Improve ser function naming
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-17 09:22:45 +00:00
e9012d6ec3 Add backfill progress logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-15 17:26:58 +00:00
10365745aa Attempt to fix bugs related to empty vecs
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 22:38:20 +00:00
8202255132 Fix backfill freshness
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 12:47:52 +00:00
0d276d537c Add websocket infinite inserting
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-14 01:46:18 +00:00
1707d74cf7 Improve alpaca request error handling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 19:45:06 +00:00
f3f9c6336b Remove rust-bert
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-13 12:09:50 +00:00
5ed0c7670a Fix backfill sentiment batching bug
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-12 21:00:11 +00:00
d2d20e2978 Add automatic websocket reconnection
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 23:41:06 +00:00
d02f958865 Optimize backfill early saving allocations
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:41:59 +00:00
2d8972dce2 Fix possible crashes on .unwrap()s
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 20:15:19 +00:00
7bacc2565a Fix CI
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:22 +00:00
b60cbc891d Add backfill early saving
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-11 16:53:12 +00:00
2de86b46f7 Improve backfill error logging
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 19:51:41 +00:00
8c7ee3d12d Add shared lib
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 18:28:40 +00:00
a15fd2c3c9 Separate data management code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 16:59:21 +00:00
acfc0ca4c9 Add pipelined backfilling
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-10 11:22:24 +00:00
681d7393d7 Add multiple asset adding route
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
2024-03-09 20:13:36 +00:00
116 changed files with 3194 additions and 37380 deletions

7
.gitignore vendored
View File

@@ -2,6 +2,7 @@
# will have compiled files and executables # will have compiled files and executables
debug/ debug/
target/ target/
log/
# These are backup files generated by rustfmt # These are backup files generated by rustfmt
**/*.rs.bk **/*.rs.bk
@@ -10,9 +11,3 @@ target/
*.pdb *.pdb
.env* .env*
# ML models
models/*/rust_model.ot
notebooks/models/
libdevice.10.bc

View File

@@ -24,13 +24,13 @@ build:
script: script:
- cargo +nightly build - cargo +nightly build
test: # test:
image: registry.karaolidis.com/karaolidis/qrust/rust # image: registry.karaolidis.com/karaolidis/qrust/rust
stage: test # stage: test
cache: # cache:
<<: *global_cache # <<: *global_cache
script: # script:
- cargo +nightly test # - cargo +nightly test
lint: lint:
image: registry.karaolidis.com/karaolidis/qrust/rust image: registry.karaolidis.com/karaolidis/qrust/rust
@@ -48,7 +48,7 @@ depcheck:
<<: *global_cache <<: *global_cache
script: script:
- cargo +nightly outdated - cargo +nightly outdated
- cargo +nightly udeps - cargo +nightly udeps --all-targets
build-release: build-release:
image: registry.karaolidis.com/karaolidis/qrust/rust image: registry.karaolidis.com/karaolidis/qrust/rust

976
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,14 @@ name = "qrust"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[lib]
name = "qrust"
path = "src/lib/mod.rs"
[[bin]]
name = "qrust"
path = "src/main.rs"
[profile.release] [profile.release]
panic = 'abort' panic = 'abort'
strip = true strip = true
@@ -43,6 +51,7 @@ clickhouse = { version = "0.11.6", features = [
] } ] }
uuid = { version = "1.6.1", features = [ uuid = { version = "1.6.1", features = [
"serde", "serde",
"v4",
] } ] }
time = { version = "0.3.31", features = [ time = { version = "0.3.31", features = [
"serde", "serde",
@@ -56,8 +65,9 @@ backoff = { version = "0.4.0", features = [
"tokio", "tokio",
] } ] }
regex = "1.10.3" regex = "1.10.3"
html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77" async-trait = "0.1.77"
itertools = "0.12.1" itertools = "0.12.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"
nonempty = { version = "0.10.0", features = [
"serialize",
] }

View File

@@ -4,7 +4,14 @@ appenders:
encoder: encoder:
pattern: "{d} {h({l})} {M}::{L} - {m}{n}" pattern: "{d} {h({l})} {M}::{L} - {m}{n}"
file:
kind: file
path: "./log/output.log"
encoder:
pattern: "{d} {l} {M}::{L} - {m}{n}"
root: root:
level: info level: info
appenders: appenders:
- stdout - stdout
- file

View File

@@ -1,32 +0,0 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -1 +0,0 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -1 +0,0 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -1,68 +1,49 @@
use crate::types::alpaca::shared::{Mode, Source};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use qrust::types::alpaca::shared::{Mode, Source};
use reqwest::{ use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue}, header::{HeaderMap, HeaderName, HeaderValue},
Client, Client,
}; };
use rust_bert::{ use std::{env, num::NonZeroU32, sync::Arc};
pipelines::{ use tokio::sync::Semaphore;
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";
pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";
lazy_static! { lazy_static! {
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
.expect("ALPACA_MODE must be set.") .expect("ALPACA_MODE must be set.")
.parse() .parse()
.expect("ALPACA_MODE must be 'live' or 'paper'"); .expect("ALPACA_MODE must be 'live' or 'paper'");
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
};
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
.expect("ALPACA_SOURCE must be set.") .expect("ALPACA_SOURCE must be set.")
.parse() .parse()
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); pub static ref ALPACA_API_KEY: String =
pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
#[derive(Debug)] pub static ref ALPACA_API_SECRET: String =
pub static ref ALPACA_API_URL: String = format!( env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
"https://{}.alpaca.markets/v2", pub static ref CLICKHOUSE_BATCH_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
match *ALPACA_MODE { .expect("BATCH_BACKFILL_BARS_SIZE must be set.")
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
#[derive(Debug)]
pub static ref ALPACA_WEBSOCKET_URL: String = format!(
"wss://{}.alpaca.markets/stream",
match *ALPACA_MODE {
Mode::Live => String::from("api"),
Mode::Paper => String::from("paper-api"),
}
);
pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse() .parse()
.expect("MAX_BERT_INPUTS must be a positive integer."); .expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_BATCH_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse()
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
} }
pub struct Config { pub struct Config {
pub alpaca_client: Client, pub alpaca_client: Client,
pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client, pub clickhouse_client: clickhouse::Client,
pub sequence_classifier: Mutex<SequenceClassificationModel>, pub clickhouse_concurrency_limiter: Arc<Semaphore>,
} }
impl Config { impl Config {
@@ -85,7 +66,7 @@ impl Config {
.unwrap(), .unwrap(),
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10_000) },
Source::Otc => unimplemented!("OTC rate limit not implemented."), Source::Otc => unimplemented!("OTC rate limit not implemented."),
})), })),
clickhouse_client: clickhouse::Client::default() clickhouse_client: clickhouse::Client::default()
@@ -95,25 +76,7 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
) )
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
sequence_classifier: Mutex::new( clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
),
} }
} }

View File

@@ -1,17 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_bars");
upsert!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -1,17 +0,0 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert,
};
use clickhouse::{error::Error, Client};
select_where_symbol!(Backfill, "backfills_news");
upsert!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> {
clickhouse_client
.query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true")
.execute()
.await
}

View File

@@ -1,7 +0,0 @@
use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
cleanup!("bars");
optimize!("bars");

View File

@@ -1,152 +0,0 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use log::info;
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
) -> Result<Vec<$record>, clickhouse::error::Error> {
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, T>(
client: &clickhouse::Client,
records: T,
) -> Result<(), clickhouse::error::Error>
where
T: IntoIterator<Item = &'a $record> + Send + Sync,
T::IntoIter: Send,
{
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> {
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Cleaning up database.");
try_join!(
bars::cleanup(clickhouse_client),
news::cleanup(clickhouse_client),
backfills_bars::cleanup(clickhouse_client),
backfills_news::cleanup(clickhouse_client)
)
.map(|_| ())
}
pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> {
info!("Optimizing database.");
try_join!(
assets::optimize(clickhouse_client),
bars::optimize(clickhouse_client),
news::optimize(clickhouse_client),
backfills_bars::optimize(clickhouse_client),
backfills_news::optimize(clickhouse_client),
orders::optimize(clickhouse_client),
calendar::optimize(clickhouse_client)
)
.map(|_| ())
}

View File

@@ -1,24 +1,25 @@
use crate::{ use crate::{
config::{Config, ALPACA_MODE}, config::{Config, ALPACA_API_BASE},
database, database,
types::alpaca,
}; };
use log::{info, warn}; use log::{info, warn};
use qrust::{alpaca, types};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::join; use tokio::join;
pub async fn check_account(config: &Arc<Config>) { pub async fn check_account(config: &Arc<Config>) {
let account = alpaca::api::incoming::account::get( let account = alpaca::account::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
None, None,
&ALPACA_API_BASE,
) )
.await .await
.unwrap(); .unwrap();
assert!( assert!(
!(account.status != alpaca::api::incoming::account::Status::Active), !(account.status != types::alpaca::api::incoming::account::Status::Active),
"Account status is not active: {:?}.", "Account status is not active: {:?}.",
account.status account.status
); );
@@ -33,56 +34,60 @@ pub async fn check_account(config: &Arc<Config>) {
warn!("Account cash is zero, qrust will not be able to trade."); warn!("Account cash is zero, qrust will not be able to trade.");
} }
warn!( info!(
"qrust active on {} account with {} {}, avoid transferring funds without shutting down.", "qrust running on {} account with {} {}, avoid transferring funds without shutting down.",
*ALPACA_MODE, account.currency, account.cash *ALPACA_API_BASE, account.currency, account.cash
); );
} }
pub async fn rehydrate_orders(config: &Arc<Config>) { pub async fn rehydrate_orders(config: &Arc<Config>) {
info!("Rehydrating order data.");
let mut orders = vec![]; let mut orders = vec![];
let mut after = OffsetDateTime::UNIX_EPOCH; let mut after = OffsetDateTime::UNIX_EPOCH;
while let Some(message) = alpaca::api::incoming::order::get( loop {
&config.alpaca_client, let message = alpaca::orders::get(
&config.alpaca_rate_limiter, &config.alpaca_client,
&alpaca::api::outgoing::order::Order { &config.alpaca_rate_limiter,
status: Some(alpaca::api::outgoing::order::Status::All), &types::alpaca::api::outgoing::order::Order {
after: Some(after), status: Some(types::alpaca::api::outgoing::order::Status::All),
..Default::default() after: Some(after),
}, ..Default::default()
None, },
) None,
.await &ALPACA_API_BASE,
.ok() )
.filter(|message| !message.is_empty()) .await
{ .unwrap();
if message.is_empty() {
break;
}
orders.extend(message); orders.extend(message);
after = orders.last().unwrap().submitted_at; after = orders.last().unwrap().submitted_at;
} }
let orders = orders let orders = orders
.into_iter() .into_iter()
.flat_map(&alpaca::api::incoming::order::Order::normalize) .flat_map(&types::alpaca::api::incoming::order::Order::normalize)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
database::orders::upsert_batch(&config.clickhouse_client, &orders) database::orders::upsert_batch(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&orders,
info!("Rehydrated order data."); )
.await
.unwrap();
} }
pub async fn rehydrate_positions(config: &Arc<Config>) { pub async fn rehydrate_positions(config: &Arc<Config>) {
info!("Rehydrating position data.");
let positions_future = async { let positions_future = async {
alpaca::api::incoming::position::get( alpaca::positions::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
None, None,
&ALPACA_API_BASE,
) )
.await .await
.unwrap() .unwrap()
@@ -92,9 +97,12 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
}; };
let assets_future = async { let assets_future = async {
database::assets::select(&config.clickhouse_client) database::assets::select(
.await &config.clickhouse_client,
.unwrap() &config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
}; };
let (mut positions, assets) = join!(positions_future, assets_future); let (mut positions, assets) = join!(positions_future, assets_future);
@@ -111,9 +119,13 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
database::assets::upsert_batch(&config.clickhouse_client, &assets) database::assets::upsert_batch(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&assets,
)
.await
.unwrap();
for position in positions.values() { for position in positions.values() {
warn!( warn!(
@@ -121,6 +133,4 @@ pub async fn rehydrate_positions(config: &Arc<Config>) {
position.symbol, position.qty position.symbol, position.qty
); );
} }
info!("Rehydrated position data.");
} }

39
src/lib/alpaca/account.rs Normal file
View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::account::Account;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/account", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Account>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

132
src/lib/alpaca/assets.rs Normal file
View File

@@ -0,0 +1,132 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{
incoming::asset::{Asset, Class},
outgoing,
};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use itertools::Itertools;
use log::warn;
use reqwest::{Client, Error};
use std::{collections::HashSet, time::Duration};
use tokio::try_join;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::asset::Asset,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/assets", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Asset>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get assets, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Asset, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!(
"https://{}.alpaca.markets/v2/assets/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Asset>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get asset, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Asset>, Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let asset = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(vec![asset]);
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let backoff_clone = backoff.clone();
let us_equity_query = outgoing::asset::Asset {
class: Some(Class::UsEquity),
..Default::default()
};
let us_equity_assets = get(
client,
rate_limiter,
&us_equity_query,
backoff_clone,
api_base,
);
let crypto_query = outgoing::asset::Asset {
class: Some(Class::Crypto),
..Default::default()
};
let crypto_assets = get(client, rate_limiter, &crypto_query, backoff, api_base);
let (us_equity_assets, crypto_assets) = try_join!(us_equity_assets, crypto_assets)?;
Ok(crypto_assets
.into_iter()
.chain(us_equity_assets)
.dedup_by(|a, b| a.symbol == b.symbol)
.filter(|asset| symbols.contains(&asset.symbol))
.collect())
}

50
src/lib/alpaca/bars.rs Normal file
View File

@@ -0,0 +1,50 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::bar::Bar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::{collections::HashMap, time::Duration};
pub const MAX_LIMIT: i64 = 10_000;
#[derive(Deserialize)]
pub struct Message {
pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
data_url: &str,
query: &outgoing::bar::Bar,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(data_url)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical bars, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,41 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::calendar::Calendar, outgoing};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::calendar::Calendar,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Calendar>, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/calendar", api_base))
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Calendar>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get calendar, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

39
src/lib/alpaca/clock.rs Normal file
View File

@@ -0,0 +1,39 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::clock::Clock;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use std::time::Duration;
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Clock, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/clock", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Clock>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get clock, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

27
src/lib/alpaca/mod.rs Normal file
View File

@@ -0,0 +1,27 @@
pub mod account;
pub mod assets;
pub mod bars;
pub mod calendar;
pub mod clock;
pub mod news;
pub mod orders;
pub mod positions;
use reqwest::StatusCode;
pub fn error_to_backoff(err: reqwest::Error) -> backoff::Error<reqwest::Error> {
if err.is_status() {
return match err.status() {
Some(StatusCode::BAD_REQUEST | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND)
| None => backoff::Error::Permanent(err),
_ => err.into(),
};
}
if err.is_builder() || err.is_request() || err.is_redirect() || err.is_decode() || err.is_body()
{
return backoff::Error::Permanent(err);
}
err.into()
}

49
src/lib/alpaca/news.rs Normal file
View File

@@ -0,0 +1,49 @@
use super::error_to_backoff;
use crate::types::alpaca::api::{incoming::news::News, outgoing, ALPACA_NEWS_DATA_API_URL};
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize;
use std::time::Duration;
pub const MAX_LIMIT: i64 = 50;
#[derive(Deserialize)]
pub struct Message {
pub news: Vec<News>,
pub next_page_token: Option<String>,
}
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::news::News,
backoff: Option<ExponentialBackoff>,
) -> Result<Message, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(ALPACA_NEWS_DATA_API_URL)
.query(query)
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Message>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get historical news, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -1,44 +1,39 @@
use crate::{ use super::error_to_backoff;
config::ALPACA_API_URL, use crate::types::alpaca::{api::outgoing, shared::order};
types::alpaca::{api::outgoing, shared},
};
use backoff::{future::retry_notify, ExponentialBackoff}; use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter; use governor::DefaultDirectRateLimiter;
use log::warn; use log::warn;
use reqwest::{Client, Error}; use reqwest::{Client, Error};
use std::time::Duration; use std::time::Duration;
pub use shared::order::Order; pub use order::Order;
pub async fn get( pub async fn get(
alpaca_client: &Client, client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter, rate_limiter: &DefaultDirectRateLimiter,
query: &outgoing::order::Order, query: &outgoing::order::Order,
backoff: Option<ExponentialBackoff>, backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Order>, Error> { ) -> Result<Vec<Order>, Error> {
retry_notify( retry_notify(
backoff.unwrap_or_default(), backoff.unwrap_or_default(),
|| async { || async {
alpaca_rate_limiter.until_ready().await; rate_limiter.until_ready().await;
alpaca_client client
.get(&format!("{}/orders", *ALPACA_API_URL)) .get(&format!("https://{}.alpaca.markets/v2/orders", api_base))
.query(query) .query(query)
.send() .send()
.await? .await
.map_err(error_to_backoff)?
.error_for_status() .error_for_status()
.map_err(|e| match e.status() { .map_err(error_to_backoff)?
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Vec<Order>>() .json::<Vec<Order>>()
.await .await
.map_err(backoff::Error::Permanent) .map_err(error_to_backoff)
}, },
|e, duration: Duration| { |e, duration: Duration| {
warn!( warn!(
"Failed to get orders, will retry in {} seconds: {}", "Failed to get orders, will retry in {} seconds: {}.",
duration.as_secs(), duration.as_secs(),
e e
); );

108
src/lib/alpaca/positions.rs Normal file
View File

@@ -0,0 +1,108 @@
use super::error_to_backoff;
use crate::types::alpaca::api::incoming::position::Position;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::Client;
use std::{collections::HashSet, time::Duration};
pub async fn get(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
client
.get(&format!("https://{}.alpaca.markets/v2/positions", api_base))
.send()
.await
.map_err(error_to_backoff)?
.error_for_status()
.map_err(error_to_backoff)?
.json::<Vec<Position>>()
.await
.map_err(error_to_backoff)
},
|e, duration: Duration| {
warn!(
"Failed to get positions, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbol(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbol: &str,
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Option<Position>, reqwest::Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
rate_limiter.until_ready().await;
let response = client
.get(&format!(
"https://{}.alpaca.markets/v2/positions/{}",
api_base, symbol
))
.send()
.await
.map_err(error_to_backoff)?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
response
.error_for_status()
.map_err(error_to_backoff)?
.json::<Position>()
.await
.map_err(error_to_backoff)
.map(Some)
},
|e, duration: Duration| {
warn!(
"Failed to get position, will retry in {} seconds: {}.",
duration.as_secs(),
e
);
},
)
.await
}
pub async fn get_by_symbols(
client: &Client,
rate_limiter: &DefaultDirectRateLimiter,
symbols: &[String],
backoff: Option<ExponentialBackoff>,
api_base: &str,
) -> Result<Vec<Position>, reqwest::Error> {
if symbols.is_empty() {
return Ok(vec![]);
}
if symbols.len() == 1 {
let position = get_by_symbol(client, rate_limiter, &symbols[0], backoff, api_base).await?;
return Ok(position.into_iter().collect());
}
let symbols = symbols.iter().collect::<HashSet<_>>();
let positions = get(client, rate_limiter, backoff, api_base).await?;
Ok(positions
.into_iter()
.filter(|position| symbols.contains(&position.symbol))
.collect())
}

View File

@@ -1,8 +1,11 @@
use std::sync::Arc;
use crate::{ use crate::{
delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch,
}; };
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use serde::Serialize; use serde::Serialize;
use tokio::sync::Semaphore;
select!(Asset, "assets"); select!(Asset, "assets");
select_where_symbol!(Asset, "assets"); select_where_symbol!(Asset, "assets");
@@ -11,14 +14,16 @@ delete_where_symbols!("assets");
optimize!("assets"); optimize!("assets");
pub async fn update_status_where_symbol<T>( pub async fn update_status_where_symbol<T>(
clickhouse_client: &Client, client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T, symbol: &T,
status: bool, status: bool,
) -> Result<(), Error> ) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?") .query("ALTER TABLE assets UPDATE status = ? WHERE symbol = ?")
.bind(status) .bind(status)
.bind(symbol) .bind(symbol)
@@ -27,14 +32,16 @@ where
} }
pub async fn update_qty_where_symbol<T>( pub async fn update_qty_where_symbol<T>(
clickhouse_client: &Client, client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbol: &T, symbol: &T,
qty: f64, qty: f64,
) -> Result<(), Error> ) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?")
.bind(qty) .bind(qty)
.bind(symbol) .bind(symbol)

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_bars");
upsert_batch!(Backfill, "backfills_bars");
delete_where_symbols!("backfills_bars");
cleanup!("backfills_bars");
optimize!("backfills_bars");
set_fresh_where_symbols!("backfills_bars");

View File

@@ -0,0 +1,11 @@
use crate::{
cleanup, delete_where_symbols, optimize, select_where_symbols, set_fresh_where_symbols,
types::Backfill, upsert_batch,
};
select_where_symbols!(Backfill, "backfills_news");
upsert_batch!(Backfill, "backfills_news");
delete_where_symbols!("backfills_news");
cleanup!("backfills_news");
optimize!("backfills_news");
set_fresh_where_symbols!("backfills_news");

21
src/lib/database/bars.rs Normal file
View File

@@ -0,0 +1,21 @@
use std::sync::Arc;
use crate::{delete_where_symbols, optimize, types::Bar, upsert, upsert_batch};
use clickhouse::Client;
use tokio::sync::Semaphore;
upsert!(Bar, "bars");
upsert_batch!(Bar, "bars");
delete_where_symbols!("bars");
optimize!("bars");
pub async fn cleanup(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets) OR symbol NOT IN (SELECT symbol FROM backfills_bars)")
.execute()
.await
}

View File

@@ -1,16 +1,19 @@
use std::sync::Arc;
use crate::{optimize, types::Calendar}; use crate::{optimize, types::Calendar};
use clickhouse::error::Error; use clickhouse::{error::Error, Client};
use tokio::try_join; use tokio::{sync::Semaphore, try_join};
optimize!("calendar"); optimize!("calendar");
pub async fn upsert_batch_and_delete<'a, T>( pub async fn upsert_batch_and_delete<'a, I>(
client: &clickhouse::Client, client: &Client,
records: T, concurrency_limiter: &Arc<Semaphore>,
records: I,
) -> Result<(), Error> ) -> Result<(), Error>
where where
T: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone, I: IntoIterator<Item = &'a Calendar> + Send + Sync + Clone,
T::IntoIter: Send, I::IntoIter: Send,
{ {
let upsert_future = async { let upsert_future = async {
let mut insert = client.insert("calendar")?; let mut insert = client.insert("calendar")?;
@@ -34,5 +37,6 @@ where
.await .await
}; };
let _ = concurrency_limiter.acquire_many(2).await.unwrap();
try_join!(upsert_future, delete_future).map(|_| ()) try_join!(upsert_future, delete_future).map(|_| ())
} }

223
src/lib/database/mod.rs Normal file
View File

@@ -0,0 +1,223 @@
pub mod assets;
pub mod backfills_bars;
pub mod backfills_news;
pub mod bars;
pub mod calendar;
pub mod news;
pub mod orders;
use clickhouse::{error::Error, Client};
use tokio::try_join;
#[macro_export]
macro_rules! select {
($record:ty, $table_name:expr) => {
pub async fn select(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<Vec<$record>, clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("SELECT ?fields FROM {} FINAL", $table_name))
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbol {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbol<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbol: &T,
) -> Result<Option<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol = ?",
$table_name
))
.bind(symbol)
.fetch_optional::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! select_where_symbols {
($record:ty, $table_name:expr) => {
pub async fn select_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<Vec<$record>, clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"SELECT ?fields FROM {} FINAL WHERE symbol IN ?",
$table_name
))
.bind(symbols)
.fetch_all::<$record>()
.await
}
};
}
#[macro_export]
macro_rules! upsert {
($record:ty, $table_name:expr) => {
pub async fn upsert(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
record: &$record,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
insert.write(record).await?;
insert.end().await
}
};
}
#[macro_export]
macro_rules! upsert_batch {
($record:ty, $table_name:expr) => {
pub async fn upsert_batch<'a, I>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
records: I,
) -> Result<(), clickhouse::error::Error>
where
I: IntoIterator<Item = &'a $record> + Send + Sync,
I::IntoIter: Send,
{
let _ = concurrency_limiter.acquire().await.unwrap();
let mut insert = client.insert($table_name)?;
for record in records {
insert.write(record).await?;
}
insert.end().await
}
};
}
#[macro_export]
macro_rules! delete_where_symbols {
($table_name:expr) => {
pub async fn delete_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name))
.bind(symbols)
.execute()
.await
}
};
}
#[macro_export]
macro_rules! cleanup {
($table_name:expr) => {
pub async fn cleanup(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)",
$table_name
))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! optimize {
($table_name:expr) => {
pub async fn optimize(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), clickhouse::error::Error> {
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!("OPTIMIZE TABLE {} FINAL", $table_name))
.execute()
.await
}
};
}
#[macro_export]
macro_rules! set_fresh_where_symbols {
($table_name:expr) => {
pub async fn set_fresh_where_symbols<T>(
client: &clickhouse::Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
fresh: bool,
symbols: &[T],
) -> Result<(), clickhouse::error::Error>
where
T: AsRef<str> + serde::Serialize + Send + Sync,
{
let _ = concurrency_limiter.acquire().await.unwrap();
client
.query(&format!(
"ALTER TABLE {} UPDATE fresh = ? WHERE symbol IN ?",
$table_name
))
.bind(fresh)
.bind(symbols)
.execute()
.await
}
};
}
pub async fn cleanup_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
bars::cleanup(clickhouse_client, concurrency_limiter),
news::cleanup(clickhouse_client, concurrency_limiter),
backfills_bars::cleanup(clickhouse_client, concurrency_limiter),
backfills_news::cleanup(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}
pub async fn optimize_all(
clickhouse_client: &Client,
concurrency_limiter: &std::sync::Arc<tokio::sync::Semaphore>,
) -> Result<(), Error> {
try_join!(
assets::optimize(clickhouse_client, concurrency_limiter),
bars::optimize(clickhouse_client, concurrency_limiter),
news::optimize(clickhouse_client, concurrency_limiter),
backfills_bars::optimize(clickhouse_client, concurrency_limiter),
backfills_news::optimize(clickhouse_client, concurrency_limiter),
orders::optimize(clickhouse_client, concurrency_limiter),
calendar::optimize(clickhouse_client, concurrency_limiter)
)
.map(|_| ())
}

View File

@@ -1,24 +1,33 @@
use std::sync::Arc;
use crate::{optimize, types::News, upsert, upsert_batch}; use crate::{optimize, types::News, upsert, upsert_batch};
use clickhouse::{error::Error, Client}; use clickhouse::{error::Error, Client};
use serde::Serialize; use serde::Serialize;
use tokio::sync::Semaphore;
upsert!(News, "news"); upsert!(News, "news");
upsert_batch!(News, "news"); upsert_batch!(News, "news");
optimize!("news"); optimize!("news");
pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> pub async fn delete_where_symbols<T>(
client: &Client,
concurrency_limiter: &Arc<Semaphore>,
symbols: &[T],
) -> Result<(), Error>
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))") .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))")
.bind(symbols) .bind(symbols)
.execute() .execute()
.await .await
} }
pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { pub async fn cleanup(client: &Client, concurrency_limiter: &Arc<Semaphore>) -> Result<(), Error> {
clickhouse_client let _ = concurrency_limiter.acquire().await.unwrap();
client
.query( .query(
"DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))", "DELETE FROM news WHERE NOT hasAny(symbols, (SELECT groupArray(symbol) FROM assets))",
) )

4
src/lib/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod alpaca;
pub mod database;
pub mod types;
pub mod utils;

View File

@@ -1,13 +1,7 @@
use crate::config::ALPACA_API_URL;
use backoff::{future::retry_notify, ExponentialBackoff};
use governor::DefaultDirectRateLimiter;
use log::warn;
use reqwest::{Client, Error};
use serde::Deserialize; use serde::Deserialize;
use serde_aux::field_attributes::{ use serde_aux::field_attributes::{
deserialize_number_from_string, deserialize_option_number_from_string, deserialize_number_from_string, deserialize_option_number_from_string,
}; };
use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
use uuid::Uuid; use uuid::Uuid;
@@ -79,38 +73,3 @@ pub struct Account {
#[serde(deserialize_with = "deserialize_number_from_string")] #[serde(deserialize_with = "deserialize_number_from_string")]
pub regt_buying_power: f64, pub regt_buying_power: f64,
} }
pub async fn get(
alpaca_client: &Client,
alpaca_rate_limiter: &DefaultDirectRateLimiter,
backoff: Option<ExponentialBackoff>,
) -> Result<Account, Error> {
retry_notify(
backoff.unwrap_or_default(),
|| async {
alpaca_rate_limiter.until_ready().await;
alpaca_client
.get(&format!("{}/account", *ALPACA_API_URL))
.send()
.await?
.error_for_status()
.map_err(|e| match e.status() {
Some(reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::FORBIDDEN) => {
backoff::Error::Permanent(e)
}
_ => e.into(),
})?
.json::<Account>()
.await
.map_err(backoff::Error::Permanent)
},
|e, duration: Duration| {
warn!(
"Failed to get account, will retry in {} seconds: {}",
duration.as_secs(),
e
);
},
)
.await
}

View File

@@ -0,0 +1,39 @@
use super::position::Position;
use crate::types::{self, alpaca::shared::asset};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_option_number_from_string;
use uuid::Uuid;
pub use asset::{Class, Exchange, Status};
#[allow(clippy::struct_excessive_bools)]
#[derive(Deserialize, Clone)]
pub struct Asset {
pub id: Uuid,
pub class: Class,
pub exchange: Exchange,
pub symbol: String,
pub name: String,
pub status: Status,
pub tradable: bool,
pub marginable: bool,
pub shortable: bool,
pub easy_to_borrow: bool,
pub fractionable: bool,
#[serde(deserialize_with = "deserialize_option_number_from_string")]
pub maintenance_margin_requirement: Option<f32>,
pub attributes: Option<Vec<String>>,
}
impl From<(Asset, Option<Position>)> for types::Asset {
fn from((asset, position): (Asset, Option<Position>)) -> Self {
Self {
symbol: asset.symbol,
class: asset.class.into(),
exchange: asset.exchange.into(),
status: asset.status.into(),
time_added: time::OffsetDateTime::now_utc(),
qty: position.map(|position| position.qty).unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,40 @@
use crate::types;
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Bar {
#[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")]
pub time: OffsetDateTime,
#[serde(rename = "o")]
pub open: f64,
#[serde(rename = "h")]
pub high: f64,
#[serde(rename = "l")]
pub low: f64,
#[serde(rename = "c")]
pub close: f64,
#[serde(rename = "v")]
pub volume: f64,
#[serde(rename = "n")]
pub trades: i64,
#[serde(rename = "vw")]
pub vwap: f64,
}
impl From<(Bar, String)> for types::Bar {
fn from((bar, symbol): (Bar, String)) -> Self {
Self {
time: bar.time,
symbol,
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
trades: bar.trades,
vwap: bar.vwap,
}
}
}

View File

@@ -0,0 +1,26 @@
use crate::{
types,
utils::{de, time::EST_OFFSET},
};
use serde::Deserialize;
use time::{Date, OffsetDateTime, Time};
#[derive(Deserialize)]
pub struct Calendar {
pub date: Date,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub open: Time,
#[serde(deserialize_with = "de::human_time_hh_mm")]
pub close: Time,
pub settlement_date: Date,
}
impl From<Calendar> for types::Calendar {
fn from(calendar: Calendar) -> Self {
Self {
date: calendar.date,
open: OffsetDateTime::new_in_offset(calendar.date, calendar.open, *EST_OFFSET),
close: OffsetDateTime::new_in_offset(calendar.date, calendar.close, *EST_OFFSET),
}
}
}

View File

@@ -0,0 +1,13 @@
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
pub struct Clock {
#[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime,
pub is_open: bool,
#[serde(with = "time::serde::rfc3339")]
pub next_open: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub next_close: OffsetDateTime,
}

View File

@@ -0,0 +1,57 @@
use crate::{
types::{self, alpaca::shared::news::strip},
utils::de,
};
use serde::Deserialize;
use time::OffsetDateTime;
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageSize {
Thumb,
Small,
Large,
}
#[derive(Deserialize)]
pub struct Image {
pub size: ImageSize,
pub url: String,
}
#[derive(Deserialize)]
pub struct News {
pub id: i64,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "created_at")]
pub time_created: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
#[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime,
#[serde(deserialize_with = "de::add_slash_to_symbols")]
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: Option<String>,
pub images: Vec<Image>,
}
impl From<News> for types::News {
fn from(news: News) -> Self {
Self {
id: news.id,
time_created: news.time_created,
time_updated: news.time_updated,
symbols: news.symbols,
headline: strip(&news.headline),
author: strip(&news.author),
source: strip(&news.source),
summary: news.summary,
content: news.content,
url: news.url.unwrap_or_default(),
}
}
}

View File

@@ -0,0 +1,3 @@
use crate::types::alpaca::shared::order;
pub use order::{Order, Side};

View File

@@ -0,0 +1,61 @@
use crate::{
types::alpaca::api::incoming::{
asset::{Class, Exchange},
order,
},
utils::de,
};
use serde::Deserialize;
use serde_aux::field_attributes::deserialize_number_from_string;
use uuid::Uuid;
#[derive(Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Side {
Long,
Short,
}
impl From<Side> for order::Side {
fn from(side: Side) -> Self {
match side {
Side::Long => Self::Buy,
Side::Short => Self::Sell,
}
}
}
#[derive(Deserialize, Clone)]
pub struct Position {
pub asset_id: Uuid,
#[serde(deserialize_with = "de::add_slash_to_symbol")]
pub symbol: String,
pub exchange: Exchange,
pub asset_class: Class,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub avg_entry_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub qty_available: f64,
pub side: Side,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub market_value: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub cost_basis: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_pl: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub unrealized_intraday_plpc: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub current_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub lastday_price: f64,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub change_today: f64,
pub asset_marginable: bool,
}

View File

@@ -0,0 +1,6 @@
pub mod incoming;
pub mod outgoing;
pub const ALPACA_US_EQUITY_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars";
pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars";
pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news";

View File

@@ -0,0 +1,23 @@
use crate::types::alpaca::shared::asset;
use serde::Serialize;
pub use asset::{Class, Exchange, Status};
#[derive(Serialize)]
pub struct Asset {
pub status: Option<Status>,
pub class: Option<Class>,
pub exchange: Option<Exchange>,
pub attributes: Option<Vec<String>>,
}
impl Default for Asset {
fn default() -> Self {
Self {
status: None,
class: Some(Class::UsEquity),
exchange: None,
attributes: None,
}
}
}

View File

@@ -1,12 +1,14 @@
use crate::{ use crate::{
config::ALPACA_SOURCE, alpaca::bars::MAX_LIMIT,
types::alpaca::shared::{Sort, Source}, types::alpaca::shared,
utils::{ser, ONE_MINUTE}, utils::{ser, ONE_MINUTE},
}; };
use serde::Serialize; use serde::Serialize;
use std::time::Duration; use std::time::Duration;
use time::OffsetDateTime; use time::OffsetDateTime;
pub use shared::{Sort, Source};
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[allow(dead_code)] #[allow(dead_code)]
@@ -53,10 +55,10 @@ impl Default for UsEquity {
timeframe: ONE_MINUTE, timeframe: ONE_MINUTE,
start: None, start: None,
end: None, end: None,
limit: Some(10000), limit: Some(MAX_LIMIT),
adjustment: Some(Adjustment::All), adjustment: Some(Adjustment::All),
asof: None, asof: None,
feed: Some(*ALPACA_SOURCE), feed: Some(Source::Iex),
currency: None, currency: None,
page_token: None, page_token: None,
sort: Some(Sort::Asc), sort: Some(Sort::Asc),
@@ -91,7 +93,7 @@ impl Default for Crypto {
timeframe: ONE_MINUTE, timeframe: ONE_MINUTE,
start: None, start: None,
end: None, end: None,
limit: Some(10000), limit: Some(MAX_LIMIT),
page_token: None, page_token: None,
sort: Some(Sort::Asc), sort: Some(Sort::Asc),
} }

View File

@@ -1,3 +1,4 @@
pub mod asset;
pub mod bar; pub mod bar;
pub mod calendar; pub mod calendar;
pub mod news; pub mod news;

View File

@@ -1,10 +1,10 @@
use crate::{types::alpaca::shared::Sort, utils::ser}; use crate::{alpaca::news::MAX_LIMIT, types::alpaca::shared::Sort, utils::ser};
use serde::Serialize; use serde::Serialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Serialize)] #[derive(Serialize)]
pub struct News { pub struct News {
#[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] #[serde(serialize_with = "ser::remove_slash_and_join_symbols")]
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "time::serde::rfc3339::option")] #[serde(with = "time::serde::rfc3339::option")]
@@ -30,7 +30,7 @@ impl Default for News {
symbols: vec![], symbols: vec![],
start: None, start: None,
end: None, end: None,
limit: Some(50), limit: Some(MAX_LIMIT),
include_content: Some(true), include_content: Some(true),
exclude_contentless: Some(false), exclude_contentless: Some(false),
page_token: None, page_token: None,

View File

@@ -1,10 +1,12 @@
use crate::{ use crate::{
types::alpaca::shared::{order::Side, Sort}, types::alpaca::shared::{order, Sort},
utils::ser, utils::ser,
}; };
use serde::Serialize; use serde::Serialize;
use time::OffsetDateTime; use time::OffsetDateTime;
pub use order::Side;
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[allow(dead_code)] #[allow(dead_code)]

View File

@@ -1,7 +1,7 @@
use crate::{impl_from_enum, types}; use crate::{impl_from_enum, types};
use serde::Deserialize; use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Copy, Debug, PartialEq, Eq)] #[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Class { pub enum Class {
UsEquity, UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto); impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Deserialize)] #[derive(Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Exchange { pub enum Exchange {
Amex, Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto Crypto
); );
#[derive(Deserialize)] #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Status { pub enum Status {
Active, Active,

View File

@@ -1,4 +1,3 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
@@ -7,12 +6,10 @@ lazy_static! {
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap(); static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
} }
pub fn normalize_html_content(content: &str) -> String { pub fn strip(content: &str) -> String {
let content = content.replace('\n', " "); let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, ""); let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " "); let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim(); let content = content.trim();
content.to_string() content.to_string()
} }

View File

@@ -1,5 +1,5 @@
use crate::{ use crate::{
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News}, types::{alpaca::shared::news::strip, News},
utils::de, utils::de,
}; };
use serde::Deserialize; use serde::Deserialize;
@@ -31,13 +31,11 @@ impl From<Message> for News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: normalize_html_content(&news.headline), headline: strip(&news.headline),
author: normalize_html_content(&news.author), author: strip(&news.author),
source: normalize_html_content(&news.source), source: strip(&news.source),
summary: normalize_html_content(&news.summary), summary: news.summary,
content: normalize_html_content(&news.content), content: news.content,
sentiment: Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(), url: news.url.unwrap_or_default(),
} }
} }

View File

@@ -1,10 +1,7 @@
pub mod incoming; pub mod incoming;
pub mod outgoing; pub mod outgoing;
use crate::{ use crate::types::alpaca::websocket;
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic; use core::panic;
use futures_util::{ use futures_util::{
stream::{SplitSink, SplitStream}, stream::{SplitSink, SplitStream},
@@ -17,6 +14,8 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate( pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) { ) {
match stream.next().await.unwrap().unwrap() { match stream.next().await.unwrap().unwrap() {
Message::Text(data) Message::Text(data)
@@ -32,8 +31,8 @@ pub async fn authenticate(
sink.send(Message::Text( sink.send(Message::Text(
to_string(&websocket::data::outgoing::Message::Auth( to_string(&websocket::data::outgoing::Message::Auth(
websocket::auth::Message { websocket::auth::Message {
key: (*ALPACA_API_KEY).clone(), key: api_key,
secret: (*ALPACA_API_SECRET).clone(), secret: api_secret,
}, },
)) ))
.unwrap(), .unwrap(),

View File

@@ -1,4 +1,5 @@
use crate::utils::ser; use crate::utils::ser;
use nonempty::NonEmpty;
use serde::Serialize; use serde::Serialize;
#[derive(Serialize)] #[derive(Serialize)]
@@ -6,14 +7,14 @@ use serde::Serialize;
pub enum Market { pub enum Market {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
UsEquity { UsEquity {
bars: Vec<String>, bars: NonEmpty<String>,
updated_bars: Vec<String>, updated_bars: NonEmpty<String>,
statuses: Vec<String>, statuses: NonEmpty<String>,
}, },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Crypto { Crypto {
bars: Vec<String>, bars: NonEmpty<String>,
updated_bars: Vec<String>, updated_bars: NonEmpty<String>,
}, },
} }
@@ -23,12 +24,12 @@ pub enum Message {
Market(Market), Market(Market),
News { News {
#[serde(serialize_with = "ser::remove_slash_from_symbols")] #[serde(serialize_with = "ser::remove_slash_from_symbols")]
news: Vec<String>, news: NonEmpty<String>,
}, },
} }
impl Message { impl Message {
pub fn new_market_us_equity(symbols: Vec<String>) -> Self { pub fn new_market_us_equity(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::UsEquity { Self::Market(Market::UsEquity {
bars: symbols.clone(), bars: symbols.clone(),
updated_bars: symbols.clone(), updated_bars: symbols.clone(),
@@ -36,14 +37,14 @@ impl Message {
}) })
} }
pub fn new_market_crypto(symbols: Vec<String>) -> Self { pub fn new_market_crypto(symbols: NonEmpty<String>) -> Self {
Self::Market(Market::Crypto { Self::Market(Market::Crypto {
bars: symbols.clone(), bars: symbols.clone(),
updated_bars: symbols, updated_bars: symbols,
}) })
} }
pub fn new_news(symbols: Vec<String>) -> Self { pub fn new_news(symbols: NonEmpty<String>) -> Self {
Self::News { news: symbols } Self::News { news: symbols }
} }
} }

View File

@@ -0,0 +1,8 @@
pub mod auth;
pub mod data;
pub mod trading;
pub const ALPACA_US_EQUITY_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2";
pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str =
"wss://stream.data.alpaca.markets/v1beta3/crypto/us";
pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news";

View File

@@ -1,10 +1,10 @@
use crate::types::alpaca::shared; use crate::types::alpaca::shared::order;
use serde::Deserialize; use serde::Deserialize;
use serde_aux::prelude::deserialize_number_from_string; use serde_aux::prelude::deserialize_number_from_string;
use time::OffsetDateTime; use time::OffsetDateTime;
use uuid::Uuid; use uuid::Uuid;
pub use shared::order::Order; pub use order::Order;
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]

View File

@@ -1,10 +1,7 @@
pub mod incoming; pub mod incoming;
pub mod outgoing; pub mod outgoing;
use crate::{ use crate::types::alpaca::websocket;
config::{ALPACA_API_KEY, ALPACA_API_SECRET},
types::alpaca::websocket,
};
use core::panic; use core::panic;
use futures_util::{ use futures_util::{
stream::{SplitSink, SplitStream}, stream::{SplitSink, SplitStream},
@@ -17,12 +14,14 @@ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub async fn authenticate( pub async fn authenticate(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, stream: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
api_key: String,
api_secret: String,
) { ) {
sink.send(Message::Text( sink.send(Message::Text(
to_string(&websocket::trading::outgoing::Message::Auth( to_string(&websocket::trading::outgoing::Message::Auth(
websocket::auth::Message { websocket::auth::Message {
key: (*ALPACA_API_KEY).clone(), key: api_key,
secret: (*ALPACA_API_SECRET).clone(), secret: api_secret,
}, },
)) ))
.unwrap(), .unwrap(),

11
src/lib/types/backfill.rs Normal file
View File

@@ -0,0 +1,11 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)]
pub struct Backfill {
pub symbol: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time: OffsetDateTime,
pub fresh: bool,
}

19
src/lib/types/news.rs Normal file
View File

@@ -0,0 +1,19 @@
use clickhouse::Row;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News {
pub id: i64,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_created: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime,
pub symbols: Vec<String>,
pub headline: String,
pub author: String,
pub source: String,
pub summary: String,
pub content: String,
pub url: String,
}

View File

@@ -8,7 +8,8 @@ use std::fmt;
use time::{format_description::OwnedFormatItem, macros::format_description, Time}; use time::{format_description::OwnedFormatItem, macros::format_description, Time};
lazy_static! { lazy_static! {
static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap(); // This *will* break in the future if a crypto pair with one letter is added
static ref RE_SLASH: Regex = Regex::new(r"^(.{2,})(BTC|USD.?)$").unwrap();
static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into(); static ref FMT_HH_MM: OwnedFormatItem = format_description!("[hour]:[minute]").into();
} }

View File

@@ -58,12 +58,13 @@ where
} }
} }
pub fn remove_slash_from_symbols<S>(pairs: &[String], serializer: S) -> Result<S::Ok, S::Error> pub fn remove_slash_from_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
I: IntoIterator<Item = &'a String>,
{ {
let symbols = pairs let symbols = symbols
.iter() .into_iter()
.map(|pair| remove_slash(pair)) .map(|pair| remove_slash(pair))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -74,15 +75,13 @@ where
seq.end() seq.end()
} }
pub fn remove_slash_from_pairs_join_symbols<S>( pub fn remove_slash_and_join_symbols<'a, S, I>(symbols: I, serializer: S) -> Result<S::Ok, S::Error>
symbols: &[String],
serializer: S,
) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
I: IntoIterator<Item = &'a String>,
{ {
let symbols = symbols let symbols = symbols
.iter() .into_iter()
.map(|symbol| remove_slash(symbol)) .map(|symbol| remove_slash(symbol))
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@@ -3,16 +3,19 @@
#![feature(hash_extract_if)] #![feature(hash_extract_if)]
mod config; mod config;
mod database;
mod init; mod init;
mod routes; mod routes;
mod threads; mod threads;
mod types;
mod utils;
use config::Config; use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE,
CLICKHOUSE_BATCH_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
};
use dotenv::dotenv; use dotenv::dotenv;
use log::info;
use log4rs::config::Deserializers; use log4rs::config::Deserializers;
use nonempty::NonEmpty;
use qrust::{create_send_await, database};
use tokio::{join, spawn, sync::mpsc, try_join}; use tokio::{join, spawn, sync::mpsc, try_join};
#[tokio::main] #[tokio::main]
@@ -21,18 +24,62 @@ async fn main() {
log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap();
let config = Config::arc_from_env(); let config = Config::arc_from_env();
let _ = *ALPACA_MODE;
let _ = *ALPACA_API_BASE;
let _ = *ALPACA_SOURCE;
let _ = *CLICKHOUSE_BATCH_BARS_SIZE;
let _ = *CLICKHOUSE_BATCH_NEWS_SIZE;
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
info!("Marking all assets as stale.");
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap()
.into_iter()
.map(|asset| (asset.symbol, asset.class))
.collect::<Vec<_>>();
let symbols = assets.iter().map(|(symbol, _)| symbol).collect::<Vec<_>>();
try_join!( try_join!(
database::backfills_bars::unfresh(&config.clickhouse_client), database::backfills_bars::set_fresh_where_symbols(
database::backfills_news::unfresh(&config.clickhouse_client) &config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
),
database::backfills_news::set_fresh_where_symbols(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
false,
&symbols
)
) )
.unwrap(); .unwrap();
database::cleanup_all(&config.clickhouse_client) info!("Cleaning up database.");
.await
.unwrap(); database::cleanup_all(
database::optimize_all(&config.clickhouse_client) &config.clickhouse_client,
.await &config.clickhouse_concurrency_limiter,
.unwrap(); )
.await
.unwrap();
info!("Optimizing database.");
database::optimize_all(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
info!("Rehydrating account data.");
init::check_account(&config).await; init::check_account(&config).await;
join!( join!(
@@ -40,6 +87,8 @@ async fn main() {
init::rehydrate_positions(&config) init::rehydrate_positions(&config)
); );
info!("Starting threads.");
spawn(threads::trading::run(config.clone())); spawn(threads::trading::run(config.clone()));
let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100); let (data_sender, data_receiver) = mpsc::channel::<threads::data::Message>(100);
@@ -53,19 +102,14 @@ async fn main() {
spawn(threads::clock::run(config.clone(), clock_sender)); spawn(threads::clock::run(config.clone(), clock_sender));
let assets = database::assets::select(&config.clickhouse_client) if let Some(assets) = NonEmpty::from_vec(assets) {
.await create_send_await!(
.unwrap() data_sender,
.into_iter() threads::data::Message::new,
.map(|asset| (asset.symbol, asset.class)) threads::data::Action::Enable,
.collect::<Vec<_>>(); assets
);
create_send_await!( }
data_sender,
threads::data::Message::new,
threads::data::Action::Enable,
assets
);
routes::run(config, data_sender).await; routes::run(config, data_sender).await;
} }

View File

@@ -1,20 +1,30 @@
use crate::{ use crate::{
config::Config, config::{Config, ALPACA_API_BASE},
create_send_await, database, threads, create_send_await, database, threads,
types::{alpaca, Asset},
}; };
use axum::{extract::Path, Extension, Json}; use axum::{extract::Path, Extension, Json};
use http::StatusCode; use http::StatusCode;
use serde::Deserialize; use nonempty::{nonempty, NonEmpty};
use std::sync::Arc; use qrust::{
alpaca,
types::{self, Asset},
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub async fn get( pub async fn get(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> { ) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(&config.clickhouse_client) let assets = database::assets::select(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; &config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets))) Ok((StatusCode::OK, Json(assets)))
} }
@@ -23,9 +33,13 @@ pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> { ) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) let asset = database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; &config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| { asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset))) Ok((StatusCode::OK, Json(asset)))
@@ -33,28 +47,115 @@ pub async fn get_where_symbol(
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct AddAssetRequest { pub struct AddAssetsRequest {
symbol: String, symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
} }
pub async fn add( pub async fn add(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetRequest>, Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::assets::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == types::alpaca::api::incoming::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
if let Some(assets) = NonEmpty::from_vec(assets.clone()) {
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets
);
}
Ok((
StatusCode::OK,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> { ) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(&config.clickhouse_client, &request.symbol) if database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? &config.clickhouse_concurrency_limiter,
.is_some() &symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{ {
return Err(StatusCode::CONFLICT); return Err(StatusCode::CONFLICT);
} }
let asset = alpaca::api::incoming::asset::get_by_symbol( let asset = alpaca::assets::get_by_symbol(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&request.symbol, &symbol,
None, None,
&ALPACA_API_BASE,
) )
.await .await
.map_err(|e| { .map_err(|e| {
@@ -64,7 +165,10 @@ pub async fn add(
}) })
})?; })?;
if !asset.tradable || !asset.fractionable { if asset.status != types::alpaca::api::incoming::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN); return Err(StatusCode::FORBIDDEN);
} }
@@ -72,7 +176,7 @@ pub async fn add(
data_sender, data_sender,
threads::data::Message::new, threads::data::Message::new,
threads::data::Action::Add, threads::data::Action::Add,
vec![(asset.symbol, asset.class.into())] nonempty![(asset.symbol, asset.class.into())]
); );
Ok(StatusCode::CREATED) Ok(StatusCode::CREATED)
@@ -83,16 +187,20 @@ pub async fn delete(
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>, Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>, Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> { ) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(&config.clickhouse_client, &symbol) let asset = database::assets::select_where_symbol(
.await &config.clickhouse_client,
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? &config.clickhouse_concurrency_limiter,
.ok_or(StatusCode::NOT_FOUND)?; &symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!( create_send_await!(
data_sender, data_sender,
threads::data::Message::new, threads::data::Message::new,
threads::data::Action::Remove, threads::data::Action::Remove,
vec![(asset.symbol, asset.class)] nonempty![(asset.symbol, asset.class)]
); );
Ok(StatusCode::NO_CONTENT) Ok(StatusCode::NO_CONTENT)

View File

@@ -16,6 +16,7 @@ pub async fn run(config: Arc<Config>, data_sender: mpsc::Sender<threads::data::M
.route("/assets", get(assets::get)) .route("/assets", get(assets::get))
.route("/assets/:symbol", get(assets::get_where_symbol)) .route("/assets/:symbol", get(assets::get_where_symbol))
.route("/assets", post(assets::add)) .route("/assets", post(assets::add))
.route("/assets/:symbol", post(assets::add_symbol))
.route("/assets/:symbol", delete(assets::delete)) .route("/assets/:symbol", delete(assets::delete))
.layer(Extension(config)) .layer(Extension(config))
.layer(Extension(data_sender)); .layer(Extension(data_sender));

View File

@@ -1,10 +1,13 @@
use crate::{ use crate::{
config::Config, config::{Config, ALPACA_API_BASE},
database, database,
types::{alpaca, Calendar},
utils::{backoff, duration_until},
}; };
use log::info; use log::info;
use qrust::{
alpaca,
types::{self, Calendar},
utils::{backoff, duration_until},
};
use std::sync::Arc; use std::sync::Arc;
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{join, sync::mpsc, time::sleep}; use tokio::{join, sync::mpsc, time::sleep};
@@ -19,8 +22,8 @@ pub struct Message {
pub next_switch: OffsetDateTime, pub next_switch: OffsetDateTime,
} }
impl From<alpaca::api::incoming::clock::Clock> for Message { impl From<types::alpaca::api::incoming::clock::Clock> for Message {
fn from(clock: alpaca::api::incoming::clock::Clock) -> Self { fn from(clock: types::alpaca::api::incoming::clock::Clock) -> Self {
if clock.is_open { if clock.is_open {
Self { Self {
status: Status::Open, status: Status::Open,
@@ -38,21 +41,23 @@ impl From<alpaca::api::incoming::clock::Clock> for Message {
pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) { pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
loop { loop {
let clock_future = async { let clock_future = async {
alpaca::api::incoming::clock::get( alpaca::clock::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
Some(backoff::infinite()), Some(backoff::infinite()),
&ALPACA_API_BASE,
) )
.await .await
.unwrap() .unwrap()
}; };
let calendar_future = async { let calendar_future = async {
alpaca::api::incoming::calendar::get( alpaca::calendar::get(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&alpaca::api::outgoing::calendar::Calendar::default(), &types::alpaca::api::outgoing::calendar::Calendar::default(),
Some(backoff::infinite()), Some(backoff::infinite()),
&ALPACA_API_BASE,
) )
.await .await
.unwrap() .unwrap()
@@ -74,9 +79,13 @@ pub async fn run(config: Arc<Config>, sender: mpsc::Sender<Message>) {
let sleep_future = sleep(sleep_until); let sleep_future = sleep(sleep_until);
let calendar_future = async { let calendar_future = async {
database::calendar::upsert_batch_and_delete(&config.clickhouse_client, &calendar) database::calendar::upsert_batch_and_delete(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&calendar,
)
.await
.unwrap();
}; };
join!(sleep_future, calendar_future); join!(sleep_future, calendar_future);

View File

@@ -1,413 +0,0 @@
use super::ThreadType;
use crate::{
config::{
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
MAX_BERT_INPUTS,
},
database,
types::{
alpaca::{self, shared::Source},
news::Prediction,
Backfill, Bar, Class, News,
},
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND},
};
use async_trait::async_trait;
use futures_util::future::join_all;
use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::{block_in_place, JoinHandle},
time::sleep,
try_join,
};
pub enum Action {
Backfill,
Purge,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Backfill),
super::Action::Remove => Some(Action::Purge),
super::Action::Disable => None,
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime);
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime);
fn log_string(&self) -> &'static str;
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(HashMap::new()));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_backfill_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_backfill_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
message: Message,
) {
let mut backfill_jobs = backfill_jobs.lock().await;
match message.action {
Some(Action::Backfill) => {
let log_string = handler.log_string();
for symbol in message.symbols {
if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
}
let handler = handler.clone();
backfill_jobs.insert(
symbol.clone(),
spawn(async move {
let fetch_from = match handler
.select_latest_backfill(symbol.clone())
.await
.unwrap()
{
Some(latest_backfill) => latest_backfill.time + ONE_SECOND,
None => OffsetDateTime::UNIX_EPOCH,
};
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
handler.queue_backfill(&symbol, fetch_to).await;
handler.backfill(symbol, fetch_from, fetch_to).await;
}),
);
}
}
Some(Action::Purge) => {
for symbol in &message.symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
if !job.is_finished() {
job.abort();
}
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&message.symbols),
handler.delete_data(&message.symbols)
)
.unwrap();
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarHandler {
config: Arc<Config>,
data_url: &'static str,
api_query_constructor: fn(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar,
}
fn us_equity_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
fn crypto_query_constructor(
symbol: String,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> alpaca::api::outgoing::bar::Bar {
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
symbols: vec![symbol],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
..Default::default()
})
}
#[async_trait]
impl Handler for BarHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
if *ALPACA_SOURCE == Source::Iex {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing bar backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling bars for {}.", symbol);
let mut bars = vec![];
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::bar::get_historical(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbol.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await
else {
error!("Failed to backfill bars for {}.", symbol);
return;
};
message.bars.into_iter().for_each(|(symbol, bar_vec)| {
for bar in bar_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
});
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
}
if bars.is_empty() {
info!("No bars to backfill for {}.", symbol);
return;
}
let backfill = bars.last().unwrap().clone().into();
database::bars::upsert_batch(&self.config.clickhouse_client, &bars)
.await
.unwrap();
database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill)
.await
.unwrap();
info!("Backfilled bars for {}.", symbol);
}
fn log_string(&self) -> &'static str {
"bars"
}
}
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
async fn select_latest_backfill(
&self,
symbol: String,
) -> Result<Option<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(&self.config.clickhouse_client, symbols).await
}
async fn queue_backfill(&self, symbol: &str, fetch_to: OffsetDateTime) {
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
info!("Queing news backfill for {} in {:?}.", symbol, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, symbol: String, fetch_from: OffsetDateTime, fetch_to: OffsetDateTime) {
info!("Backfilling news for {}.", symbol);
let mut news = vec![];
let mut next_page_token = None;
loop {
let Ok(message) = alpaca::api::incoming::news::get_historical(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News {
symbols: vec![symbol.clone()],
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
..Default::default()
},
None,
)
.await
else {
error!("Failed to backfill news for {}.", symbol);
return;
};
message.news.into_iter().for_each(|news_item| {
news.push(News::from(news_item));
});
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
}
if news.is_empty() {
info!("No news to backfill for {}.", symbol);
return;
}
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions = join_all(inputs.chunks(*MAX_BERT_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}))
.await
.into_iter()
.flatten();
let news = news
.into_iter()
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
let backfill = (news.last().unwrap().clone(), symbol.clone()).into();
database::news::upsert_batch(&self.config.clickhouse_client, &news)
.await
.unwrap();
database::backfills_news::upsert(&self.config.clickhouse_client, &backfill)
.await
.unwrap();
info!("Backfilled news for {}.", symbol);
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler {
config,
data_url: ALPACA_STOCK_DATA_API_URL,
api_query_constructor: us_equity_query_constructor,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarHandler {
config,
data_url: ALPACA_CRYPTO_DATA_API_URL,
api_query_constructor: crypto_query_constructor,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

View File

@@ -0,0 +1,238 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_BARS_SIZE},
database,
threads::data::ThreadType,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::{
api::{ALPACA_CRYPTO_DATA_API_URL, ALPACA_US_EQUITY_DATA_API_URL},
shared::{Sort, Source},
},
Backfill, Bar, Class,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
pub data_url: &'static str,
pub api_query_constructor: fn(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar,
}
pub fn us_equity_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::UsEquity(types::alpaca::api::outgoing::bar::UsEquity {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
feed: Some(*ALPACA_SOURCE),
..Default::default()
})
}
pub fn crypto_query_constructor(
symbols: Vec<String>,
fetch_from: OffsetDateTime,
fetch_to: OffsetDateTime,
next_page_token: Option<String>,
) -> types::alpaca::api::outgoing::bar::Bar {
types::alpaca::api::outgoing::bar::Bar::Crypto(types::alpaca::api::outgoing::bar::Crypto {
symbols,
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token,
sort: Some(Sort::Asc),
..Default::default()
})
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_bars::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::bars::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut bars = Vec::with_capacity(*CLICKHOUSE_BATCH_BARS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling bars for {:?}.", symbols);
loop {
let message = alpaca::bars::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
self.data_url,
&(self.api_query_constructor)(
symbols.clone(),
fetch_from,
fetch_to,
next_page_token.clone(),
),
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill bars for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for (symbol, bars_vec) in message.bars {
if let Some(last) = bars_vec.last() {
last_times.insert(symbol.clone(), last.time);
}
for bar in bars_vec {
bars.push(Bar::from((bar, symbol.clone())));
}
}
if bars.len() < *CLICKHOUSE_BATCH_BARS_SIZE && message.next_page_token.is_some() {
continue;
}
database::bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&bars,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_bars::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
bars.clear();
}
database::backfills_bars::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled bars for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::bars::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"bars"
}
}
pub fn create_handler(config: Arc<Config>, thread_type: ThreadType) -> Box<dyn super::Handler> {
let data_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => ALPACA_US_EQUITY_DATA_API_URL,
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_API_URL,
_ => unreachable!(),
};
let api_query_constructor = match thread_type {
ThreadType::Bars(Class::UsEquity) => us_equity_query_constructor,
ThreadType::Bars(Class::Crypto) => crypto_query_constructor,
_ => unreachable!(),
};
Box::new(Handler {
config,
data_url,
api_query_constructor,
})
}

View File

@@ -0,0 +1,244 @@
pub mod bars;
pub mod news;
use async_trait::async_trait;
use itertools::Itertools;
use log::{info, warn};
use nonempty::{nonempty, NonEmpty};
use qrust::{
types::Backfill,
utils::{last_minute, ONE_SECOND},
};
use std::{collections::HashMap, hash::Hash, sync::Arc};
use time::OffsetDateTime;
use tokio::{
spawn,
sync::{mpsc, oneshot, Mutex},
task::JoinHandle,
try_join,
};
use uuid::Uuid;
pub enum Action {
Backfill,
Purge,
}
pub struct Message {
pub action: Action,
pub symbols: NonEmpty<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Action, symbols: NonEmpty<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel::<()>();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
#[derive(Clone)]
pub struct Job {
pub symbol: String,
pub fetch_from: OffsetDateTime,
pub fetch_to: OffsetDateTime,
pub fresh: bool,
}
#[async_trait]
pub trait Handler: Send + Sync {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
async fn queue_backfill(&self, jobs: &NonEmpty<Job>);
async fn backfill(&self, jobs: NonEmpty<Job>);
fn max_limit(&self) -> i64;
fn log_string(&self) -> &'static str;
}
pub struct Jobs {
pub symbol_to_uuid: HashMap<String, Uuid>,
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
}
impl Jobs {
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
let uuid = Uuid::new_v4();
for symbol in jobs {
self.symbol_to_uuid.insert(symbol.clone(), uuid);
}
self.uuid_to_job.insert(uuid, fut);
}
pub fn contains_key(&self, symbol: &str) -> bool {
self.symbol_to_uuid.contains_key(symbol)
}
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
self.symbol_to_uuid
.remove(symbol)
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
}
pub fn remove_many<T>(&mut self, symbols: &[T])
where
T: AsRef<str> + Hash + Eq,
{
for symbol in symbols {
self.symbol_to_uuid
.remove(symbol.as_ref())
.and_then(|uuid| self.uuid_to_job.remove(&uuid));
}
}
pub fn len(&self) -> usize {
self.symbol_to_uuid.len()
}
}
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
let backfill_jobs = Arc::new(Mutex::new(Jobs {
symbol_to_uuid: HashMap::new(),
uuid_to_job: HashMap::new(),
}));
loop {
let message = receiver.recv().await.unwrap();
spawn(handle_message(
handler.clone(),
backfill_jobs.clone(),
message,
));
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
backfill_jobs: Arc<Mutex<Jobs>>,
message: Message,
) {
let backfill_jobs_clone = backfill_jobs.clone();
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = Vec::from(message.symbols);
match message.action {
Action::Backfill => {
let log_string = handler.log_string();
let max_limit = handler.max_limit();
let backfills = handler
.select_latest_backfills(&symbols)
.await
.unwrap()
.into_iter()
.map(|backfill| (backfill.symbol.clone(), backfill))
.collect::<HashMap<_, _>>();
let mut jobs = Vec::with_capacity(symbols.len());
for symbol in symbols {
if backfill_jobs.contains_key(&symbol) {
warn!(
"Backfill for {} {} is already running, skipping.",
symbol, log_string
);
continue;
}
let backfill = backfills.get(&symbol);
let fetch_from = backfill.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
backfill.time + ONE_SECOND
});
let fetch_to = last_minute();
if fetch_from > fetch_to {
info!("No need to backfill {} {}.", symbol, log_string,);
return;
}
let fresh = backfill.map_or(false, |backfill| backfill.fresh);
jobs.push(Job {
symbol,
fetch_from,
fetch_to,
fresh,
});
}
let jobs = jobs
.into_iter()
.sorted_unstable_by_key(|job| job.fetch_from)
.collect::<Vec<_>>();
let mut job_groups: Vec<NonEmpty<Job>> = vec![];
let mut current_minutes = 0;
for job in jobs {
let minutes = (job.fetch_to - job.fetch_from).whole_minutes();
if job_groups.last().is_some() && current_minutes + minutes <= max_limit {
let job_group = job_groups.last_mut().unwrap();
job_group.push(job);
current_minutes += minutes;
} else {
job_groups.push(nonempty![job]);
current_minutes = minutes;
}
}
for job_group in job_groups {
let symbols = job_group
.iter()
.map(|job| job.symbol.clone())
.collect::<Vec<_>>();
let handler = handler.clone();
let symbols_clone = symbols.clone();
let backfill_jobs_clone = backfill_jobs_clone.clone();
let fut = spawn(async move {
handler.queue_backfill(&job_group).await;
handler.backfill(job_group).await;
let mut backfill_jobs = backfill_jobs_clone.lock().await;
backfill_jobs.remove_many(&symbols_clone);
let remaining = backfill_jobs.len();
drop(backfill_jobs);
info!("{} {} backfills remaining.", remaining, log_string);
});
backfill_jobs.insert(symbols, fut);
}
}
Action::Purge => {
for symbol in &symbols {
if let Some(job) = backfill_jobs.remove(symbol) {
job.abort();
let _ = job.await;
}
}
try_join!(
handler.delete_backfills(&symbols),
handler.delete_data(&symbols)
)
.unwrap();
}
}
message.response.send(()).unwrap();
}

View File

@@ -0,0 +1,186 @@
use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, CLICKHOUSE_BATCH_NEWS_SIZE},
database,
};
use async_trait::async_trait;
use log::{error, info};
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
self,
alpaca::shared::{Sort, Source},
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::time::sleep;
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &NonEmpty<Job>) {
if *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.iter().map(|job| &job.symbol).collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::iter_with_drain)]
async fn backfill(&self, jobs: NonEmpty<Job>) {
let symbols = Vec::from(jobs.clone().map(|job| job.symbol));
let symbols_set = symbols.clone().into_iter().collect::<HashSet<_>>();
let fetch_from = jobs.minimum_by_key(|job| job.fetch_from).fetch_from;
let fetch_to = jobs.maximum_by_key(|job| job.fetch_to).fetch_to;
let freshness = jobs
.into_iter()
.map(|job| (job.symbol, job.fresh))
.collect::<HashMap<_, _>>();
let mut news = Vec::with_capacity(*CLICKHOUSE_BATCH_NEWS_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
info!("Backfilling news for {:?}.", symbols);
loop {
let message = alpaca::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&types::alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() < *CLICKHOUSE_BATCH_NEWS_SIZE && message.next_page_token.is_some() {
continue;
}
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill {
fresh: freshness[&symbol],
symbol,
time,
})
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
database::backfills_news::set_fresh_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
true,
&symbols,
)
.await
.unwrap();
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}
pub fn create_handler(config: Arc<Config>) -> Box<dyn super::Handler> {
Box::new(Handler { config })
}

View File

@@ -3,24 +3,29 @@ mod websocket;
use super::clock; use super::clock;
use crate::{ use crate::{
config::{ config::{Config, ALPACA_API_BASE, ALPACA_SOURCE},
Config, ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL, ALPACA_SOURCE,
ALPACA_STOCK_DATA_WEBSOCKET_URL,
},
create_send_await, database, create_send_await, database,
types::{alpaca, Asset, Class},
utils::backoff,
}; };
use futures_util::{future::join_all, StreamExt};
use itertools::{Either, Itertools}; use itertools::{Either, Itertools};
use std::sync::Arc; use log::error;
use nonempty::NonEmpty;
use qrust::{
alpaca,
types::{
alpaca::websocket::{
ALPACA_CRYPTO_DATA_WEBSOCKET_URL, ALPACA_NEWS_DATA_WEBSOCKET_URL,
ALPACA_US_EQUITY_DATA_WEBSOCKET_URL,
},
Asset, Class,
},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{ use tokio::{
join, select, spawn, join, select, spawn,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
}; };
use tokio_tungstenite::connect_async;
#[derive(Clone, Copy)] #[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum Action { pub enum Action {
Add, Add,
@@ -31,12 +36,12 @@ pub enum Action {
pub struct Message { pub struct Message {
pub action: Action, pub action: Action,
pub assets: Vec<(String, Class)>, pub assets: NonEmpty<(String, Class)>,
pub response: oneshot::Sender<()>, pub response: oneshot::Sender<()>,
} }
impl Message { impl Message {
pub fn new(action: Action, assets: Vec<(String, Class)>) -> (Self, oneshot::Receiver<()>) { pub fn new(action: Action, assets: NonEmpty<(String, Class)>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
( (
Self { Self {
@@ -61,11 +66,11 @@ pub async fn run(
mut clock_receiver: mpsc::Receiver<clock::Message>, mut clock_receiver: mpsc::Receiver<clock::Message>,
) { ) {
let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) = let (bars_us_equity_websocket_sender, bars_us_equity_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::UsEquity)).await; init_thread(config.clone(), ThreadType::Bars(Class::UsEquity));
let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) = let (bars_crypto_websocket_sender, bars_crypto_backfill_sender) =
init_thread(config.clone(), ThreadType::Bars(Class::Crypto)).await; init_thread(config.clone(), ThreadType::Bars(Class::Crypto));
let (news_websocket_sender, news_backfill_sender) = let (news_websocket_sender, news_backfill_sender) =
init_thread(config.clone(), ThreadType::News).await; init_thread(config.clone(), ThreadType::News);
loop { loop {
select! { select! {
@@ -94,7 +99,7 @@ pub async fn run(
} }
} }
async fn init_thread( fn init_thread(
config: Arc<Config>, config: Arc<Config>,
thread_type: ThreadType, thread_type: ThreadType,
) -> ( ) -> (
@@ -103,28 +108,32 @@ async fn init_thread(
) { ) {
let websocket_url = match thread_type { let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => { ThreadType::Bars(Class::UsEquity) => {
format!("{}/{}", ALPACA_STOCK_DATA_WEBSOCKET_URL, *ALPACA_SOURCE) format!("{}/{}", ALPACA_US_EQUITY_DATA_WEBSOCKET_URL, *ALPACA_SOURCE)
} }
ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(), ThreadType::Bars(Class::Crypto) => ALPACA_CRYPTO_DATA_WEBSOCKET_URL.into(),
ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(), ThreadType::News => ALPACA_NEWS_DATA_WEBSOCKET_URL.into(),
}; };
let (websocket, _) = connect_async(websocket_url).await.unwrap(); let backfill_handler = match thread_type {
let (mut websocket_sink, mut websocket_stream) = websocket.split(); ThreadType::Bars(_) => backfill::bars::create_handler(config.clone(), thread_type),
alpaca::websocket::data::authenticate(&mut websocket_sink, &mut websocket_stream).await; ThreadType::News => backfill::news::create_handler(config.clone()),
};
let (backfill_sender, backfill_receiver) = mpsc::channel(100); let (backfill_sender, backfill_receiver) = mpsc::channel(100);
spawn(backfill::run(
Arc::new(backfill::create_handler(thread_type, config.clone())), spawn(backfill::run(backfill_handler.into(), backfill_receiver));
backfill_receiver,
)); let websocket_handler = match thread_type {
ThreadType::Bars(_) => websocket::bars::create_handler(config, thread_type),
ThreadType::News => websocket::news::create_handler(config),
};
let (websocket_sender, websocket_receiver) = mpsc::channel(100); let (websocket_sender, websocket_receiver) = mpsc::channel(100);
spawn(websocket::run( spawn(websocket::run(
Arc::new(websocket::create_handler(thread_type, config.clone())), websocket_handler.into(),
websocket_receiver, websocket_receiver,
websocket_stream, websocket_url,
websocket_sink,
)); ));
(websocket_sender, backfill_sender) (websocket_sender, backfill_sender)
@@ -142,11 +151,6 @@ async fn handle_message(
news_backfill_sender: mpsc::Sender<backfill::Message>, news_backfill_sender: mpsc::Sender<backfill::Message>,
message: Message, message: Message,
) { ) {
if message.assets.is_empty() {
message.response.send(()).unwrap();
return;
}
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = message
.assets .assets
.clone() .clone()
@@ -156,50 +160,28 @@ async fn handle_message(
Class::Crypto => Either::Right(asset.0), Class::Crypto => Either::Right(asset.0),
}); });
let symbols = message let symbols = message.assets.map(|(symbol, _)| symbol);
.assets
.into_iter()
.map(|(symbol, _)| symbol)
.collect::<Vec<_>>();
let bars_us_equity_future = async { let bars_us_equity_future = async {
if us_equity_symbols.is_empty() { if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols.clone()) {
return; create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols
);
} }
create_send_await!(
bars_us_equity_websocket_sender,
websocket::Message::new,
message.action.into(),
us_equity_symbols.clone()
);
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
message.action.into(),
us_equity_symbols
);
}; };
let bars_crypto_future = async { let bars_crypto_future = async {
if crypto_symbols.is_empty() { if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols.clone()) {
return; create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols
);
} }
create_send_await!(
bars_crypto_websocket_sender,
websocket::Message::new,
message.action.into(),
crypto_symbols.clone()
);
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
message.action.into(),
crypto_symbols
);
}; };
let news_future = async { let news_future = async {
@@ -209,62 +191,127 @@ async fn handle_message(
message.action.into(), message.action.into(),
symbols.clone() symbols.clone()
); );
create_send_await!(
news_backfill_sender,
backfill::Message::new,
message.action.into(),
symbols.clone()
);
}; };
join!(bars_us_equity_future, bars_crypto_future, news_future); join!(bars_us_equity_future, bars_crypto_future, news_future);
if message.action == Action::Disable {
message.response.send(()).unwrap();
return;
}
match message.action { match message.action {
Action::Add => { Action::Add | Action::Enable => {
let assets = join_all(symbols.into_iter().map(|symbol| { let symbols = Vec::from(symbols.clone());
let config = config.clone();
async move {
let asset_future = async {
alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
Some(backoff::infinite()),
)
.await
.unwrap()
};
let position_future = async { let assets = async {
alpaca::api::incoming::position::get_by_symbol( alpaca::assets::get_by_symbols(
&config.alpaca_client, &config.alpaca_client,
&config.alpaca_rate_limiter, &config.alpaca_rate_limiter,
&symbol, &symbols,
Some(backoff::infinite()), None,
) &ALPACA_API_BASE,
.await )
.unwrap()
};
let (asset, position) = join!(asset_future, position_future);
Asset::from((asset, position))
}
}))
.await;
database::assets::upsert_batch(&config.clickhouse_client, &assets)
.await .await
.unwrap(); .unwrap()
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>()
};
let positions = async {
alpaca::positions::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbols,
None,
&ALPACA_API_BASE,
)
.await
.unwrap()
.into_iter()
.map(|position| (position.symbol.clone(), position))
.collect::<HashMap<_, _>>()
};
let (mut assets, mut positions) = join!(assets, positions);
let mut batch = Vec::with_capacity(symbols.len());
for symbol in &symbols {
if let Some(asset) = assets.remove(symbol) {
let position = positions.remove(symbol);
batch.push(Asset::from((asset, position)));
} else {
error!("Failed to find asset for symbol: {}.", symbol);
}
}
database::assets::upsert_batch(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&batch,
)
.await
.unwrap();
} }
Action::Remove => { Action::Remove => {
database::assets::delete_where_symbols(&config.clickhouse_client, &symbols) database::assets::delete_where_symbols(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
&Vec::from(symbols.clone()),
)
.await
.unwrap();
} }
_ => {} Action::Disable => unreachable!(),
} }
let bars_us_equity_future = async {
if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
create_send_await!(
bars_us_equity_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
us_equity_symbols
);
}
};
let bars_crypto_future = async {
if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
create_send_await!(
bars_crypto_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
crypto_symbols
);
}
};
let news_future = async {
create_send_await!(
news_backfill_sender,
backfill::Message::new,
match message.action {
Action::Add | Action::Enable => backfill::Action::Backfill,
Action::Remove => backfill::Action::Purge,
Action::Disable => unreachable!(),
},
symbols
);
};
join!(bars_us_equity_future, bars_crypto_future, news_future);
message.response.send(()).unwrap(); message.response.send(()).unwrap();
} }
@@ -274,13 +321,19 @@ async fn handle_clock_message(
bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>, bars_crypto_backfill_sender: mpsc::Sender<backfill::Message>,
news_backfill_sender: mpsc::Sender<backfill::Message>, news_backfill_sender: mpsc::Sender<backfill::Message>,
) { ) {
database::cleanup_all(&config.clickhouse_client) database::cleanup_all(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let assets = database::assets::select(&config.clickhouse_client) let assets = database::assets::select(
.await &config.clickhouse_client,
.unwrap(); &config.clickhouse_concurrency_limiter,
)
.await
.unwrap();
let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets let (us_equity_symbols, crypto_symbols): (Vec<_>, Vec<_>) = assets
.clone() .clone()
@@ -296,30 +349,36 @@ async fn handle_clock_message(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let bars_us_equity_future = async { let bars_us_equity_future = async {
create_send_await!( if let Some(us_equity_symbols) = NonEmpty::from_vec(us_equity_symbols) {
bars_us_equity_backfill_sender, create_send_await!(
backfill::Message::new, bars_us_equity_backfill_sender,
Some(backfill::Action::Backfill), backfill::Message::new,
us_equity_symbols.clone() backfill::Action::Backfill,
); us_equity_symbols
);
}
}; };
let bars_crypto_future = async { let bars_crypto_future = async {
create_send_await!( if let Some(crypto_symbols) = NonEmpty::from_vec(crypto_symbols) {
bars_crypto_backfill_sender, create_send_await!(
backfill::Message::new, bars_crypto_backfill_sender,
Some(backfill::Action::Backfill), backfill::Message::new,
crypto_symbols.clone() backfill::Action::Backfill,
); crypto_symbols
);
}
}; };
let news_future = async { let news_future = async {
create_send_await!( if let Some(symbols) = NonEmpty::from_vec(symbols) {
news_backfill_sender, create_send_await!(
backfill::Message::new, news_backfill_sender,
Some(backfill::Action::Backfill), backfill::Message::new,
symbols backfill::Action::Backfill,
); symbols
);
}
}; };
join!(bars_us_equity_future, bars_crypto_future, news_future); join!(bars_us_equity_future, bars_crypto_future, news_future);

View File

@@ -1,427 +0,0 @@
use super::ThreadType;
use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, Class, News},
};
use async_trait::async_trait;
use futures_util::{
future::join_all,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use log::{debug, error, info};
use serde_json::{from_str, to_string};
use std::{collections::HashMap, sync::Arc};
use tokio::{
net::TcpStream,
select, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::block_in_place,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
pub enum Action {
Subscribe,
Unsubscribe,
}
impl From<super::Action> for Option<Action> {
fn from(action: super::Action) -> Self {
match action {
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
}
}
}
pub struct Message {
pub action: Option<Action>,
pub symbols: Vec<String>,
pub response: oneshot::Sender<()>,
}
impl Message {
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
action,
symbols,
response: sender,
},
receiver,
)
}
}
pub struct Pending {
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
}
#[async_trait]
pub trait Handler: Send + Sync {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message;
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
);
}
pub async fn run(
handler: Arc<Box<dyn Handler>>,
mut receiver: mpsc::Receiver<Message>,
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
) {
let pending = Arc::new(RwLock::new(Pending {
subscriptions: HashMap::new(),
unsubscriptions: HashMap::new(),
}));
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
loop {
select! {
Some(message) = receiver.recv() => {
spawn(handle_message(
handler.clone(),
pending.clone(),
websocket_sink.clone(),
message,
));
}
Some(Ok(message)) = websocket_stream.next() => {
match message {
tungstenite::Message::Text(message) => {
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
if parsed_message.is_err() {
error!("Failed to deserialize websocket message: {:?}", message);
continue;
}
for message in parsed_message.unwrap() {
let handler = handler.clone();
let pending = pending.clone();
spawn(async move {
handler.handle_websocket_message(pending, message).await;
});
}
}
tungstenite::Message::Ping(_) => {}
_ => error!("Unexpected websocket message: {:?}", message),
}
}
else => panic!("Communication channel unexpectedly closed.")
}
}
}
async fn handle_message(
handler: Arc<Box<dyn Handler>>,
pending: Arc<RwLock<Pending>>,
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
message: Message,
) {
if message.symbols.is_empty() {
message.response.send(()).unwrap();
return;
}
match message.action {
Some(Action::Subscribe) => {
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.subscriptions
.extend(pending_subscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Subscribe(
handler.create_subscription_message(message.symbols),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
Some(Action::Unsubscribe) => {
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
.symbols
.iter()
.map(|symbol| {
let (sender, receiver) = oneshot::channel();
((symbol.clone(), sender), receiver)
})
.unzip();
pending
.write()
.await
.unsubscriptions
.extend(pending_unsubscriptions);
sink.lock()
.await
.send(tungstenite::Message::Text(
to_string(&websocket::data::outgoing::Message::Unsubscribe(
handler.create_subscription_message(message.symbols.clone()),
))
.unwrap(),
))
.await
.unwrap();
join_all(receivers).await;
}
None => {}
}
message.response.send(()).unwrap();
}
struct BarsHandler {
config: Arc<Config>,
subscription_message_constructor:
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
}
#[async_trait]
impl Handler for BarsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
(self.subscription_message_constructor)(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::Market {
bars: symbols, ..
} = message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to bars for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from bars for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::Bar(message)
| websocket::data::incoming::Message::UpdatedBar(message) => {
let bar = Bar::from(message);
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
database::bars::upsert(&self.config.clickhouse_client, &bar)
.await
.unwrap();
}
websocket::data::incoming::Message::Status(message) => {
debug!(
"Received status message for {}: {:?}.",
message.symbol, message.status
);
match message.status {
websocket::data::incoming::status::Status::TradingHalt(_)
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&message.symbol,
false,
)
.await
.unwrap();
}
websocket::data::incoming::status::Status::Resume(_)
| websocket::data::incoming::status::Status::TradingResumption(_) => {
database::assets::update_status_where_symbol(
&self.config.clickhouse_client,
&message.symbol,
true,
)
.await
.unwrap();
}
_ => {}
}
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
struct NewsHandler {
config: Arc<Config>,
}
#[async_trait]
impl Handler for NewsHandler {
fn create_subscription_message(
&self,
symbols: Vec<String>,
) -> websocket::data::outgoing::subscribe::Message {
websocket::data::outgoing::subscribe::Message::new_news(symbols)
}
async fn handle_websocket_message(
&self,
pending: Arc<RwLock<Pending>>,
message: websocket::data::incoming::Message,
) {
match message {
websocket::data::incoming::Message::Subscription(message) => {
let websocket::data::incoming::subscription::Message::News { news: symbols } =
message
else {
unreachable!()
};
let mut pending = pending.write().await;
let newly_subscribed = pending
.subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
let newly_unsubscribed = pending
.unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
drop(pending);
if !newly_subscribed.is_empty() {
info!(
"Subscribed to news for {:?}.",
newly_subscribed.keys().collect::<Vec<_>>()
);
for sender in newly_subscribed.into_values() {
sender.send(()).unwrap();
}
}
if !newly_unsubscribed.is_empty() {
info!(
"Unsubscribed from news for {:?}.",
newly_unsubscribed.keys().collect::<Vec<_>>()
);
for sender in newly_unsubscribed.into_values() {
sender.send(()).unwrap();
}
}
}
websocket::data::incoming::Message::News(message) => {
let news = News::from(message);
debug!(
"Received news for {:?}: {}.",
news.symbols, news.time_created
);
let input = format!("{}\n\n{}", news.headline, news.content);
let sequence_classifier = self.config.sequence_classifier.lock().await;
let prediction = block_in_place(|| {
sequence_classifier
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&self.config.clickhouse_client, &news)
.await
.unwrap();
}
websocket::data::incoming::Message::Error(message) => {
error!("Received error message: {}.", message.message);
}
_ => unreachable!(),
}
}
}
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
match thread_type {
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
}),
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
config,
subscription_message_constructor:
websocket::data::outgoing::subscribe::Message::new_market_crypto,
}),
ThreadType::News => Box::new(NewsHandler { config }),
}
}

Some files were not shown because too many files have changed in this diff Show More