use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use lazy_static::lazy_static; use qrust::types::alpaca::shared::{Mode, Source}; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Client, }; use rust_bert::{ pipelines::{ common::{ModelResource, ModelType}, sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel}, }, resources::LocalResource, }; use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc}; use tokio::sync::Semaphore; lazy_static! { pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE") .expect("ALPACA_MODE must be set.") .parse() .expect("ALPACA_MODE must be 'live' or 'paper'"); pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE { Mode::Live => String::from("api"), Mode::Paper => String::from("paper-api"), }; pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE") .expect("ALPACA_SOURCE must be set.") .parse() .expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'"); pub static ref ALPACA_API_KEY: String = env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set."); pub static ref ALPACA_API_SECRET: String = env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set."); pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE") .expect("BATCH_BACKFILL_BARS_SIZE must be set.") .parse() .expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer."); pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE") .expect("BATCH_BACKFILL_NEWS_SIZE must be set.") .parse() .expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer."); pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS") .expect("BERT_MAX_INPUTS must be set.") .parse() .expect("BERT_MAX_INPUTS must be a positive integer."); pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS") .expect("CLICKHOUSE_MAX_CONNECTIONS must be set.") .parse() .expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer."); } pub struct Config { pub alpaca_client: Client, pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub clickhouse_client: clickhouse::Client, pub clickhouse_concurrency_limiter: Arc, pub sequence_classifier: std::sync::Mutex, } impl Config { pub fn from_env() -> Self { Self { alpaca_client: Client::builder() .default_headers(HeaderMap::from_iter([ ( HeaderName::from_static("apca-api-key-id"), HeaderValue::from_str(&ALPACA_API_KEY) .expect("Alpaca API key must not contain invalid characters."), ), ( HeaderName::from_static("apca-api-secret-key"), HeaderValue::from_str(&ALPACA_API_SECRET) .expect("Alpaca API secret must not contain invalid characters."), ), ])) .build() .unwrap(), alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE { Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, Source::Otc => unimplemented!("OTC rate limit not implemented."), })), clickhouse_client: clickhouse::Client::default() .with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set.")) .with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set.")) .with_password( env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."), ) .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)), sequence_classifier: std::sync::Mutex::new( SequenceClassificationModel::new(SequenceClassificationConfig::new( ModelType::Bert, ModelResource::Torch(Box::new(LocalResource { local_path: PathBuf::from("./models/finbert/rust_model.ot"), })), LocalResource { local_path: PathBuf::from("./models/finbert/config.json"), }, LocalResource { local_path: PathBuf::from("./models/finbert/vocab.txt"), }, None, true, None, None, )) .unwrap(), ), } } pub fn arc_from_env() -> Arc { Arc::new(Self::from_env()) } }