use crate::types::alpaca::shared::{Mode, Source}; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use lazy_static::lazy_static; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Client, }; use rust_bert::{ pipelines::{ common::{ModelResource, ModelType}, sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel}, }, resources::LocalResource, }; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; use tokio::sync::Mutex; pub const ALPACA_STOCK_DATA_API_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_CRYPTO_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; pub const ALPACA_NEWS_DATA_API_URL: &str = "https://data.alpaca.markets/v1beta1/news"; pub const ALPACA_STOCK_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; pub const ALPACA_CRYPTO_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; pub const ALPACA_NEWS_DATA_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; lazy_static! { pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") .expect("ALPACA_MODE must be set.") .parse() .expect("ALPACA_MODE must be 'live' or 'paper'"); pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") .expect("ALPACA_SOURCE must be set.") .parse() .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); #[derive(Debug)] pub static ref ALPACA_API_URL: String = format!( "https://{}.alpaca.markets/v2", match *ALPACA_MODE { Mode::Live => String::from("api"), Mode::Paper => String::from("paper-api"), } ); #[derive(Debug)] pub static ref ALPACA_WEBSOCKET_URL: String = format!( "wss://{}.alpaca.markets/stream", match *ALPACA_MODE { Mode::Live => String::from("api"), Mode::Paper => String::from("paper-api"), } ); pub static ref MAX_BERT_INPUTS: usize = env::var("MAX_BERT_INPUTS") .expect("MAX_BERT_INPUTS must be set.") .parse() .expect("MAX_BERT_INPUTS must be a positive integer."); } pub struct Config { pub alpaca_client: Client, pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub clickhouse_client: clickhouse::Client, pub sequence_classifier: Mutex, } impl Config { pub fn from_env() -> Self { Self { alpaca_client: Client::builder() .default_headers(HeaderMap::from_iter([ ( HeaderName::from_static("apca-api-key-id"), HeaderValue::from_str(&ALPACA_API_KEY) .expect("Alpaca API key must not contain invalid characters."), ), ( HeaderName::from_static("apca-api-secret-key"), HeaderValue::from_str(&ALPACA_API_SECRET) .expect("Alpaca API secret must not contain invalid characters."), ), ])) .build() .unwrap(), alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, Source::Otc => unimplemented!("OTC rate limit not implemented."), })), clickhouse_client: clickhouse::Client::default() .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) .with_password( env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), ) .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), sequence_classifier: Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, ModelResource::Torch(Box::new(LocalResource { local_path: PathBuf::from("./models/finbert/rust_model.ot"), })), LocalResource { local_path: PathBuf::from("./models/finbert/config.json"), }, LocalResource { local_path: PathBuf::from("./models/finbert/vocab.txt"), }, None, true, None, None, )) .unwrap(), ), } } pub fn arc_from_env() -> Arc { Arc::new(Self::from_env()) } }