Remove stored abbreviation
- Alpaca is fuck Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -3,6 +3,7 @@ use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, news::Prediction, Bar, News, Subset},
|
||||
utils::add_slash_to_pair,
|
||||
};
|
||||
use futures_util::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
@@ -10,16 +11,12 @@ use futures_util::{
|
||||
};
|
||||
use log::{error, info, warn};
|
||||
use serde_json::from_str;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
use tokio::{
|
||||
join,
|
||||
net::TcpStream,
|
||||
spawn,
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
task::spawn_blocking,
|
||||
};
|
||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
@@ -106,20 +103,24 @@ async fn handle_parsed_websocket_message(
|
||||
websocket::incoming::Message::Subscription(message) => {
|
||||
let symbols = match message {
|
||||
websocket::incoming::subscription::Message::Market(message) => message.bars,
|
||||
websocket::incoming::subscription::Message::News(message) => message.news,
|
||||
websocket::incoming::subscription::Message::News(message) => message
|
||||
.news
|
||||
.into_iter()
|
||||
.map(|symbol| add_slash_to_pair(&symbol))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let mut guard = guard.write().await;
|
||||
|
||||
let newly_subscribed = guard
|
||||
.pending_subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
.extract_if(|asset| symbols.contains(&asset.symbol))
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let newly_unsubscribed = guard
|
||||
.pending_unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
.extract_if(|asset| !symbols.contains(&asset.symbol))
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
drop(guard);
|
||||
|
||||
@@ -128,12 +129,15 @@ async fn handle_parsed_websocket_message(
|
||||
info!(
|
||||
"{:?} - Subscribed to {:?}.",
|
||||
thread_type,
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
newly_subscribed
|
||||
.iter()
|
||||
.map(|asset| asset.symbol.clone())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let (backfill_message, backfill_receiver) = backfill::Message::new(
|
||||
backfill::Action::Backfill,
|
||||
Subset::Some(newly_subscribed.into_values().collect::<Vec<_>>()),
|
||||
Subset::Some(newly_subscribed.into_iter().collect::<Vec<_>>()),
|
||||
);
|
||||
|
||||
backfill_sender.send(backfill_message).await.unwrap();
|
||||
@@ -146,12 +150,15 @@ async fn handle_parsed_websocket_message(
|
||||
info!(
|
||||
"{:?} - Unsubscribed from {:?}.",
|
||||
thread_type,
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
newly_unsubscribed
|
||||
.iter()
|
||||
.map(|asset| asset.symbol.clone())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let (purge_message, purge_receiver) = backfill::Message::new(
|
||||
backfill::Action::Purge,
|
||||
Subset::Some(newly_unsubscribed.into_values().collect::<Vec<_>>()),
|
||||
Subset::Some(newly_unsubscribed.into_iter().collect::<Vec<_>>()),
|
||||
);
|
||||
|
||||
backfill_sender.send(purge_message).await.unwrap();
|
||||
@@ -166,7 +173,7 @@ async fn handle_parsed_websocket_message(
|
||||
let bar = Bar::from(message);
|
||||
|
||||
let guard = guard.read().await;
|
||||
if guard.symbols.get(&bar.symbol).is_none() {
|
||||
if !guard.assets.contains_right(&bar.symbol) {
|
||||
warn!(
|
||||
"{:?} - Race condition: received bar for unsubscribed symbol: {:?}.",
|
||||
thread_type, bar.symbol
|
||||
@@ -182,10 +189,13 @@ async fn handle_parsed_websocket_message(
|
||||
}
|
||||
websocket::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
let symbols = news.symbols.clone().into_iter().collect::<HashSet<_>>();
|
||||
|
||||
let guard = guard.read().await;
|
||||
if !guard.symbols.iter().any(|symbol| symbols.contains(symbol)) {
|
||||
if !news
|
||||
.symbols
|
||||
.iter()
|
||||
.any(|symbol| guard.assets.contains_right(symbol))
|
||||
{
|
||||
warn!(
|
||||
"{:?} - Race condition: received news for unsubscribed symbols: {:?}.",
|
||||
thread_type, news.symbols
|
||||
@@ -198,21 +208,16 @@ async fn handle_parsed_websocket_message(
|
||||
thread_type, news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let app_config_clone = app_config.clone();
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let prediction = spawn_blocking(move || {
|
||||
app_config_clone
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let prediction = app_config
|
||||
.sequence_classifier
|
||||
.lock()
|
||||
.await
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0];
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
|
Reference in New Issue
Block a user