118 lines
5.0 KiB
Rust
118 lines
5.0 KiB
Rust
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
|
|
use lazy_static::lazy_static;
|
|
use qrust::types::alpaca::shared::{Mode, Source};
|
|
use reqwest::{
|
|
header::{HeaderMap, HeaderName, HeaderValue},
|
|
Client,
|
|
};
|
|
use rust_bert::{
|
|
pipelines::{
|
|
common::{ModelResource, ModelType},
|
|
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
|
|
},
|
|
resources::LocalResource,
|
|
};
|
|
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
|
use tokio::sync::Semaphore;
|
|
|
|
lazy_static! {
|
|
pub static ref ALPACA_MODE: Mode = env::var("ALPACA_MODE")
|
|
.expect("ALPACA_MODE must be set.")
|
|
.parse()
|
|
.expect("ALPACA_MODE must be 'live' or 'paper'");
|
|
pub static ref ALPACA_API_BASE: String = match *ALPACA_MODE {
|
|
Mode::Live => String::from("api"),
|
|
Mode::Paper => String::from("paper-api"),
|
|
};
|
|
pub static ref ALPACA_SOURCE: Source = env::var("ALPACA_SOURCE")
|
|
.expect("ALPACA_SOURCE must be set.")
|
|
.parse()
|
|
.expect("ALPACA_SOURCE must be 'iex', 'sip', or 'otc'");
|
|
pub static ref ALPACA_API_KEY: String =
|
|
env::var("ALPACA_API_KEY").expect("ALPACA_API_KEY must be set.");
|
|
pub static ref ALPACA_API_SECRET: String =
|
|
env::var("ALPACA_API_SECRET").expect("ALPACA_API_SECRET must be set.");
|
|
pub static ref BATCH_BACKFILL_BARS_SIZE: usize = env::var("BATCH_BACKFILL_BARS_SIZE")
|
|
.expect("BATCH_BACKFILL_BARS_SIZE must be set.")
|
|
.parse()
|
|
.expect("BATCH_BACKFILL_BARS_SIZE must be a positive integer.");
|
|
pub static ref BATCH_BACKFILL_NEWS_SIZE: usize = env::var("BATCH_BACKFILL_NEWS_SIZE")
|
|
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
|
|
.parse()
|
|
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
|
|
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
|
|
.expect("BERT_MAX_INPUTS must be set.")
|
|
.parse()
|
|
.expect("BERT_MAX_INPUTS must be a positive integer.");
|
|
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
|
|
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
|
|
.parse()
|
|
.expect("CLICKHOUSE_MAX_CONNECTIONS must be a positive integer.");
|
|
}
|
|
|
|
pub struct Config {
|
|
pub alpaca_client: Client,
|
|
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
|
pub clickhouse_client: clickhouse::Client,
|
|
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
|
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
|
|
}
|
|
|
|
impl Config {
|
|
pub fn from_env() -> Self {
|
|
Self {
|
|
alpaca_client: Client::builder()
|
|
.default_headers(HeaderMap::from_iter([
|
|
(
|
|
HeaderName::from_static("apca-api-key-id"),
|
|
HeaderValue::from_str(&ALPACA_API_KEY)
|
|
.expect("Alpaca API key must not contain invalid characters."),
|
|
),
|
|
(
|
|
HeaderName::from_static("apca-api-secret-key"),
|
|
HeaderValue::from_str(&ALPACA_API_SECRET)
|
|
.expect("Alpaca API secret must not contain invalid characters."),
|
|
),
|
|
]))
|
|
.build()
|
|
.unwrap(),
|
|
alpaca_rate_limiter: RateLimiter::direct(Quota::per_minute(match *ALPACA_SOURCE {
|
|
Source::Iex => unsafe { NonZeroU32::new_unchecked(200) },
|
|
Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) },
|
|
Source::Otc => unimplemented!("OTC rate limit not implemented."),
|
|
})),
|
|
clickhouse_client: clickhouse::Client::default()
|
|
.with_url(env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."))
|
|
.with_user(env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."))
|
|
.with_password(
|
|
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."),
|
|
)
|
|
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
|
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
|
sequence_classifier: std::sync::Mutex::new(
|
|
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
|
ModelType::Bert,
|
|
ModelResource::Torch(Box::new(LocalResource {
|
|
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
|
|
})),
|
|
LocalResource {
|
|
local_path: PathBuf::from("./models/finbert/config.json"),
|
|
},
|
|
LocalResource {
|
|
local_path: PathBuf::from("./models/finbert/vocab.txt"),
|
|
},
|
|
None,
|
|
true,
|
|
None,
|
|
None,
|
|
))
|
|
.unwrap(),
|
|
),
|
|
}
|
|
}
|
|
|
|
pub fn arc_from_env() -> Arc<Self> {
|
|
Arc::new(Self::from_env())
|
|
}
|
|
}
|