diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index 6375793..4ef2762 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -66,7 +66,7 @@ pub trait Handler: Send + Sync { &self, symbols: Vec, ) -> websocket::data::outgoing::subscribe::Message; - async fn handle_parsed_websocket_message( + async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, @@ -96,12 +96,26 @@ pub async fn run( )); } Some(Ok(message)) = websocket_stream.next() => { - spawn(handle_websocket_message( - handler.clone(), - pending.clone(), - websocket_sink.clone(), - message, - )); + match message { + tungstenite::Message::Text(message) => { + let parsed_message = from_str::>(&message); + + if parsed_message.is_err() { + error!("Failed to deserialize websocket message: {:?}", message); + continue; + } + + for message in parsed_message.unwrap() { + let handler = handler.clone(); + let pending = pending.clone(); + spawn(async move { + handler.handle_websocket_message(pending, message).await; + }); + } + } + tungstenite::Message::Ping(_) => {} + _ => error!("Unexpected websocket message: {:?}", message), + } } else => panic!("Communication channel unexpectedly closed.") } @@ -179,40 +193,6 @@ async fn handle_message( message.response.send(()).unwrap(); } -async fn handle_websocket_message( - handler: Arc>, - pending: Arc>, - sink: Arc>, tungstenite::Message>>>, - message: tungstenite::Message, -) { - match message { - tungstenite::Message::Text(message) => { - if let Ok(message) = from_str::>(&message) { - for message in message { - let handler = handler.clone(); - let pending = pending.clone(); - - spawn(async move { - handler - .handle_parsed_websocket_message(pending, message) - .await; - }); - } - } else { - error!("Failed to deserialize websocket message: {:?}", message); - } - } - tungstenite::Message::Ping(payload) => { - sink.lock() - .await - .send(tungstenite::Message::Pong(payload)) - .await - .unwrap(); - } - _ => error!("Unexpected websocket message: {:?}", message), - } -} - struct BarsHandler { config: Arc, subscription_message_constructor: @@ -228,7 +208,7 @@ impl Handler for BarsHandler { (self.subscription_message_constructor)(symbols) } - async fn handle_parsed_websocket_message( + async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, @@ -338,7 +318,7 @@ impl Handler for NewsHandler { websocket::data::outgoing::subscribe::Message::new_news(symbols) } - async fn handle_parsed_websocket_message( + async fn handle_websocket_message( &self, pending: Arc>, message: websocket::data::incoming::Message, diff --git a/src/threads/trading/mod.rs b/src/threads/trading/mod.rs index 5b4840c..ad88d58 100644 --- a/src/threads/trading/mod.rs +++ b/src/threads/trading/mod.rs @@ -16,5 +16,5 @@ pub async fn run(config: Arc) { alpaca::websocket::trading::authenticate(&mut websocket_sink, &mut websocket_stream).await; alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await; - spawn(websocket::run(config, websocket_stream, websocket_sink)); + spawn(websocket::run(config, websocket_stream)); } diff --git a/src/threads/trading/websocket.rs b/src/threads/trading/websocket.rs index ff66bc8..9fb5690 100644 --- a/src/threads/trading/websocket.rs +++ b/src/threads/trading/websocket.rs @@ -3,60 +3,43 @@ use crate::{ database, types::{alpaca::websocket, Order}, }; -use futures_util::{ - stream::{SplitSink, SplitStream}, - SinkExt, StreamExt, -}; +use futures_util::{stream::SplitStream, StreamExt}; use log::{debug, error}; use serde_json::from_str; use std::sync::Arc; -use tokio::{net::TcpStream, spawn, sync::Mutex}; +use tokio::{net::TcpStream, spawn}; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; pub async fn run( config: Arc, mut websocket_stream: SplitStream>>, - websocket_sink: SplitSink>, tungstenite::Message>, ) { - let websocket_sink = Arc::new(Mutex::new(websocket_sink)); - loop { let message = websocket_stream.next().await.unwrap().unwrap(); - spawn(handle_websocket_message( - config.clone(), - websocket_sink.clone(), - message, - )); + + match message { + tungstenite::Message::Binary(message) => { + let parsed_message = from_str::( + &String::from_utf8_lossy(&message), + ); + + if parsed_message.is_err() { + error!("Failed to deserialize websocket message: {:?}", message); + continue; + } + + spawn(handle_websocket_message( + config.clone(), + parsed_message.unwrap(), + )); + } + tungstenite::Message::Ping(_) => {} + _ => error!("Unexpected websocket message: {:?}", message), + } } } async fn handle_websocket_message( - config: Arc, - sink: Arc>, tungstenite::Message>>>, - message: tungstenite::Message, -) { - match message { - tungstenite::Message::Binary(message) => { - if let Ok(message) = from_str::( - &String::from_utf8_lossy(&message), - ) { - handle_parsed_websocket_message(config, message).await; - } else { - error!("Failed to deserialize websocket message: {:?}", message); - } - } - tungstenite::Message::Ping(payload) => { - sink.lock() - .await - .send(tungstenite::Message::Pong(payload)) - .await - .unwrap(); - } - _ => error!("Unexpected websocket message: {:?}", message), - } -} - -async fn handle_parsed_websocket_message( config: Arc, message: websocket::trading::incoming::Message, ) {