diff --git a/Cargo.lock b/Cargo.lock index 31285ac..6298fd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,7 +233,6 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-targets 0.52.0", ] @@ -326,41 +325,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "darling" -version = "0.20.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.48", -] - -[[package]] -name = "darling_macro" -version = "0.20.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.48", -] - [[package]] name = "dashmap" version = "5.5.3" @@ -689,12 +653,6 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "html-escape" version = "0.2.13" @@ -875,12 +833,6 @@ dependencies = [ "cc", ] -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "0.5.0" @@ -899,7 +851,6 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", - "serde", ] [[package]] @@ -910,7 +861,6 @@ checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", "hashbrown 0.14.3", - "serde", ] [[package]] @@ -1309,7 +1259,6 @@ dependencies = [ "serde", "serde_json", "serde_repr", - "serde_with", "time", "tokio", "tokio-tungstenite", @@ -1539,9 +1488,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] @@ -1558,9 +1507,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", @@ -1580,9 +1529,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.111" +version = "1.0.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" +checksum = "4d1bd37ce2324cf3bf85e5a25f96eb4baf0d5aa6eba43e7ae8958870c4ec48ed" dependencies = [ "itoa", "ryu", @@ -1622,35 +1571,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_with" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5c9fdb6b00a489875b22efd4b78fe2b363b72265cc5f6eb2e2b9ee270e6140c" -dependencies = [ - "base64", - "chrono", - "hex", - "indexmap 1.9.3", - "indexmap 2.1.0", - "serde", - "serde_json", - "serde_with_macros", - "time", -] - -[[package]] -name = "serde_with_macros" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbff351eb4b33600a2e138dfa0b10b65a238ea8ff8fb2387c422c5022a3e8298" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.48", -] - [[package]] name = "serde_yaml" version = "0.8.26" @@ -1716,12 +1636,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "syn" version = "1.0.109" diff --git a/Cargo.toml b/Cargo.toml index e24cf28..9023d83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ log4rs = "1.2.0" serde = "1.0.188" serde_json = "1.0.105" serde_repr = "0.1.18" -serde_with = "3.5.1" futures-util = "0.3.28" reqwest = { version = "0.11.20", features = [ "json", diff --git a/docker-compose.yml b/docker-compose.yml index b7d110f..ce283a8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,6 +4,11 @@ services: file: support/clickhouse/docker-compose.yml service: clickhouse + ollama: + extends: + file: support/ollama/docker-compose.yml + service: ollama + grafana: extends: file: support/grafana/docker-compose.yml @@ -19,10 +24,12 @@ services: - 7878:7878 depends_on: - clickhouse + - ollama env_file: - .env.docker volumes: clickhouse-lib: clickhouse-log: + ollama: grafana-lib: diff --git a/src/config.rs b/src/config.rs index 906d762..ef463a9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Client, }; -use std::{env, num::NonZeroU32, sync::Arc}; +use std::{env, num::NonZeroU32, sync::Arc, time::Duration}; pub const ALPACA_ASSET_API_URL: &str = "https://api.alpaca.markets/v2/assets"; pub const ALPACA_CLOCK_API_URL: &str = "https://api.alpaca.markets/v2/clock"; @@ -19,9 +19,12 @@ pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1 pub struct Config { pub alpaca_api_key: String, pub alpaca_api_secret: String, - pub alpaca_client: Client, pub alpaca_rate_limit: DefaultDirectRateLimiter, pub alpaca_source: Source, + pub alpaca_client: Client, + pub ollama_url: String, + pub ollama_model: String, + pub ollama_client: Client, pub clickhouse_client: clickhouse::Client, } @@ -35,6 +38,9 @@ impl Config { .parse() .expect("ALPACA_SOURCE must be a either 'iex' or 'sip'."); + let ollama_url = env::var("OLLAMA_URL").expect("OLLAMA_URL must be set."); + let ollama_model = env::var("OLLAMA_MODEL").expect("OLLAMA_MODEL must be set."); + let clickhouse_url = env::var("CLICKHOUSE_URL").expect("CLICKHOUSE_URL must be set."); let clickhouse_user = env::var("CLICKHOUSE_USER").expect("CLICKHOUSE_USER must be set."); let clickhouse_password = @@ -55,20 +61,27 @@ impl Config { .expect("Alpaca API secret must not contain invalid characters."), ), ])) + .timeout(Duration::from_secs(60)) .build() .unwrap(), alpaca_rate_limit: RateLimiter::direct(Quota::per_minute(match alpaca_source { - Source::Iex => unsafe { NonZeroU32::new_unchecked(200) }, - Source::Sip => unsafe { NonZeroU32::new_unchecked(10000) }, + Source::Iex => unsafe { NonZeroU32::new_unchecked(190) }, + Source::Sip => unsafe { NonZeroU32::new_unchecked(9990) }, })), alpaca_source, + alpaca_api_key, + alpaca_api_secret, + ollama_url, + ollama_model, + ollama_client: Client::builder() + .timeout(Duration::from_secs(15)) + .build() + .unwrap(), clickhouse_client: clickhouse::Client::default() .with_url(clickhouse_url) .with_user(clickhouse_user) .with_password(clickhouse_password) .with_database(clickhouse_db), - alpaca_api_key, - alpaca_api_secret, } } diff --git a/src/main.rs b/src/main.rs index 25667cc..06337b7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,11 +9,11 @@ mod threads; mod types; mod utils; -use crate::utils::cleanup; use config::Config; use dotenv::dotenv; use log4rs::config::Deserializers; use tokio::{spawn, sync::mpsc}; +use utils::{cleanup::cleanup, init}; #[tokio::main] async fn main() { @@ -22,9 +22,10 @@ async fn main() { let app_config = Config::arc_from_env(); cleanup(&app_config.clickhouse_client).await; + init::ollama(&app_config).await; let (asset_status_sender, asset_status_receiver) = - mpsc::channel::(100); + mpsc::channel::(10); let (clock_sender, clock_receiver) = mpsc::channel::(1); spawn(threads::data::run( diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 681311a..9d0d6b8 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -56,10 +56,7 @@ pub async fn add( .send() .await? .error_for_status() - .map_err(|e| match e.status() { - Some(reqwest::StatusCode::NOT_FOUND) => backoff::Error::Permanent(e), - _ => e.into(), - })? + .map_err(backoff::Error::Permanent)? .json::() .await .map_err(backoff::Error::Permanent) diff --git a/src/threads/clock.rs b/src/threads/clock.rs index 34b7853..ba8e9ab 100644 --- a/src/threads/clock.rs +++ b/src/threads/clock.rs @@ -44,7 +44,8 @@ pub async fn run(app_config: Arc, clock_sender: mpsc::Sender) { .get(ALPACA_CLOCK_API_URL) .send() .await? - .error_for_status()? + .error_for_status() + .map_err(backoff::Error::Permanent)? .json::() .await .map_err(backoff::Error::Permanent) diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index 06fd10f..1536964 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -3,13 +3,14 @@ use crate::{ config::{Config, ALPACA_CRYPTO_DATA_URL, ALPACA_NEWS_DATA_URL, ALPACA_STOCK_DATA_URL}, database, types::{ - alpaca::{api, Source}, - Asset, Bar, Class, News, Subset, + alpaca::{self, Source}, + ollama, Asset, Bar, Class, News, Subset, }, utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}, }; use backoff::{future::retry, ExponentialBackoff}; use log::{error, info}; +use serde_json::{from_str, to_string}; use std::{collections::HashMap, sync::Arc}; use time::OffsetDateTime; use tokio::{ @@ -244,18 +245,19 @@ async fn execute_backfill_bars( app_config .alpaca_client .get(&data_url) - .query(&api::outgoing::bar::Bar::new( - vec![symbol.clone()], - ONE_MINUTE, - fetch_from, - fetch_to, - 10000, - next_page_token.clone(), - )) + .query(&alpaca::api::outgoing::bar::Bar { + symbols: vec![symbol.clone()], + timeframe: ONE_MINUTE, + start: fetch_from, + end: fetch_to, + limit: 10000, + page_token: next_page_token.clone(), + }) .send() .await? - .error_for_status()? - .json::() + .error_for_status() + .map_err(backoff::Error::Permanent)? + .json::() .await .map_err(backoff::Error::Permanent) }) @@ -318,19 +320,20 @@ async fn execute_backfill_news( app_config .alpaca_client .get(&data_url) - .query(&api::outgoing::news::News::new( - vec![symbol.clone()], - fetch_from, - fetch_to, - 50, - true, - false, - next_page_token.clone(), - )) + .query(&alpaca::api::outgoing::news::News { + symbols: vec![symbol.clone()], + start: fetch_from, + end: fetch_to, + limit: 50, + include_content: true, + exclude_contentless: false, + page_token: next_page_token.clone(), + }) .send() .await? - .error_for_status()? - .json::() + .error_for_status() + .map_err(backoff::Error::Permanent)? + .json::() .await .map_err(backoff::Error::Permanent) }) @@ -361,6 +364,53 @@ async fn execute_backfill_news( return; } + for news in &mut news { + info!( + "{:?} - Getting sentiment for news: {}.", + thread_type, news.headline + ); + + let prediction = retry(ExponentialBackoff::default(), || async { + let response = app_config + .ollama_client + .post(format!("{}/api/chat", app_config.ollama_url)) + .body( + to_string(&ollama::outgoing::sentiment::Sentiment::new( + app_config.ollama_model.clone(), + &news.clone().into(), + )) + .unwrap(), + ) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + from_str::(&response.message.content) + .map_err(Into::into) + }) + .await; + + match prediction { + Ok(prediction) => { + info!( + "{:?} - Received sentiment for news: {:?}.", + thread_type, prediction + ); + news.sentiment = prediction.sentiment.into(); + news.confidence = prediction.confidence.into(); + } + Err(e) => { + error!( + "{:?} - Failed to get sentiment for news: {:?}.", + thread_type, e + ); + } + } + } + let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); database::news::upsert_batch(&app_config.clickhouse_client, news).await; database::backfills::upsert(&app_config.clickhouse_client, &thread_type, &backfill).await; diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index 4bad2f5..d973c3a 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -27,16 +27,6 @@ pub struct Guard { pub pending_unsubscriptions: HashMap, } -impl Guard { - pub fn new() -> Self { - Self { - symbols: HashSet::new(), - pending_subscriptions: HashMap::new(), - pending_unsubscriptions: HashMap::new(), - } - } -} - #[derive(Clone, Copy, Debug)] pub enum ThreadType { Bars(Class), @@ -86,7 +76,11 @@ async fn init_thread( mpsc::Sender, mpsc::Sender, ) { - let guard = Arc::new(RwLock::new(Guard::new())); + let guard = Arc::new(RwLock::new(Guard { + symbols: HashSet::new(), + pending_subscriptions: HashMap::new(), + pending_unsubscriptions: HashMap::new(), + })); let websocket_url = match thread_type { ThreadType::Bars(Class::UsEquity) => format!( @@ -102,7 +96,7 @@ async fn init_thread( authenticate(&app_config, &mut websocket_sender, &mut websocket_receiver).await; let websocket_sender = Arc::new(Mutex::new(websocket_sender)); - let (asset_status_sender, asset_status_receiver) = mpsc::channel(100); + let (asset_status_sender, asset_status_receiver) = mpsc::channel(10); spawn(asset_status::run( app_config.clone(), thread_type, @@ -111,7 +105,7 @@ async fn init_thread( websocket_sender.clone(), )); - let (backfill_sender, backfill_receiver) = mpsc::channel(100); + let (backfill_sender, backfill_receiver) = mpsc::channel(10); spawn(backfill::run( app_config.clone(), thread_type, diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index dd94faf..8dd6c94 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -2,14 +2,15 @@ use super::{backfill, Guard, ThreadType}; use crate::{ config::Config, database, - types::{alpaca::websocket, Bar, News, Subset}, + types::{alpaca::websocket, ollama, Bar, News, Subset}, }; +use backoff::{future::retry, ExponentialBackoff}; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use log::{error, info, warn}; -use serde_json::from_str; +use serde_json::{from_str, to_string}; use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -93,6 +94,7 @@ async fn handle_websocket_message( } #[allow(clippy::significant_drop_tightening)] +#[allow(clippy::too_many_lines)] async fn handle_parsed_websocket_message( app_config: Arc, thread_type: ThreadType, @@ -179,7 +181,7 @@ async fn handle_parsed_websocket_message( database::bars::upsert(&app_config.clickhouse_client, &bar).await; } websocket::incoming::Message::News(message) => { - let news = News::from(message); + let mut news = News::from(message); let symbols = news.symbols.clone().into_iter().collect::>(); let guard = guard.read().await; @@ -195,6 +197,52 @@ async fn handle_parsed_websocket_message( "{:?} - Received news for {:?}: {}.", thread_type, news.symbols, news.time_created ); + + info!( + "{:?} - Getting sentiment for news: {}.", + thread_type, news.headline + ); + + let prediction = retry(ExponentialBackoff::default(), || async { + let response = app_config + .ollama_client + .post(format!("{}/api/chat", app_config.ollama_url)) + .body( + to_string(&ollama::outgoing::sentiment::Sentiment::new( + app_config.ollama_model.clone(), + &news.clone().into(), + )) + .unwrap(), + ) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + from_str::(&response.message.content) + .map_err(Into::into) + }) + .await; + + match prediction { + Ok(prediction) => { + info!( + "{:?} - Received sentiment for news: {:?}.", + thread_type, prediction + ); + news.sentiment = prediction.sentiment.into(); + news.confidence = prediction.confidence.into(); + } + Err(e) => { + error!( + "{:?} - Failed to get sentiment for news: {:?}.", + thread_type, e + ); + } + } + database::news::upsert(&app_config.clickhouse_client, &news).await; } websocket::incoming::Message::Success(_) => {} diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index 8f59ca8..ebc8edd 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -1,7 +1,7 @@ use crate::types::{self, alpaca::api::impl_from_enum}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Class { UsEquity, @@ -10,7 +10,7 @@ pub enum Class { impl_from_enum!(types::Class, Class, UsEquity, Crypto); -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "UPPERCASE")] pub enum Exchange { Amex, @@ -36,7 +36,7 @@ impl_from_enum!( Crypto ); -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub enum Status { Active, @@ -44,7 +44,7 @@ pub enum Status { } #[allow(clippy::struct_excessive_bools)] -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Deserialize)] pub struct Asset { pub id: String, pub class: Class, diff --git a/src/types/alpaca/api/incoming/bar.rs b/src/types/alpaca/api/incoming/bar.rs index e078d38..37feb60 100644 --- a/src/types/alpaca/api/incoming/bar.rs +++ b/src/types/alpaca/api/incoming/bar.rs @@ -1,9 +1,9 @@ use crate::types; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use std::collections::HashMap; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Deserialize)] pub struct Bar { #[serde(rename = "t")] #[serde(with = "time::serde::rfc3339")] @@ -40,7 +40,7 @@ impl From<(Bar, String)> for types::Bar { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Deserialize)] pub struct Message { pub bars: HashMap>, pub next_page_token: Option, diff --git a/src/types/alpaca/api/incoming/clock.rs b/src/types/alpaca/api/incoming/clock.rs index 51bafce..e808b99 100644 --- a/src/types/alpaca/api/incoming/clock.rs +++ b/src/types/alpaca/api/incoming/clock.rs @@ -1,7 +1,7 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] pub struct Clock { #[serde(with = "time::serde::rfc3339")] pub timestamp: OffsetDateTime, diff --git a/src/types/alpaca/api/incoming/news.rs b/src/types/alpaca/api/incoming/news.rs index 60715a3..8e3fb9d 100644 --- a/src/types/alpaca/api/incoming/news.rs +++ b/src/types/alpaca/api/incoming/news.rs @@ -1,9 +1,14 @@ -use crate::{types, utils::normalize_news_content}; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; +use crate::{ + types::{ + self, + news::{Confidence, Sentiment}, + }, + utils::news, +}; +use serde::Deserialize; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub enum ImageSize { Thumb, @@ -11,14 +16,13 @@ pub enum ImageSize { Large, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] pub struct Image { pub size: ImageSize, pub url: String, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde_as] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] pub struct News { pub id: i64, #[serde(with = "time::serde::rfc3339")] @@ -28,17 +32,11 @@ pub struct News { #[serde(rename = "updated_at")] pub time_updated: OffsetDateTime, pub symbols: Vec, - #[serde_as(as = "NoneAsEmptyString")] - pub headline: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub author: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub source: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub summary: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub content: Option, - #[serde_as(as = "NoneAsEmptyString")] + pub headline: String, + pub author: String, + pub source: String, + pub summary: String, + pub content: String, pub url: Option, pub images: Vec, } @@ -50,17 +48,16 @@ impl From for types::News { time_created: news.time_created, time_updated: news.time_updated, symbols: news.symbols, - headline: normalize_news_content(news.headline), - author: normalize_news_content(news.author), - source: normalize_news_content(news.source), - summary: normalize_news_content(news.summary), - content: normalize_news_content(news.content), - url: news.url, + headline: news::normalize(&news.headline), + author: news::normalize(&news.author), + content: news::normalize(&news.content), + sentiment: Sentiment::Neutral, + confidence: Confidence::VeryUncertain, } } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] pub struct Message { pub news: Vec, pub next_page_token: Option, diff --git a/src/types/alpaca/api/outgoing/bar.rs b/src/types/alpaca/api/outgoing/bar.rs index 48b758f..d3577b4 100644 --- a/src/types/alpaca/api/outgoing/bar.rs +++ b/src/types/alpaca/api/outgoing/bar.rs @@ -49,23 +49,3 @@ pub struct Bar { #[serde(skip_serializing_if = "Option::is_none")] pub page_token: Option, } - -impl Bar { - pub const fn new( - symbols: Vec, - timeframe: Duration, - start: OffsetDateTime, - end: OffsetDateTime, - limit: i64, - page_token: Option, - ) -> Self { - Self { - symbols, - timeframe, - start, - end, - limit, - page_token, - } - } -} diff --git a/src/types/alpaca/api/outgoing/news.rs b/src/types/alpaca/api/outgoing/news.rs index 8bc64c2..f5d4dc3 100644 --- a/src/types/alpaca/api/outgoing/news.rs +++ b/src/types/alpaca/api/outgoing/news.rs @@ -16,25 +16,3 @@ pub struct News { #[serde(skip_serializing_if = "Option::is_none")] pub page_token: Option, } - -impl News { - pub const fn new( - symbols: Vec, - start: OffsetDateTime, - end: OffsetDateTime, - limit: i64, - include_content: bool, - exclude_contentless: bool, - page_token: Option, - ) -> Self { - Self { - symbols, - start, - end, - limit, - include_content, - exclude_contentless, - page_token, - } - } -} diff --git a/src/types/alpaca/websocket/incoming/bar.rs b/src/types/alpaca/websocket/incoming/bar.rs index 7a4d986..7f71ef6 100644 --- a/src/types/alpaca/websocket/incoming/bar.rs +++ b/src/types/alpaca/websocket/incoming/bar.rs @@ -1,8 +1,8 @@ use crate::types; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Deserialize)] pub struct Message { #[serde(rename = "t")] #[serde(with = "time::serde::rfc3339")] diff --git a/src/types/alpaca/websocket/incoming/error.rs b/src/types/alpaca/websocket/incoming/error.rs index 714ba8e..4b2ba29 100644 --- a/src/types/alpaca/websocket/incoming/error.rs +++ b/src/types/alpaca/websocket/incoming/error.rs @@ -1,6 +1,6 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Message { pub code: u16, diff --git a/src/types/alpaca/websocket/incoming/mod.rs b/src/types/alpaca/websocket/incoming/mod.rs index 7309e2c..c955f9d 100644 --- a/src/types/alpaca/websocket/incoming/mod.rs +++ b/src/types/alpaca/websocket/incoming/mod.rs @@ -4,9 +4,9 @@ pub mod news; pub mod subscription; pub mod success; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Deserialize)] #[serde(tag = "T")] pub enum Message { #[serde(rename = "success")] diff --git a/src/types/alpaca/websocket/incoming/news.rs b/src/types/alpaca/websocket/incoming/news.rs index 5563cb5..1f75f3c 100644 --- a/src/types/alpaca/websocket/incoming/news.rs +++ b/src/types/alpaca/websocket/incoming/news.rs @@ -1,10 +1,14 @@ -use crate::{types, utils::normalize_news_content}; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; +use crate::{ + types::{ + self, + news::{Confidence, Sentiment}, + }, + utils::news, +}; +use serde::Deserialize; use time::OffsetDateTime; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde_as] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] pub struct Message { pub id: i64, #[serde(with = "time::serde::rfc3339")] @@ -14,17 +18,11 @@ pub struct Message { #[serde(rename = "updated_at")] pub time_updated: OffsetDateTime, pub symbols: Vec, - #[serde_as(as = "NoneAsEmptyString")] - pub headline: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub author: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub source: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub summary: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub content: Option, - #[serde_as(as = "NoneAsEmptyString")] + pub headline: String, + pub author: String, + pub source: String, + pub summary: String, + pub content: String, pub url: Option, } @@ -35,12 +33,11 @@ impl From for types::News { time_created: news.time_created, time_updated: news.time_updated, symbols: news.symbols, - headline: normalize_news_content(news.headline), - author: normalize_news_content(news.author), - source: normalize_news_content(news.source), - summary: normalize_news_content(news.summary), - content: normalize_news_content(news.content), - url: news.url, + headline: news::normalize(&news.headline), + author: news::normalize(&news.author), + content: news::normalize(&news.content), + sentiment: Sentiment::Neutral, + confidence: Confidence::VeryUncertain, } } } diff --git a/src/types/alpaca/websocket/incoming/subscription.rs b/src/types/alpaca/websocket/incoming/subscription.rs index 92b5d91..7e44d02 100644 --- a/src/types/alpaca/websocket/incoming/subscription.rs +++ b/src/types/alpaca/websocket/incoming/subscription.rs @@ -1,6 +1,6 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub struct MarketMessage { pub trades: Vec, @@ -14,13 +14,13 @@ pub struct MarketMessage { pub cancel_errors: Option>, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub struct NewsMessage { pub news: Vec, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(untagged)] pub enum Message { Market(MarketMessage), diff --git a/src/types/alpaca/websocket/incoming/success.rs b/src/types/alpaca/websocket/incoming/success.rs index c130169..be6592e 100644 --- a/src/types/alpaca/websocket/incoming/success.rs +++ b/src/types/alpaca/websocket/incoming/success.rs @@ -1,6 +1,6 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] #[serde(tag = "msg")] #[serde(rename_all = "camelCase")] pub enum Message { diff --git a/src/types/alpaca/websocket/outgoing/auth.rs b/src/types/alpaca/websocket/outgoing/auth.rs index a8933e5..94b10fb 100644 --- a/src/types/alpaca/websocket/outgoing/auth.rs +++ b/src/types/alpaca/websocket/outgoing/auth.rs @@ -2,12 +2,6 @@ use serde::Serialize; #[derive(Serialize)] pub struct Message { - key: String, - secret: String, -} - -impl Message { - pub const fn new(key: String, secret: String) -> Self { - Self { key, secret } - } + pub key: String, + pub secret: String, } diff --git a/src/types/asset.rs b/src/types/asset.rs index 13809b3..67bebac 100644 --- a/src/types/asset.rs +++ b/src/types/asset.rs @@ -4,14 +4,14 @@ use serde_repr::{Deserialize_repr, Serialize_repr}; use time::OffsetDateTime; #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] -#[repr(u8)] +#[repr(i8)] pub enum Class { UsEquity = 1, Crypto = 2, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] -#[repr(u8)] +#[repr(i8)] pub enum Exchange { Amex = 1, Arca = 2, diff --git a/src/types/backfill.rs b/src/types/backfill.rs index e940d76..92ad587 100644 --- a/src/types/backfill.rs +++ b/src/types/backfill.rs @@ -10,20 +10,20 @@ pub struct Backfill { pub time: OffsetDateTime, } -impl Backfill { - pub const fn new(symbol: String, time: OffsetDateTime) -> Self { - Self { symbol, time } - } -} - impl From for Backfill { fn from(bar: Bar) -> Self { - Self::new(bar.symbol, bar.time) + Self { + symbol: bar.symbol, + time: bar.time, + } } } impl From<(News, String)> for Backfill { fn from((news, symbol): (News, String)) -> Self { - Self::new(symbol, news.time_created) + Self { + symbol, + time: news.time_created, + } } } diff --git a/src/types/mod.rs b/src/types/mod.rs index b19bfab..0cd85f8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,6 +4,7 @@ pub mod asset; pub mod backfill; pub mod bar; pub mod news; +pub mod ollama; pub use algebraic::Subset; pub use asset::{Asset, Class, Exchange}; diff --git a/src/types/news.rs b/src/types/news.rs index 1900329..4085b44 100644 --- a/src/types/news.rs +++ b/src/types/news.rs @@ -1,10 +1,33 @@ use clickhouse::Row; use serde::{Deserialize, Serialize}; -use serde_with::serde_as; +use serde_repr::{Deserialize_repr, Serialize_repr}; use time::OffsetDateTime; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(i8)] +pub enum Sentiment { + VeryNegative = -3, + Negative = -2, + MildlyNegative = -1, + Neutral = 0, + MildlyPositive = 1, + Positive = 2, + VeryPositive = 3, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(i8)] +pub enum Confidence { + VeryUncertain = -3, + Uncertain = -2, + MildlyUncertain = -1, + Neutral = 0, + MildlyCertain = 1, + Certain = 2, + VeryCertain = 3, +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] -#[serde_as] pub struct News { pub id: i64, #[serde(with = "clickhouse::serde::time::datetime")] @@ -12,16 +35,9 @@ pub struct News { #[serde(with = "clickhouse::serde::time::datetime")] pub time_updated: OffsetDateTime, pub symbols: Vec, - #[serde_as(as = "NoneAsEmptyString")] - pub headline: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub author: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub source: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub summary: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub content: Option, - #[serde_as(as = "NoneAsEmptyString")] - pub url: Option, + pub headline: String, + pub author: String, + pub content: String, + pub sentiment: Sentiment, + pub confidence: Confidence, } diff --git a/src/types/ollama/chat.rs b/src/types/ollama/chat.rs new file mode 100644 index 0000000..220050e --- /dev/null +++ b/src/types/ollama/chat.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: String, +} diff --git a/src/types/ollama/incoming/mod.rs b/src/types/ollama/incoming/mod.rs new file mode 100644 index 0000000..656dfb5 --- /dev/null +++ b/src/types/ollama/incoming/mod.rs @@ -0,0 +1,2 @@ +pub mod pull; +pub mod sentiment; diff --git a/src/types/ollama/incoming/pull.rs b/src/types/ollama/incoming/pull.rs new file mode 100644 index 0000000..ec325a5 --- /dev/null +++ b/src/types/ollama/incoming/pull.rs @@ -0,0 +1,6 @@ +use serde::Deserialize; + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] +pub struct Response { + pub status: String, +} diff --git a/src/types/ollama/incoming/sentiment.rs b/src/types/ollama/incoming/sentiment.rs new file mode 100644 index 0000000..808a2f0 --- /dev/null +++ b/src/types/ollama/incoming/sentiment.rs @@ -0,0 +1,75 @@ +use crate::types::{self, ollama::chat::Message}; +use serde::Deserialize; +use time::OffsetDateTime; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Sentiment { + VeryNegative, + Negative, + MildlyNegative, + Neutral, + MildlyPositive, + Positive, + VeryPositive, +} + +impl From for types::news::Sentiment { + fn from(sentiment: Sentiment) -> Self { + match sentiment { + Sentiment::VeryNegative => Self::VeryNegative, + Sentiment::Negative => Self::Negative, + Sentiment::MildlyNegative => Self::MildlyNegative, + Sentiment::Neutral => Self::Neutral, + Sentiment::MildlyPositive => Self::MildlyPositive, + Sentiment::Positive => Self::Positive, + Sentiment::VeryPositive => Self::VeryPositive, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Confidence { + VeryUncertain, + Uncertain, + MildlyUncertain, + Neutral, + MildlyCertain, + Certain, + VeryCertain, +} + +impl From for types::news::Confidence { + fn from(confidence: Confidence) -> Self { + match confidence { + Confidence::VeryUncertain => Self::VeryUncertain, + Confidence::Uncertain => Self::Uncertain, + Confidence::MildlyUncertain => Self::MildlyUncertain, + Confidence::Neutral => Self::Neutral, + Confidence::MildlyCertain => Self::MildlyCertain, + Confidence::Certain => Self::Certain, + Confidence::VeryCertain => Self::VeryCertain, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)] +pub struct Prediction { + pub sentiment: Sentiment, + pub confidence: Confidence, +} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] +pub struct Response { + pub model: String, + #[serde(with = "time::serde::rfc3339")] + pub created_at: OffsetDateTime, + pub message: Message, + pub done: bool, + pub total_duration: i64, + pub load_duration: i64, + pub prompt_eval_duration: i64, + pub eval_count: i64, + pub eval_duration: i64, +} diff --git a/src/types/ollama/mod.rs b/src/types/ollama/mod.rs new file mode 100644 index 0000000..50b1c56 --- /dev/null +++ b/src/types/ollama/mod.rs @@ -0,0 +1,3 @@ +pub mod chat; +pub mod incoming; +pub mod outgoing; diff --git a/src/types/ollama/outgoing/mod.rs b/src/types/ollama/outgoing/mod.rs new file mode 100644 index 0000000..656dfb5 --- /dev/null +++ b/src/types/ollama/outgoing/mod.rs @@ -0,0 +1,2 @@ +pub mod pull; +pub mod sentiment; diff --git a/src/types/ollama/outgoing/pull.rs b/src/types/ollama/outgoing/pull.rs new file mode 100644 index 0000000..0740733 --- /dev/null +++ b/src/types/ollama/outgoing/pull.rs @@ -0,0 +1,16 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub struct Pull { + name: String, + stream: bool, +} + +impl Pull { + pub const fn new(name: String) -> Self { + Self { + name, + stream: false, + } + } +} diff --git a/src/types/ollama/outgoing/sentiment.rs b/src/types/ollama/outgoing/sentiment.rs new file mode 100644 index 0000000..454053f --- /dev/null +++ b/src/types/ollama/outgoing/sentiment.rs @@ -0,0 +1,64 @@ +use crate::types::{ + self, + ollama::chat::{self, Message}, +}; +use serde::Serialize; +use serde_json::to_string; + +const PROMPT: &str = r#"You are SentimentLLama, a news classification AI. Users will input a news headline or article, and you will output a sentiment and confidence in one line of JSON. + +For sentiment, pick out of ["very_negative", "negative", "mildly_negative", "neutral", "mildly_positive", "positive", "very_positive"]. + +For confidence, pick out of ["very_uncertain", "uncertain", "mildly_uncertain", "neutral", "mildly_certain", "certain", "very_certain"]."#; + +#[derive(Serialize)] +struct ModelOptions { + temperature: f64, + seed: i64, +} + +#[derive(Serialize)] +pub struct News { + headline: String, +} + +impl From for News { + fn from(news: types::News) -> Self { + Self { + headline: news.headline, + } + } +} + +#[derive(Serialize)] +pub struct Sentiment { + model: String, + messages: Vec, + format: String, + stream: bool, + options: ModelOptions, +} + +impl Sentiment { + pub fn new(model: String, content: &News) -> Self { + Self { + model, + messages: vec![ + Message { + role: chat::Role::System, + content: PROMPT.to_string(), + }, + Message { + role: chat::Role::User, + content: to_string(content).unwrap(), + }, + ], + format: "json".to_string(), + stream: false, + options: ModelOptions { + temperature: 0.0, + seed: 0, + }, + } + } +} diff --git a/src/utils/init.rs b/src/utils/init.rs new file mode 100644 index 0000000..9ac8d34 --- /dev/null +++ b/src/utils/init.rs @@ -0,0 +1,24 @@ +use crate::{config::Config, types::ollama}; +use serde_json::to_string; +use std::time::Duration; + +pub async fn ollama(app_config: &Config) { + let response = app_config + .ollama_client + .post(format!("{}/api/pull", app_config.ollama_url)) + .body( + to_string(&ollama::outgoing::pull::Pull::new( + app_config.ollama_model.clone(), + )) + .unwrap(), + ) + .timeout(Duration::MAX) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + assert!(response.status == "success", "Failed to pull Ollama model."); +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9d111eb..33bc578 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,8 @@ pub mod cleanup; +pub mod init; pub mod news; pub mod time; pub mod websocket; -pub use cleanup::cleanup; -pub use news::normalize_news_content; pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE}; pub use websocket::authenticate; diff --git a/src/utils/news.rs b/src/utils/news.rs index 084e217..d3b90e2 100644 --- a/src/utils/news.rs +++ b/src/utils/news.rs @@ -1,10 +1,7 @@ use html_escape::decode_html_entities; use regex::Regex; -pub fn normalize_news_content(content: Option) -> Option { - content.as_ref()?; - let content = content.unwrap(); - +pub fn normalize(content: &str) -> String { let re_tags = Regex::new("<[^>]+>").unwrap(); let re_spaces = Regex::new("[\\u00A0\\s]+").unwrap(); @@ -12,11 +9,5 @@ pub fn normalize_news_content(content: Option) -> Option { let content = re_tags.replace_all(&content, ""); let content = re_spaces.replace_all(&content, " "); let content = decode_html_entities(&content); - let content = content.trim(); - - if content.is_empty() { - None - } else { - Some(content.to_string()) - } + content.trim().to_string() } diff --git a/src/utils/websocket.rs b/src/utils/websocket.rs index 5aef8ee..6983bff 100644 --- a/src/utils/websocket.rs +++ b/src/utils/websocket.rs @@ -28,10 +28,10 @@ pub async fn authenticate( sender .send(Message::Text( to_string(&websocket::outgoing::Message::Auth( - websocket::outgoing::auth::Message::new( - app_config.alpaca_api_key.clone(), - app_config.alpaca_api_secret.clone(), - ), + websocket::outgoing::auth::Message { + key: app_config.alpaca_api_key.clone(), + secret: app_config.alpaca_api_secret.clone(), + }, )) .unwrap(), )) diff --git a/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql b/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql index 2426bee..d70ee48 100644 --- a/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql +++ b/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql @@ -47,10 +47,25 @@ CREATE TABLE IF NOT EXISTS qrust.news ( symbols Array(LowCardinality(String)), headline String, author String, - source String, - summary String, content String, - url String, + sentiment Enum( + 'very_negative' = -3, + 'negative' = -2, + 'mildly_negative' = -1, + 'neutral' = 0, + 'mildly_positive' = 1, + 'positive' = 2, + 'very_positive' = 3 + ), + confidence Enum( + 'very_uncertain' = -3, + 'uncertain' = -2, + 'mildly_uncertain' = -1, + 'neutral' = 0, + 'mildly_certain' = 1, + 'certain' = 2, + 'very_certain' = 3 + ), INDEX index_symbols symbols TYPE bloom_filter() ) ENGINE = ReplacingMergeTree() diff --git a/support/ollama/docker-compose.yml b/support/ollama/docker-compose.yml new file mode 100644 index 0000000..c192393 --- /dev/null +++ b/support/ollama/docker-compose.yml @@ -0,0 +1,20 @@ +services: + ollama: + image: ollama/ollama + hostname: ollama + restart: unless-stopped + volumes: + - ollama:/root/.ollama + ports: + - target: 11434 + published: 11434 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + +volumes: + ollama: