Remove rust-bert
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,6 +11,3 @@ log/
|
||||
*.pdb
|
||||
|
||||
.env*
|
||||
|
||||
# ML models
|
||||
models/*/rust_model.ot
|
||||
|
693
Cargo.lock
generated
693
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -51,6 +51,7 @@ clickhouse = { version = "0.11.6", features = [
|
||||
] }
|
||||
uuid = { version = "1.6.1", features = [
|
||||
"serde",
|
||||
"v4",
|
||||
] }
|
||||
time = { version = "0.3.31", features = [
|
||||
"serde",
|
||||
@@ -64,8 +65,6 @@ backoff = { version = "0.4.0", features = [
|
||||
"tokio",
|
||||
] }
|
||||
regex = "1.10.3"
|
||||
html-escape = "0.2.13"
|
||||
rust-bert = "0.22.0"
|
||||
async-trait = "0.1.77"
|
||||
itertools = "0.12.1"
|
||||
lazy_static = "1.4.0"
|
||||
|
@@ -1,32 +0,0 @@
|
||||
{
|
||||
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
|
||||
"architectures": [
|
||||
"BertForSequenceClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "positive",
|
||||
"1": "negative",
|
||||
"2": "neutral"
|
||||
},
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"label2id": {
|
||||
"positive": 0,
|
||||
"negative": 1,
|
||||
"neutral": 2
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
@@ -1 +0,0 @@
|
||||
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
@@ -1 +0,0 @@
|
||||
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}
|
30522
models/finbert/vocab.txt
30522
models/finbert/vocab.txt
File diff suppressed because it is too large
Load Diff
@@ -5,14 +5,8 @@ use reqwest::{
|
||||
header::{HeaderMap, HeaderName, HeaderValue},
|
||||
Client,
|
||||
};
|
||||
use rust_bert::{
|
||||
pipelines::{
|
||||
common::{ModelResource, ModelType},
|
||||
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
|
||||
},
|
||||
resources::LocalResource,
|
||||
};
|
||||
use std::{env, num::NonZeroU32, path::PathBuf, sync::Arc};
|
||||
|
||||
use std::{env, num::NonZeroU32, sync::Arc};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
lazy_static! {
|
||||
@@ -40,10 +34,6 @@ lazy_static! {
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be set.")
|
||||
.parse()
|
||||
.expect("BATCH_BACKFILL_NEWS_SIZE must be a positive integer.");
|
||||
pub static ref BERT_MAX_INPUTS: usize = env::var("BERT_MAX_INPUTS")
|
||||
.expect("BERT_MAX_INPUTS must be set.")
|
||||
.parse()
|
||||
.expect("BERT_MAX_INPUTS must be a positive integer.");
|
||||
pub static ref CLICKHOUSE_MAX_CONNECTIONS: usize = env::var("CLICKHOUSE_MAX_CONNECTIONS")
|
||||
.expect("CLICKHOUSE_MAX_CONNECTIONS must be set.")
|
||||
.parse()
|
||||
@@ -55,7 +45,6 @@ pub struct Config {
|
||||
pub alpaca_rate_limiter: DefaultDirectRateLimiter,
|
||||
pub clickhouse_client: clickhouse::Client,
|
||||
pub clickhouse_concurrency_limiter: Arc<Semaphore>,
|
||||
pub sequence_classifier: std::sync::Mutex<SequenceClassificationModel>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -89,25 +78,6 @@ impl Config {
|
||||
)
|
||||
.with_database(env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.")),
|
||||
clickhouse_concurrency_limiter: Arc::new(Semaphore::new(*CLICKHOUSE_MAX_CONNECTIONS)),
|
||||
sequence_classifier: std::sync::Mutex::new(
|
||||
SequenceClassificationModel::new(SequenceClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::Torch(Box::new(LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
|
||||
})),
|
||||
LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/config.json"),
|
||||
},
|
||||
LocalResource {
|
||||
local_path: PathBuf::from("./models/finbert/vocab.txt"),
|
||||
},
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
types::{self, alpaca::shared::news::normalize_html_content},
|
||||
types::{self, alpaca::shared::news::strip},
|
||||
utils::de,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -46,13 +46,11 @@ impl From<News> for types::News {
|
||||
time_created: news.time_created,
|
||||
time_updated: news.time_updated,
|
||||
symbols: news.symbols,
|
||||
headline: normalize_html_content(&news.headline),
|
||||
author: normalize_html_content(&news.author),
|
||||
source: normalize_html_content(&news.source),
|
||||
summary: normalize_html_content(&news.summary),
|
||||
content: normalize_html_content(&news.content),
|
||||
sentiment: types::news::Sentiment::Neutral,
|
||||
confidence: 0.0,
|
||||
headline: strip(&news.headline),
|
||||
author: strip(&news.author),
|
||||
source: strip(&news.source),
|
||||
summary: news.summary,
|
||||
content: news.content,
|
||||
url: news.url.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
@@ -1,4 +1,3 @@
|
||||
use html_escape::decode_html_entities;
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
|
||||
@@ -7,12 +6,10 @@ lazy_static! {
|
||||
static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap();
|
||||
}
|
||||
|
||||
pub fn normalize_html_content(content: &str) -> String {
|
||||
pub fn strip(content: &str) -> String {
|
||||
let content = content.replace('\n', " ");
|
||||
let content = RE_TAGS.replace_all(&content, "");
|
||||
let content = RE_SPACES.replace_all(&content, " ");
|
||||
let content = decode_html_entities(&content);
|
||||
let content = content.trim();
|
||||
|
||||
content.to_string()
|
||||
}
|
||||
|
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News},
|
||||
types::{alpaca::shared::news::strip, News},
|
||||
utils::de,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -31,13 +31,11 @@ impl From<Message> for News {
|
||||
time_created: news.time_created,
|
||||
time_updated: news.time_updated,
|
||||
symbols: news.symbols,
|
||||
headline: normalize_html_content(&news.headline),
|
||||
author: normalize_html_content(&news.author),
|
||||
source: normalize_html_content(&news.source),
|
||||
summary: normalize_html_content(&news.summary),
|
||||
content: normalize_html_content(&news.content),
|
||||
sentiment: Sentiment::Neutral,
|
||||
confidence: 0.0,
|
||||
headline: strip(&news.headline),
|
||||
author: strip(&news.author),
|
||||
source: strip(&news.source),
|
||||
summary: news.summary,
|
||||
content: news.content,
|
||||
url: news.url.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
@@ -1,48 +1,7 @@
|
||||
use clickhouse::Row;
|
||||
use rust_bert::pipelines::sequence_classification::Label;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use std::str::FromStr;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(i8)]
|
||||
pub enum Sentiment {
|
||||
Positive = 1,
|
||||
Neutral = 0,
|
||||
Negative = -1,
|
||||
}
|
||||
|
||||
impl FromStr for Sentiment {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"positive" => Ok(Self::Positive),
|
||||
"neutral" => Ok(Self::Neutral),
|
||||
"negative" => Ok(Self::Negative),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub struct Prediction {
|
||||
pub sentiment: Sentiment,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
impl TryFrom<Label> for Prediction {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(label: Label) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
sentiment: Sentiment::from_str(&label.text)?,
|
||||
confidence: label.score,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
|
||||
pub struct News {
|
||||
pub id: i64,
|
||||
@@ -56,7 +15,5 @@ pub struct News {
|
||||
pub source: String,
|
||||
pub summary: String,
|
||||
pub content: String,
|
||||
pub sentiment: Sentiment,
|
||||
pub confidence: f64,
|
||||
pub url: String,
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ mod threads;
|
||||
|
||||
use config::{
|
||||
Config, ALPACA_API_BASE, ALPACA_MODE, ALPACA_SOURCE, BATCH_BACKFILL_BARS_SIZE,
|
||||
BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS, CLICKHOUSE_MAX_CONNECTIONS,
|
||||
BATCH_BACKFILL_NEWS_SIZE, CLICKHOUSE_MAX_CONNECTIONS,
|
||||
};
|
||||
use dotenv::dotenv;
|
||||
use log4rs::config::Deserializers;
|
||||
@@ -27,7 +27,6 @@ async fn main() {
|
||||
let _ = *ALPACA_SOURCE;
|
||||
let _ = *BATCH_BACKFILL_BARS_SIZE;
|
||||
let _ = *BATCH_BACKFILL_NEWS_SIZE;
|
||||
let _ = *BERT_MAX_INPUTS;
|
||||
let _ = *CLICKHOUSE_MAX_CONNECTIONS;
|
||||
|
||||
try_join!(
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE, BERT_MAX_INPUTS},
|
||||
config::{Config, ALPACA_SOURCE, BATCH_BACKFILL_NEWS_SIZE},
|
||||
database,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
@@ -10,7 +10,6 @@ use qrust::{
|
||||
types::{
|
||||
self,
|
||||
alpaca::shared::{Sort, Source},
|
||||
news::Prediction,
|
||||
Backfill, News,
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
@@ -19,7 +18,7 @@ use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{task::block_in_place, time::sleep};
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
@@ -83,8 +82,6 @@ impl super::Handler for Handler {
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
let mut news = Vec::with_capacity(*BATCH_BACKFILL_NEWS_SIZE);
|
||||
let mut batch = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||
let mut predictions = Vec::with_capacity(*BERT_MAX_INPUTS);
|
||||
let mut last_times = HashMap::new();
|
||||
let mut next_page_token = None;
|
||||
|
||||
@@ -122,46 +119,9 @@ impl super::Handler for Handler {
|
||||
}
|
||||
}
|
||||
|
||||
batch.push(news_item);
|
||||
news.push(news_item);
|
||||
}
|
||||
|
||||
if batch.len() < *BERT_MAX_INPUTS
|
||||
&& batch.len() < *BATCH_BACKFILL_NEWS_SIZE
|
||||
&& message.next_page_token.is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let inputs = batch
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for chunk in inputs.chunks(*BERT_MAX_INPUTS) {
|
||||
let chunk_predictions = block_in_place(|| {
|
||||
self.config
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(chunk.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
});
|
||||
|
||||
predictions.extend(chunk_predictions);
|
||||
}
|
||||
|
||||
let zipped = batch
|
||||
.drain(..)
|
||||
.zip(predictions.drain(..))
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
});
|
||||
|
||||
news.extend(zipped);
|
||||
|
||||
if news.len() < *BATCH_BACKFILL_NEWS_SIZE && message.next_page_token.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
@@ -2,9 +2,9 @@ use super::State;
|
||||
use crate::{config::Config, database};
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, error, info};
|
||||
use qrust::types::{alpaca::websocket, news::Prediction, News};
|
||||
use qrust::types::{alpaca::websocket, News};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{sync::RwLock, task::block_in_place};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
@@ -80,25 +80,6 @@ impl super::Handler for Handler {
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let prediction = block_in_place(|| {
|
||||
self.config
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
|
||||
database::news::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
|
@@ -1,12 +1,4 @@
|
||||
FROM rust:bookworm
|
||||
|
||||
RUN apt-get update -y && apt-get upgrade -y
|
||||
RUN apt-get install -y python3 python3-setuptools python3-pip
|
||||
RUN apt-get clean
|
||||
RUN rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu --break-system-packages
|
||||
ENV LIBTORCH_USE_PYTORCH 1
|
||||
FROM rust
|
||||
|
||||
RUN rustup install nightly
|
||||
RUN rustup component add rustfmt clippy --toolchain nightly
|
||||
|
@@ -51,8 +51,6 @@ CREATE TABLE IF NOT EXISTS qrust.news (
|
||||
source String,
|
||||
summary String,
|
||||
content String,
|
||||
sentiment Enum('positive' = 1, 'neutral' = 0, 'negative' = -1),
|
||||
confidence Float64,
|
||||
url String,
|
||||
INDEX index_symbols symbols TYPE bloom_filter()
|
||||
)
|
||||
|
Reference in New Issue
Block a user