Remove stored abbreviation

- Alpaca is fuck

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-04 21:24:14 +00:00
parent 65c9ae8b25
commit 61c573cbc7
22 changed files with 180 additions and 153 deletions

View File

@@ -78,20 +78,20 @@ async fn handle_asset_status_message(
.assets
.clone()
.into_iter()
.map(|asset| match thread_type {
ThreadType::Bars(_) => asset.symbol,
ThreadType::News => asset.abbreviation,
})
.map(|asset| asset.symbol)
.collect::<Vec<_>>();
match message.action {
Action::Add => {
let mut guard = guard.write().await;
guard.symbols.extend(symbols.clone());
guard
.pending_subscriptions
.extend(symbols.clone().into_iter().zip(message.assets.clone()));
guard.assets.extend(
message
.assets
.iter()
.map(|asset| (asset.clone(), asset.symbol.clone())),
);
guard.pending_subscriptions.extend(message.assets.clone());
info!("{:?} - Added {:?}.", thread_type, symbols);
@@ -108,7 +108,7 @@ async fn handle_asset_status_message(
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Subscribe(
websocket_market_message_factory(thread_type, symbols),
create_websocket_market_message(thread_type, symbols),
))
.unwrap(),
))
@@ -121,10 +121,10 @@ async fn handle_asset_status_message(
Action::Remove => {
let mut guard = guard.write().await;
guard.symbols.retain(|symbol| !symbols.contains(symbol));
guard
.pending_unsubscriptions
.extend(symbols.clone().into_iter().zip(message.assets.clone()));
.assets
.retain(|asset, _| !message.assets.contains(asset));
guard.pending_unsubscriptions.extend(message.assets);
info!("{:?} - Removed {:?}.", thread_type, symbols);
@@ -140,7 +140,7 @@ async fn handle_asset_status_message(
.await
.send(tungstenite::Message::Text(
to_string(&websocket::outgoing::Message::Unsubscribe(
websocket_market_message_factory(thread_type, symbols),
create_websocket_market_message(thread_type, symbols),
))
.unwrap(),
))
@@ -155,7 +155,7 @@ async fn handle_asset_status_message(
message.response.send(()).unwrap();
}
fn websocket_market_message_factory(
fn create_websocket_market_message(
thread_type: ThreadType,
symbols: Vec<String>,
) -> websocket::outgoing::subscribe::Message {

View File

@@ -10,13 +10,14 @@ use crate::{
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE},
};
use backoff::{future::retry, ExponentialBackoff};
use futures_util::future::join_all;
use log::{error, info, warn};
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
use tokio::{
join, spawn,
sync::{mpsc, oneshot, Mutex, RwLock},
task::{spawn_blocking, JoinHandle},
task::JoinHandle,
time::sleep,
};
@@ -87,16 +88,18 @@ async fn handle_backfill_message(
let mut backfill_jobs = backfill_jobs.lock().await;
let symbols = match message.assets {
Subset::All => guard.symbols.clone().into_iter().collect::<Vec<_>>(),
Subset::All => guard
.assets
.clone()
.into_iter()
.map(|(_, symbol)| symbol)
.collect(),
Subset::Some(assets) => assets
.into_iter()
.map(|asset| match thread_type {
ThreadType::Bars(_) => asset.symbol,
ThreadType::News => asset.abbreviation,
})
.map(|asset| asset.symbol)
.filter(|symbol| match message.action {
Action::Backfill => guard.symbols.contains(symbol),
Action::Purge => !guard.symbols.contains(symbol),
Action::Backfill => guard.assets.contains_right(symbol),
Action::Purge => !guard.assets.contains_right(symbol),
})
.collect::<Vec<_>>(),
};
@@ -365,33 +368,30 @@ async fn execute_backfill_news(
return;
}
let app_config_clone = app_config.clone();
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions: Vec<Prediction> = spawn_blocking(move || {
inputs
.chunks(app_config_clone.max_bert_inputs)
.flat_map(|inputs| {
app_config_clone
.sequence_classifier
.lock()
.unwrap()
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
.collect()
})
let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| {
let sequence_classifier = app_config.sequence_classifier.clone();
async move {
sequence_classifier
.lock()
.await
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
}
}))
.await
.unwrap();
.into_iter()
.flatten();
let news = news
.into_iter()
.zip(predictions.into_iter())
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,

View File

@@ -2,31 +2,22 @@ pub mod asset_status;
pub mod backfill;
pub mod websocket;
use super::clock;
use super::{clock, guard::Guard};
use crate::{
config::{
Config, ALPACA_CRYPTO_WEBSOCKET_URL, ALPACA_NEWS_WEBSOCKET_URL, ALPACA_STOCK_WEBSOCKET_URL,
},
types::{Asset, Class, Subset},
types::{Class, Subset},
utils::authenticate,
};
use futures_util::StreamExt;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use std::sync::Arc;
use tokio::{
join, select, spawn,
sync::{mpsc, Mutex, RwLock},
};
use tokio_tungstenite::connect_async;
pub struct Guard {
pub symbols: HashSet<String>,
pub pending_subscriptions: HashMap<String, Asset>,
pub pending_unsubscriptions: HashMap<String, Asset>,
}
#[derive(Clone, Copy, Debug)]
pub enum ThreadType {
Bars(Class),
@@ -76,11 +67,7 @@ async fn init_thread(
mpsc::Sender<asset_status::Message>,
mpsc::Sender<backfill::Message>,
) {
let guard = Arc::new(RwLock::new(Guard {
symbols: HashSet::new(),
pending_subscriptions: HashMap::new(),
pending_unsubscriptions: HashMap::new(),
}));
let guard = Arc::new(RwLock::new(Guard::new()));
let websocket_url = match thread_type {
ThreadType::Bars(Class::UsEquity) => format!(

View File

@@ -3,6 +3,7 @@ use crate::{
config::Config,
database,
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
utils::add_slash_to_pair,
};
use futures_util::{
stream::{SplitSink, SplitStream},
@@ -10,16 +11,12 @@ use futures_util::{
};
use log::{error, info, warn};
use serde_json::from_str;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use std::{collections::HashSet, sync::Arc};
use tokio::{
join,
net::TcpStream,
spawn,
sync::{mpsc, Mutex, RwLock},
task::spawn_blocking,
};
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
@@ -106,20 +103,24 @@ async fn handle_parsed_websocket_message(
websocket::incoming::Message::Subscription(message) => {
let symbols = match message {
websocket::incoming::subscription::Message::Market(message) => message.bars,
websocket::incoming::subscription::Message::News(message) => message.news,
websocket::incoming::subscription::Message::News(message) => message
.news
.into_iter()
.map(|symbol| add_slash_to_pair(&symbol))
.collect(),
};
let mut guard = guard.write().await;
let newly_subscribed = guard
.pending_subscriptions
.extract_if(|symbol, _| symbols.contains(symbol))
.collect::<HashMap<_, _>>();
.extract_if(|asset| symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
let newly_unsubscribed = guard
.pending_unsubscriptions
.extract_if(|symbol, _| !symbols.contains(symbol))
.collect::<HashMap<_, _>>();
.extract_if(|asset| !symbols.contains(&asset.symbol))
.collect::<HashSet<_>>();
drop(guard);
@@ -128,12 +129,15 @@ async fn handle_parsed_websocket_message(
info!(
"{:?} - Subscribed to {:?}.",
thread_type,
newly_subscribed.keys().collect::<Vec<_>>()
newly_subscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
let (backfill_message, backfill_receiver) = backfill::Message::new(
backfill::Action::Backfill,
Subset::Some(newly_subscribed.into_values().collect::<Vec<_>>()),
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(backfill_message).await.unwrap();
@@ -146,12 +150,15 @@ async fn handle_parsed_websocket_message(
info!(
"{:?} - Unsubscribed from {:?}.",
thread_type,
newly_unsubscribed.keys().collect::<Vec<_>>()
newly_unsubscribed
.iter()
.map(|asset| asset.symbol.clone())
.collect::<Vec<_>>()
);
let (purge_message, purge_receiver) = backfill::Message::new(
backfill::Action::Purge,
Subset::Some(newly_unsubscribed.into_values().collect::<Vec<_>>()),
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
);
backfill_sender.send(purge_message).await.unwrap();
@@ -166,7 +173,7 @@ async fn handle_parsed_websocket_message(
let bar = Bar::from(message);
let guard = guard.read().await;
if guard.symbols.get(&bar.symbol).is_none() {
if !guard.assets.contains_right(&bar.symbol) {
warn!(
"{:?} - Race condition: received bar for unsubscribed symbol: {:?}.",
thread_type, bar.symbol
@@ -182,10 +189,13 @@ async fn handle_parsed_websocket_message(
}
websocket::incoming::Message::News(message) => {
let news = News::from(message);
let symbols = news.symbols.clone().into_iter().collect::<HashSet<_>>();
let guard = guard.read().await;
if !guard.symbols.iter().any(|symbol| symbols.contains(symbol)) {
if !news
.symbols
.iter()
.any(|symbol| guard.assets.contains_right(symbol))
{
warn!(
"{:?} - Race condition: received news for unsubscribed symbols: {:?}.",
thread_type, news.symbols
@@ -198,21 +208,16 @@ async fn handle_parsed_websocket_message(
thread_type, news.symbols, news.time_created
);
let app_config_clone = app_config.clone();
let input = format!("{}\n\n{}", news.headline, news.content);
let prediction = spawn_blocking(move || {
app_config_clone
.sequence_classifier
.lock()
.unwrap()
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]
})
.await
.unwrap();
let prediction = app_config
.sequence_classifier
.lock()
.await
.predict(vec![input.as_str()])
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0];
let news = News {
sentiment: prediction.sentiment,

19
src/threads/guard.rs Normal file
View File

@@ -0,0 +1,19 @@
use crate::types::Asset;
use bimap::BiMap;
use std::collections::HashSet;
pub struct Guard {
pub assets: BiMap<Asset, String>,
pub pending_subscriptions: HashSet<Asset>,
pub pending_unsubscriptions: HashSet<Asset>,
}
impl Guard {
pub fn new() -> Self {
Self {
assets: BiMap::new(),
pending_subscriptions: HashSet::new(),
pending_unsubscriptions: HashSet::new(),
}
}
}

View File

@@ -1,2 +1,3 @@
pub mod clock;
pub mod data;
pub mod guard;