Remove rust-bert

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-03-13 12:09:50 +00:00
parent 5ed0c7670a
commit f3f9c6336b
17 changed files with 47 additions and 31402 deletions

3
.gitignore vendored
View File

@@ -11,6 +11,3 @@ log/
*.pdb *.pdb
.env* .env*
# ML models
models/*/rust_model.ot

693
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -51,6 +51,7 @@ clickhouse = { version = "0.11.6", features = [
] } ] }
uuid = { version = "1.6.1", features = [ uuid = { version = "1.6.1", features = [
"serde", "serde",
"v4",
] } ] }
time = { version = "0.3.31", features = [ time = { version = "0.3.31", features = [
"serde", "serde",
@@ -64,8 +65,6 @@ backoff = { version = "0.4.0", features = [
"tokio", "tokio",
] } ] }
regex = "1.10.3" regex = "1.10.3"
html-escape = "0.2.13"
rust-bert = "0.22.0"
async-trait = "0.1.77" async-trait = "0.1.77"
itertools = "0.12.1" itertools = "0.12.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"

View File

@@ -1,32 +0,0 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -1 +0,0 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -1 +0,0 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

File diff suppressed because it is too large Load Diff

View File

@@ -5,14 +5,8 @@ use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue}, header::{HeaderMap, HeaderName, HeaderValue},
Client, Client,
}; };
use rust_bert::{
pipelines::{ use std::{env, num::NonZeroU32, sync::Arc};
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
lazy_static! { lazy_static! {
@@ -40,10 +34,6 @@ lazy_static! {
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.") .expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
.parse() .parse()
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer."); .expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
.expect("BERT_MAX_INPUTS must be set.")
.parse()
.expect("BERT_MAX_INPUTS must be a positive integer.");
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS") pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.") .expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
.parse() .parse()
@@ -55,7 +45,6 @@ pub struct Config {
pub alpaca_rate_limiter: DefaultDirectRateLimiter, pub alpaca_rate_limiter: DefaultDirectRateLimiter,
pub clickhouse_client: clickhouse::Client, pub clickhouse_client: clickhouse::Client,
pub clickhouse_concurrency_limiter: Arc<Semaphore>, pub clickhouse_concurrency_limiter: Arc<Semaphore>,
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
} }
impl Config { impl Config {
@@ -89,25 +78,6 @@ impl Config {
) )
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")), .with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)), clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
sequence_classifier: std::sync::Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
),
} }
} }

View File

@@ -1,5 +1,5 @@
use crate::{ use crate::{
types::{self, alpaca::shared::news::normalize_html_content}, types::{self, alpaca::shared::news::strip},
utils::de, utils::de,
}; };
use serde::Deserialize; use serde::Deserialize;
@@ -46,13 +46,11 @@ impl From<News> for types::News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: normalize_html_content(&news.headline), headline: strip(&news.headline),
author: normalize_html_content(&news.author), author: strip(&news.author),
source: normalize_html_content(&news.source), source: strip(&news.source),
summary: normalize_html_content(&news.summary), summary: news.summary,
content: normalize_html_content(&news.content), content: news.content,
sentiment: types::news::Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(), url: news.url.unwrap_or_default(),
} }
} }

View File

@@ -1,4 +1,3 @@
use html_escape::decode_html_entities;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
@@ -7,12 +6,10 @@ lazy_static! {
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap(); static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
} }
pub fn normalize_html_content(content: &str) -> String { pub fn strip(content: &str) -> String {
let content = content.replace('\n', " "); let content = content.replace('\n', " ");
let content = RE_TAGS.replace_all(&content, ""); let content = RE_TAGS.replace_all(&content, "");
let content = RE_SPACES.replace_all(&content, " "); let content = RE_SPACES.replace_all(&content, " ");
let content = decode_html_entities(&content);
let content = content.trim(); let content = content.trim();
content.to_string() content.to_string()
} }

View File

@@ -1,5 +1,5 @@
use crate::{ use crate::{
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News}, types::{alpaca::shared::news::strip, News},
utils::de, utils::de,
}; };
use serde::Deserialize; use serde::Deserialize;
@@ -31,13 +31,11 @@ impl From<Message> for News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: normalize_html_content(&news.headline), headline: strip(&news.headline),
author: normalize_html_content(&news.author), author: strip(&news.author),
source: normalize_html_content(&news.source), source: strip(&news.source),
summary: normalize_html_content(&news.summary), summary: news.summary,
content: normalize_html_content(&news.content), content: news.content,
sentiment: Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(), url: news.url.unwrap_or_default(),
} }
} }

View File

@@ -1,48 +1,7 @@
use clickhouse::Row; use clickhouse::Row;
use rust_bert::pipelines::sequence_classification::Label;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::str::FromStr;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum Sentiment {
Positive = 1,
Neutral = 0,
Negative = -1,
}
impl FromStr for Sentiment {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"positive" => Ok(Self::Positive),
"neutral" => Ok(Self::Neutral),
"negative" => Ok(Self::Negative),
_ => Err(()),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Prediction {
pub sentiment: Sentiment,
pub confidence: f64,
}
impl TryFrom<Label> for Prediction {
type Error = ();
fn try_from(label: Label) -> Result<Self, Self::Error> {
Ok(Self {
sentiment: Sentiment::from_str(&label.text)?,
confidence: label.score,
})
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News { pub struct News {
pub id: i64, pub id: i64,
@@ -56,7 +15,5 @@ pub struct News {
pub source: String, pub source: String,
pub summary: String, pub summary: String,
pub content: String, pub content: String,
pub sentiment: Sentiment,
pub confidence: f64,
pub url: String, pub url: String,
} }

View File

@@ -9,7 +9,7 @@ mod threads;
use config::{ use config::{
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE, Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE,
BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS, BATCH_BACKFILL_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
}; };
use dotenv::dotenv; use dotenv::dotenv;
use log4rs::config::Deserializers; use log4rs::config::Deserializers;
@@ -27,7 +27,6 @@ async fn main() {
let _ = *ALPACA_SOURCE; let _ = *ALPACA_SOURCE;
let _ = *BATCH_BACKFILL_BARS_SIZE; let _ = *BATCH_BACKFILL_BARS_SIZE;
let _ = *BATCH_BACKFILL_NEWS_SIZE; let _ = *BATCH_BACKFILL_NEWS_SIZE;
let _ = *BERT_MAX_INPUTS;
let _ = *CLICKHOUSE_MAX_CONNECTIONS; let _ = *CLICKHOUSE_MAX_CONNECTIONS;
try_join!( try_join!(

View File

@@ -1,6 +1,6 @@
use super::Job; use super::Job;
use crate::{ use crate::{
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS}, config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE},
database, database,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@@ -10,7 +10,6 @@ use qrust::{
types::{ types::{
self, self,
alpaca::shared::{Sort, Source}, alpaca::shared::{Sort, Source},
news::Prediction,
Backfill, News, Backfill, News,
}, },
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
@@ -19,7 +18,7 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
sync::Arc, sync::Arc,
}; };
use tokio::{task::block_in_place, time::sleep}; use tokio::time::sleep;
pub struct Handler { pub struct Handler {
pub config: Arc<Config>, pub config: Arc<Config>,
@@ -83,8 +82,6 @@ impl super::Handler for Handler {
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE); let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE);
let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS);
let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS);
let mut last_times = HashMap::new(); let mut last_times = HashMap::new();
let mut next_page_token = None; let mut next_page_token = None;
@@ -122,46 +119,9 @@ impl super::Handler for Handler {
} }
} }
batch.push(news_item); news.push(news_item);
} }
if batch.len() < *BERT_MAX_INPUTS
&& batch.len() < *BATCH_BACKFILL_NEWS_SIZE
&& message.next_page_token.is_some()
{
continue;
}
let inputs = batch
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
for chunk in inputs.chunks(*BERT_MAX_INPUTS) {
let chunk_predictions = block_in_place(|| {
self.config
.sequence_classifier
.lock()
.unwrap()
.predict(chunk.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
});
predictions.extend(chunk_predictions);
}
let zipped = batch
.drain(..)
.zip(predictions.drain(..))
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
});
news.extend(zipped);
if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() { if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() {
continue; continue;
} }

View File

@@ -2,9 +2,9 @@ use super::State;
use crate::{config::Config, database}; use crate::{config::Config, database};
use async_trait::async_trait; use async_trait::async_trait;
use log::{debug, error, info}; use log::{debug, error, info};
use qrust::types::{alpaca::websocket, news::Prediction, News}; use qrust::types::{alpaca::websocket, News};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tokio::{sync::RwLock, task::block_in_place}; use tokio::sync::RwLock;
pub struct Handler { pub struct Handler {
pub config: Arc<Config>, pub config: Arc<Config>,
@@ -80,25 +80,6 @@ impl super::Handler for Handler {
news.symbols, news.time_created news.symbols, news.time_created
); );
let input = format!("{}\n\n{}", news.headline, news.content);
let prediction = block_in_place(|| {
self.config
.sequence_classifier
.lock()
.unwrap()
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
});
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert( database::news::upsert(
&self.config.clickhouse_client, &self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter, &self.config.clickhouse_concurrency_limiter,

View File

@@ -1,12 +1,4 @@
FROM rust:bookworm FROM rust
RUN apt-get update -y && apt-get upgrade -y
RUN apt-get install -y python3 python3-setuptools python3-pip
RUN apt-get clean
RUN rm -rf /var/lib/apt/lists/*
RUN pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu --break-system-packages
ENV LIBTORCH_USE_PYTORCH 1
RUN rustup install nightly RUN rustup install nightly
RUN rustup component add rustfmt clippy --toolchain nightly RUN rustup component add rustfmt clippy --toolchain nightly

View File

@@ -51,8 +51,6 @@ CREATE TABLE IF NOT EXISTS qrust.news (
source String, source String,
summary String, summary String,
content String, content String,
sentiment Enum('positive' = 1, 'neutral' = 0, 'negative' = -1),
confidence Float64,
url String, url String,
INDEX index_symbols symbols TYPE bloom_filter() INDEX index_symbols symbols TYPE bloom_filter()
) )