Add finbert sentiment analysis

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-03 18:58:40 +00:00
parent 973917dad2
commit 65c9ae8b25
26 changed files with 31460 additions and 215 deletions

3
.gitignore vendored
View File

@@ -10,3 +10,6 @@ target/
*.pdb *.pdb
.env* .env*
# ML models
models/*/rust_model.ot

761
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -27,7 +27,6 @@ log4rs = "1.2.0"
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.105" serde_json = "1.0.105"
serde_repr = "0.1.18" serde_repr = "0.1.18"
serde_with = "3.5.1"
futures-util = "0.3.28" futures-util = "0.3.28"
reqwest = { version = "0.11.20", features = [ reqwest = { version = "0.11.20", features = [
"json", "json",
@@ -52,3 +51,4 @@ backoff = { version = "0.4.0", features = [
] } ] }
regex = "1.10.3" regex = "1.10.3"
html-escape = "0.2.13" html-escape = "0.2.13"
rust-bert = "0.22.0"

View File

@@ -0,0 +1,32 @@
{
"_name_or_path": "/home/ubuntu/finbert/models/language_model/finbertTRC2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "positive",
"1": "negative",
"2": "neutral"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"positive": 0,
"negative": 1,
"neutral": 2
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -0,0 +1 @@
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

View File

@@ -0,0 +1 @@
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "bert-base-uncased"}

30522
models/finbert/vocab.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,19 @@ use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue}, header::{HeaderMap, HeaderName, HeaderValue},
Client, Client,
}; };
use std::{env, num::NonZeroU32, sync::Arc}; use rust_bert::{
pipelines::{
common::{ModelResource, ModelType},
sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel},
},
resources::LocalResource,
};
use std::{
env,
num::NonZeroU32,
path::PathBuf,
sync::{Arc, Mutex},
};
pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets"; pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets";
pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock"; pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock";
@@ -23,6 +35,8 @@ pub struct Config {
pub alpaca_rate_limit: DefaultDirectRateLimiter, pub alpaca_rate_limit: DefaultDirectRateLimiter,
pub alpaca_source: Source, pub alpaca_source: Source,
pub clickhouse_client: clickhouse::Client, pub clickhouse_client: clickhouse::Client,
pub max_bert_inputs: usize,
pub sequence_classifier: Arc<Mutex<SequenceClassificationModel>>,
} }
impl Config { impl Config {
@@ -41,6 +55,11 @@ impl Config {
env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set."); env::var("CLICKHOUSE_PASSWORD").expect("CLICKHOUSE_PASSWORD must be set.");
let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set."); let clickhouse_db = env::var("CLICKHOUSE_DB").expect("CLICKHOUSE_DB must be set.");
let max_bert_inputs: usize = env::var("MAX_BERT_INPUTS")
.expect("MAX_BERT_INPUTS must be set.")
.parse()
.expect("MAX_BERT_INPUTS must be a positive integer.");
Self { Self {
alpaca_client: Client::builder() alpaca_client: Client::builder()
.default_headers(HeaderMap::from_iter([ .default_headers(HeaderMap::from_iter([
@@ -69,6 +88,26 @@ impl Config {
.with_database(clickhouse_db), .with_database(clickhouse_db),
alpaca_api_key, alpaca_api_key,
alpaca_api_secret, alpaca_api_secret,
max_bert_inputs,
sequence_classifier: Arc::new(Mutex::new(
SequenceClassificationModel::new(SequenceClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(LocalResource {
local_path: PathBuf::from("./models/finbert/rust_model.ot"),
})),
LocalResource {
local_path: PathBuf::from("./models/finbert/config.json"),
},
LocalResource {
local_path: PathBuf::from("./models/finbert/vocab.txt"),
},
None,
true,
None,
None,
))
.unwrap(),
)),
} }
} }

View File

@@ -25,25 +25,31 @@ pub async fn delete_where_symbols<T>(clickhouse_client: &Client, symbols: &[T])
where where
T: AsRef<str> + Serialize + Send + Sync, T: AsRef<str> + Serialize + Send + Sync,
{ {
let remaining_symbols = assets::select(clickhouse_client)
.await
.into_iter()
.map(|asset| asset.abbreviation)
.collect::<Vec<_>>();
clickhouse_client clickhouse_client
.query("DELETE FROM news WHERE hasAny(symbols, ?)") .query("DELETE FROM news WHERE hasAny(symbols, ?) AND NOT hasAny(symbols, ?)")
.bind(symbols) .bind(symbols)
.bind(remaining_symbols)
.execute() .execute()
.await .await
.unwrap(); .unwrap();
} }
pub async fn cleanup(clickhouse_client: &Client) { pub async fn cleanup(clickhouse_client: &Client) {
let assets = assets::select(clickhouse_client).await; let remaining_symbols = assets::select(clickhouse_client)
.await
let symbols = assets
.into_iter() .into_iter()
.map(|asset| asset.abbreviation) .map(|asset| asset.abbreviation)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
clickhouse_client clickhouse_client
.query("DELETE FROM news WHERE NOT hasAny(symbols, ?)") .query("DELETE FROM news WHERE NOT hasAny(symbols, ?)")
.bind(symbols) .bind(remaining_symbols)
.execute() .execute()
.await .await
.unwrap(); .unwrap();

View File

@@ -4,18 +4,19 @@ use crate::{
database, database,
types::{ types::{
alpaca::{api, Source}, alpaca::{api, Source},
news::Prediction,
Asset, Bar, Class, News, Subset, Asset, Bar, Class, News, Subset,
}, },
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}, utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
}; };
use backoff::{future::retry, ExponentialBackoff}; use backoff::{future::retry, ExponentialBackoff};
use log::{error, info}; use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{ use tokio::{
join, spawn, join, spawn,
sync::{mpsc, oneshot, Mutex, RwLock}, sync::{mpsc, oneshot, Mutex, RwLock},
task::JoinHandle, task::{spawn_blocking, JoinHandle},
time::sleep, time::sleep,
}; };
@@ -103,11 +104,14 @@ async fn handle_backfill_message(
match message.action { match message.action {
Action::Backfill => { Action::Backfill => {
for symbol in symbols { for symbol in symbols {
if let Some(job) = backfill_jobs.remove(&symbol) { if let Some(job) = backfill_jobs.get(&symbol) {
if !job.is_finished() { if !job.is_finished() {
job.abort(); warn!(
"{:?} - Backfill for {} is already running, skipping.",
thread_type, symbol
);
continue;
} }
let _ = job.await;
} }
let app_config = app_config.clone(); let app_config = app_config.clone();
@@ -361,7 +365,41 @@ async fn execute_backfill_news(
return; return;
} }
let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); let app_config_clone = app_config.clone();
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions: Vec<Prediction> = spawn_blocking(move || {
inputs
.chunks(app_config_clone.max_bert_inputs)
.flat_map(|inputs| {
app_config_clone
.sequence_classifier
.lock()
.unwrap()
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
.collect()
})
.await
.unwrap();
let news = news
.into_iter()
.zip(predictions.into_iter())
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
let backfill = (news[0].clone(), symbol.clone()).into();
database::news::upsert_batch(&app_config.clickhouse_client, news).await; database::news::upsert_batch(&app_config.clickhouse_client, news).await;
database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await; database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await;

View File

@@ -27,16 +27,6 @@ pub struct Guard {
pub pending_unsubscriptions: HashMap<String, Asset>, pub pending_unsubscriptions: HashMap<String, Asset>,
} }
impl Guard {
pub fn new() -> Self {
Self {
symbols: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}
}
}
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub enum ThreadType { pub enum ThreadType {
Bars(Class), Bars(Class),
@@ -86,7 +76,11 @@ async fn init_thread(
mpsc::Sender<asset_status::Message>, mpsc::Sender<asset_status::Message>,
mpsc::Sender<backfill::Message>, mpsc::Sender<backfill::Message>,
) { ) {
let guard = Arc::new(RwLock::new(Guard::new())); let guard = Arc::new(RwLock::new(Guard {
symbols: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let websocket_url = match thread_type { let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => format!( ThreadType::Bars(Class::UsEquity) => format!(

View File

@@ -2,7 +2,7 @@ use super::{backfill, Guard, ThreadType};
use crate::{ use crate::{
config::Config, config::Config,
database, database,
types::{alpaca::websocket, Bar, News, Subset}, types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
}; };
use futures_util::{ use futures_util::{
stream::{SplitSink, SplitStream}, stream::{SplitSink, SplitStream},
@@ -19,6 +19,7 @@ use tokio::{
net::TcpStream, net::TcpStream,
spawn, spawn,
sync::{mpsc, Mutex, RwLock}, sync::{mpsc, Mutex, RwLock},
task::spawn_blocking,
}; };
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
@@ -93,6 +94,7 @@ async fn handle_websocket_message(
} }
#[allow(clippy::significant_drop_tightening)] #[allow(clippy::significant_drop_tightening)]
#[allow(clippy::too_many_lines)]
async fn handle_parsed_websocket_message( async fn handle_parsed_websocket_message(
app_config: Arc<Config>, app_config: Arc<Config>,
thread_type: ThreadType, thread_type: ThreadType,
@@ -195,6 +197,28 @@ async fn handle_parsed_websocket_message(
"{:?} - Received news for {:?}: {}.", "{:?} - Received news for {:?}: {}.",
thread_type, news.symbols, news.time_created thread_type, news.symbols, news.time_created
); );
let app_config_clone = app_config.clone();
let input = format!("{}\n\n{}", news.headline, news.content);
let prediction = spawn_blocking(move || {
app_config_clone
.sequence_classifier
.lock()
.unwrap()
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
})
.await
.unwrap();
let news = News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
};
database::news::upsert(&app_config.clickhouse_client, &news).await; database::news::upsert(&app_config.clickhouse_client, &news).await;
} }
websocket::incoming::Message::Success(_) => {} websocket::incoming::Message::Success(_) => {}

View File

@@ -1,7 +1,7 @@
use crate::types::{self, alpaca::api::impl_from_enum}; use crate::types::{self, alpaca::api::impl_from_enum};
use serde::{Deserialize, Serialize}; use serde::Deserialize;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Class { pub enum Class {
UsEquity, UsEquity,
@@ -10,7 +10,7 @@ pub enum Class {
impl_from_enum!(types::Class, Class, UsEquity, Crypto); impl_from_enum!(types::Class, Class, UsEquity, Crypto);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "UPPERCASE")] #[serde(rename_all = "UPPERCASE")]
pub enum Exchange { pub enum Exchange {
Amex, Amex,
@@ -36,7 +36,7 @@ impl_from_enum!(
Crypto Crypto
); );
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum Status { pub enum Status {
Active, Active,
@@ -44,7 +44,7 @@ pub enum Status {
} }
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct Asset { pub struct Asset {
pub id: String, pub id: String,
pub class: Class, pub class: Class,

View File

@@ -1,9 +1,9 @@
use crate::types; use crate::types;
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct Bar { pub struct Bar {
#[serde(rename = "t")] #[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -40,7 +40,7 @@ impl From<(Bar, String)> for types::Bar {
} }
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct Message { pub struct Message {
pub bars: HashMap<String, Vec<Bar>>, pub bars: HashMap<String, Vec<Bar>>,
pub next_page_token: Option<String>, pub next_page_token: Option<String>,

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
pub struct Clock { pub struct Clock {
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
pub timestamp: OffsetDateTime, pub timestamp: OffsetDateTime,

View File

@@ -1,9 +1,8 @@
use crate::{types, utils::normalize_news_content}; use crate::{types, utils::normalize_news_content};
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use serde_with::serde_as;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum ImageSize { pub enum ImageSize {
Thumb, Thumb,
@@ -11,14 +10,13 @@ pub enum ImageSize {
Large, Large,
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
pub struct Image { pub struct Image {
pub size: ImageSize, pub size: ImageSize,
pub url: String, pub url: String,
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde_as]
pub struct News { pub struct News {
pub id: i64, pub id: i64,
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -28,17 +26,11 @@ pub struct News {
#[serde(rename = "updated_at")] #[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime, pub time_updated: OffsetDateTime,
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde_as(as = "NoneAsEmptyString")] pub headline: String,
pub headline: Option<String>, pub author: String,
#[serde_as(as = "NoneAsEmptyString")] pub source: String,
pub author: Option<String>, pub summary: String,
#[serde_as(as = "NoneAsEmptyString")] pub content: String,
pub source: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub summary: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub content: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub url: Option<String>, pub url: Option<String>,
pub images: Vec<Image>, pub images: Vec<Image>,
} }
@@ -50,17 +42,19 @@ impl From<News> for types::News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: normalize_news_content(news.headline), headline: normalize_news_content(&news.headline),
author: normalize_news_content(news.author), author: normalize_news_content(&news.author),
source: normalize_news_content(news.source), source: normalize_news_content(&news.source),
summary: normalize_news_content(news.summary), summary: normalize_news_content(&news.summary),
content: normalize_news_content(news.content), content: normalize_news_content(&news.content),
url: news.url, sentiment: types::news::Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(),
} }
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
pub struct Message { pub struct Message {
pub news: Vec<News>, pub news: Vec<News>,
pub next_page_token: Option<String>, pub next_page_token: Option<String>,

View File

@@ -1,8 +1,8 @@
use crate::types; use crate::types;
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct Message { pub struct Message {
#[serde(rename = "t")] #[serde(rename = "t")]
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]

View File

@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize}; use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Message { pub struct Message {
pub code: u16, pub code: u16,

View File

@@ -4,9 +4,9 @@ pub mod news;
pub mod subscription; pub mod subscription;
pub mod success; pub mod success;
use serde::{Deserialize, Serialize}; use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Deserialize)]
#[serde(tag = "T")] #[serde(tag = "T")]
pub enum Message { pub enum Message {
#[serde(rename = "success")] #[serde(rename = "success")]

View File

@@ -1,10 +1,8 @@
use crate::{types, utils::normalize_news_content}; use crate::{types, utils::normalize_news_content};
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use serde_with::serde_as;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde_as]
pub struct Message { pub struct Message {
pub id: i64, pub id: i64,
#[serde(with = "time::serde::rfc3339")] #[serde(with = "time::serde::rfc3339")]
@@ -14,17 +12,11 @@ pub struct Message {
#[serde(rename = "updated_at")] #[serde(rename = "updated_at")]
pub time_updated: OffsetDateTime, pub time_updated: OffsetDateTime,
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde_as(as = "NoneAsEmptyString")] pub headline: String,
pub headline: Option<String>, pub author: String,
#[serde_as(as = "NoneAsEmptyString")] pub source: String,
pub author: Option<String>, pub summary: String,
#[serde_as(as = "NoneAsEmptyString")] pub content: String,
pub source: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub summary: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub content: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub url: Option<String>, pub url: Option<String>,
} }
@@ -35,12 +27,14 @@ impl From<Message> for types::News {
time_created: news.time_created, time_created: news.time_created,
time_updated: news.time_updated, time_updated: news.time_updated,
symbols: news.symbols, symbols: news.symbols,
headline: normalize_news_content(news.headline), headline: normalize_news_content(&news.headline),
author: normalize_news_content(news.author), author: normalize_news_content(&news.author),
source: normalize_news_content(news.source), source: normalize_news_content(&news.source),
summary: normalize_news_content(news.summary), summary: normalize_news_content(&news.summary),
content: normalize_news_content(news.content), content: normalize_news_content(&news.content),
url: news.url, sentiment: types::news::Sentiment::Neutral,
confidence: 0.0,
url: news.url.unwrap_or_default(),
} }
} }
} }

View File

@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize}; use serde::Deserialize;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct MarketMessage { pub struct MarketMessage {
pub trades: Vec<String>, pub trades: Vec<String>,
@@ -14,13 +14,13 @@ pub struct MarketMessage {
pub cancel_errors: Option<Vec<String>>, pub cancel_errors: Option<Vec<String>>,
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct NewsMessage { pub struct NewsMessage {
pub news: Vec<String>, pub news: Vec<String>,
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum Message { pub enum Message {
Market(MarketMessage), Market(MarketMessage),

View File

@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize}; use serde::Deserialize;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(tag = "msg")] #[serde(tag = "msg")]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum Message { pub enum Message {

View File

@@ -4,14 +4,14 @@ use serde_repr::{Deserialize_repr, Serialize_repr};
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(u8)] #[repr(i8)]
pub enum Class { pub enum Class {
UsEquity = 1, UsEquity = 1,
Crypto = 2, Crypto = 2,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[repr(u8)] #[repr(i8)]
pub enum Exchange { pub enum Exchange {
Amex = 1, Amex = 1,
Arca = 2, Arca = 2,

View File

@@ -1,10 +1,49 @@
use clickhouse::Row; use clickhouse::Row;
use rust_bert::pipelines::sequence_classification::Label;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::serde_as; use serde_repr::{Deserialize_repr, Serialize_repr};
use std::str::FromStr;
use time::OffsetDateTime; use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)]
#[serde_as] #[repr(i8)]
pub enum Sentiment {
Positive = 1,
Neutral = 0,
Negative = -1,
}
impl FromStr for Sentiment {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"positive" => Ok(Self::Positive),
"neutral" => Ok(Self::Neutral),
"negative" => Ok(Self::Negative),
_ => Err(()),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Prediction {
pub sentiment: Sentiment,
pub confidence: f64,
}
impl TryFrom<Label> for Prediction {
type Error = ();
fn try_from(label: Label) -> Result<Self, Self::Error> {
Ok(Self {
sentiment: Sentiment::from_str(&label.text)?,
confidence: label.score,
})
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)]
pub struct News { pub struct News {
pub id: i64, pub id: i64,
#[serde(with = "clickhouse::serde::time::datetime")] #[serde(with = "clickhouse::serde::time::datetime")]
@@ -12,16 +51,12 @@ pub struct News {
#[serde(with = "clickhouse::serde::time::datetime")] #[serde(with = "clickhouse::serde::time::datetime")]
pub time_updated: OffsetDateTime, pub time_updated: OffsetDateTime,
pub symbols: Vec<String>, pub symbols: Vec<String>,
#[serde_as(as = "NoneAsEmptyString")] pub headline: String,
pub headline: Option<String>, pub author: String,
#[serde_as(as = "NoneAsEmptyString")] pub source: String,
pub author: Option<String>, pub summary: String,
#[serde_as(as = "NoneAsEmptyString")] pub content: String,
pub source: Option<String>, pub sentiment: Sentiment,
#[serde_as(as = "NoneAsEmptyString")] pub confidence: f64,
pub summary: Option<String>, pub url: String,
#[serde_as(as = "NoneAsEmptyString")]
pub content: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
pub url: Option<String>,
} }

View File

@@ -1,10 +1,7 @@
use html_escape::decode_html_entities; use html_escape::decode_html_entities;
use regex::Regex; use regex::Regex;
pub fn normalize_news_content(content: Option<String>) -> Option<String> { pub fn normalize_news_content(content: &str) -> String {
content.as_ref()?;
let content = content.unwrap();
let re_tags = Regex::new("<[^>]+>").unwrap(); let re_tags = Regex::new("<[^>]+>").unwrap();
let re_spaces = Regex::new("[\\u00A0\\s]+").unwrap(); let re_spaces = Regex::new("[\\u00A0\\s]+").unwrap();
@@ -14,9 +11,5 @@ pub fn normalize_news_content(content: Option<String>) -> Option<String> {
let content = decode_html_entities(&content); let content = decode_html_entities(&content);
let content = content.trim(); let content = content.trim();
if content.is_empty() { content.to_string()
None
} else {
Some(content.to_string())
}
} }

View File

@@ -50,6 +50,8 @@ CREATE TABLE IF NOT EXISTS qrust.news (
source String, source String,
summary String, summary String,
content String, content String,
sentiment Enum('positive' = 1, 'neutral' = 0, 'negative' = -1),
confidence Float64,
url String, url String,
INDEX index_symbols symbols TYPE bloom_filter() INDEX index_symbols symbols TYPE bloom_filter()
) )