Add finbert sentiment analysis

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-03 18:58:40 +00:00
parent 973917dad2
commit 65c9ae8b25
26 changed files with 31460 additions and 215 deletions

View File

@@ -4,7 +4,19 @@ use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client,
};
use std::{env, num::NonZeroU32, sync::Arc};
use rust_bert::{
pipelines::{
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{
env,
num::NonZeroU32,
path::PathBuf,
sync::{Arc, Mutex},
};
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets";
pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock";
@@ -23,6 +35,8 @@ pub struct Config {
pub alpaca_rate_limit: DefaultDirectRateLimiter,
pub alpaca_source: Source,
pub clickhouse_client: clickhouse::Client,
pub max_bert_inputs: usize,
pub sequence_classifier: Arc<Mutex<SequenceClassificationModel>>,
}
impl Config {
@@ -41,6 +55,11 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.");
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
let max_bert_inputs: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse()
.expect("MAX_BERT_INPUTS must be a positive integer.");
Self {
alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([
@@ -69,6 +88,26 @@ impl Config {
.with_database(clickhouse_db),
alpaca_api_key,
alpaca_api_secret,
max_bert_inputs,
sequence_classifier: Arc::new(Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
)),
}
}