use crate::types::alpaca::Source; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Client, }; use rust_bert::{ pipelines::{ common::{ModelResource, ModelType}, sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel}, }, resources::LocalResource, }; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; use tokio::sync::Mutex; pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets"; pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock"; pub const ALPACA_STOCK_DATA_URL: &str = "https://data.alpaca.markets/v2/stocks/bars"; pub const ALPACA_CRYPTO_DATA_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; pub const ALPACA_NEWS_DATA_URL: &str = "https://data.alpaca.markets/v1beta1/news"; pub const ALPACA_STOCK_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; pub const ALPACA_CRYPTO_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; pub struct Config { pub alpaca_api_key: String, pub alpaca_api_secret: String, pub alpaca_client: Client, pub alpaca_rate_limit: DefaultDirectRateLimiter, pub alpaca_source: Source, pub clickhouse_client: clickhouse::Client, pub max_bert_inputs: usize, pub sequence_classifier: Arc>, } impl Config { pub fn from_env() -> Self { let alpaca_api_key = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); let alpaca_api_secret = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); let alpaca_source: Source = env::var("ALPACA_SOURCE") .expect("ALPACA_SOURCE must be set.") .parse() .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'."); let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."); let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."); let clickhouse_password = env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."); let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."); let max_bert_inputs: usize = env::var("MAX_BERT_INPUTS") .expect("MAX_BERT_INPUTS must be set.") .parse() .expect("MAX_BERT_INPUTS must be a positive integer."); Self { alpaca_client: Client::builder() .default_headers(HeaderMap::from_iter([ ( HeaderName::from_static("apca-api-key-id"), HeaderValue::from_str(&alpaca_api_key) .expect("Alpaca API key must not contain invalid characters."), ), ( HeaderName::from_static("apca-api-secret-key"), HeaderValue::from_str(&alpaca_api_secret) .expect("Alpaca API secret must not contain invalid characters."), ), ])) .build() .unwrap(), alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source { Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, Source::Otc => unimplemented!("OTC rate limit not implemented."), })), alpaca_source, clickhouse_client: clickhouse::Client::default() .with_url(clickhouse_url) .with_user(clickhouse_user) .with_password(clickhouse_password) .with_database(clickhouse_db), alpaca_api_key, alpaca_api_secret, max_bert_inputs, sequence_classifier: Arc::new(Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, ModelResource::Torch(Box::new(LocalResource { local_path: PathBuf::from("./models/finbert/rust_model.ot"), })), LocalResource { local_path: PathBuf::from("./models/finbert/config.json"), }, LocalResource { local_path: PathBuf::from("./models/finbert/vocab.txt"), }, None, true, None, None, )) .unwrap(), )), } } pub fn arc_from_env() -> Arc { Arc::new(Self::from_env()) } }