diff --git a/src/config.rs b/src/config.rs index 6b1e889..9f76124 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,6 +23,7 @@ pub const ALPACA_STOCK_DATA_URL: &str = "https://data.alpaca.markets/v2/stocks/b pub const ALPACA_CRYPTO_DATA_URL: &str = "https://data.alpaca.markets/v1beta3/crypto/us/bars"; pub const ALPACA_NEWS_DATA_URL: &str = "https://data.alpaca.markets/v1beta1/news"; +pub const ALPACA_TRADING_WEBSOCKET_URL: &str = "wss://api.alpaca.markets/stream"; pub const ALPACA_STOCK_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v2"; pub const ALPACA_CRYPTO_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta3/crypto/us"; pub const ALPACA_NEWS_WEBSOCKET_URL: &str = "wss://stream.data.alpaca.markets/v1beta1/news"; diff --git a/src/database/assets.rs b/src/database/assets.rs index 94f7f87..04a9192 100644 --- a/src/database/assets.rs +++ b/src/database/assets.rs @@ -1,50 +1,14 @@ -use crate::types::Asset; +use crate::{ + delete_where_symbols, optimize, select, select_where_symbol, types::Asset, upsert_batch, +}; use clickhouse::{error::Error, Client}; use serde::Serialize; -pub async fn select(clickhouse_client: &Client) -> Result, Error> { - clickhouse_client - .query("SELECT ?fields FROM assets FINAL") - .fetch_all::() - .await -} - -pub async fn select_where_symbol( - clickhouse_client: &Client, - symbol: &T, -) -> Result, Error> -where - T: AsRef + Serialize + Send + Sync, -{ - clickhouse_client - .query("SELECT ?fields FROM assets FINAL WHERE symbol = ?") - .bind(symbol) - .fetch_optional::() - .await -} - -pub async fn upsert_batch(clickhouse_client: &Client, assets: T) -> Result<(), Error> -where - T: IntoIterator + Send + Sync, - T::IntoIter: Send, -{ - let mut insert = clickhouse_client.insert("assets")?; - for asset in assets { - insert.write(&asset).await?; - } - insert.end().await -} - -pub async fn delete_where_symbols(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> -where - T: AsRef + Serialize + Send + Sync, -{ - clickhouse_client - .query("DELETE FROM assets WHERE symbol IN ?") - .bind(symbols) - .execute() - .await -} +select!(Asset, "assets"); +select_where_symbol!(Asset, "assets"); +upsert_batch!(Asset, "assets"); +delete_where_symbols!("assets"); +optimize!("assets"); pub async fn update_status_where_symbol( clickhouse_client: &Client, @@ -61,3 +25,19 @@ where .execute() .await } + +pub async fn update_qty_where_symbol( + clickhouse_client: &Client, + symbol: &T, + qty: f64, +) -> Result<(), Error> +where + T: AsRef + Serialize + Send + Sync, +{ + clickhouse_client + .query("ALTER TABLE assets UPDATE qty = ? WHERE symbol = ?") + .bind(qty) + .bind(symbol) + .execute() + .await +} diff --git a/src/database/backfills.rs b/src/database/backfills.rs deleted file mode 100644 index b8eaca6..0000000 --- a/src/database/backfills.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::types::Backfill; -use clickhouse::{error::Error, Client}; -use serde::Serialize; -use std::fmt::{Display, Formatter}; -use tokio::try_join; - -pub enum Table { - Bars, - News, -} - -impl Display for Table { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Bars => write!(f, "backfills_bars"), - Self::News => write!(f, "backfills_news"), - } - } -} - -pub async fn select_latest_where_symbol( - clickhouse_client: &Client, - table: &Table, - symbol: &T, -) -> Result, Error> -where - T: AsRef + Serialize + Send + Sync, -{ - clickhouse_client - .query(&format!( - "SELECT ?fields FROM {table} FINAL WHERE symbol = ? ORDER BY time DESC LIMIT 1", - )) - .bind(symbol) - .fetch_optional::() - .await -} - -pub async fn upsert( - clickhouse_client: &Client, - table: &Table, - backfill: &Backfill, -) -> Result<(), Error> { - let mut insert = clickhouse_client.insert(&table.to_string())?; - insert.write(backfill).await?; - insert.end().await -} - -pub async fn delete_where_symbols( - clickhouse_client: &Client, - table: &Table, - symbols: &[T], -) -> Result<(), Error> -where - T: AsRef + Serialize + Send + Sync, -{ - clickhouse_client - .query(&format!("DELETE FROM {table} WHERE symbol IN ?")) - .bind(symbols) - .execute() - .await -} - -pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { - let delete_bars_future = async { - clickhouse_client - .query("DELETE FROM backfills_bars WHERE symbol NOT IN (SELECT symbol FROM assets)") - .execute() - .await - }; - - let delete_news_future = async { - clickhouse_client - .query("DELETE FROM backfills_news WHERE symbol NOT IN (SELECT symbol FROM assets)") - .execute() - .await - }; - - try_join!(delete_bars_future, delete_news_future).map(|_| ()) -} diff --git a/src/database/backfills_bars.rs b/src/database/backfills_bars.rs new file mode 100644 index 0000000..2d8017c --- /dev/null +++ b/src/database/backfills_bars.rs @@ -0,0 +1,17 @@ +use crate::{ + cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, +}; +use clickhouse::{error::Error, Client}; + +select_where_symbol!(Backfill, "backfills_bars"); +upsert!(Backfill, "backfills_bars"); +delete_where_symbols!("backfills_bars"); +cleanup!("backfills_bars"); +optimize!("backfills_bars"); + +pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { + clickhouse_client + .query("ALTER TABLE backfills_bars UPDATE fresh = false WHERE true") + .execute() + .await +} diff --git a/src/database/backfills_news.rs b/src/database/backfills_news.rs new file mode 100644 index 0000000..9d4f2ef --- /dev/null +++ b/src/database/backfills_news.rs @@ -0,0 +1,17 @@ +use crate::{ + cleanup, delete_where_symbols, optimize, select_where_symbol, types::Backfill, upsert, +}; +use clickhouse::{error::Error, Client}; + +select_where_symbol!(Backfill, "backfills_news"); +upsert!(Backfill, "backfills_news"); +delete_where_symbols!("backfills_news"); +cleanup!("backfills_news"); +optimize!("backfills_news"); + +pub async fn unfresh(clickhouse_client: &Client) -> Result<(), Error> { + clickhouse_client + .query("ALTER TABLE backfills_news UPDATE fresh = false WHERE true") + .execute() + .await +} diff --git a/src/database/bars.rs b/src/database/bars.rs index d674dab..8be2374 100644 --- a/src/database/bars.rs +++ b/src/database/bars.rs @@ -1,39 +1,7 @@ -use crate::types::Bar; -use clickhouse::{error::Error, Client}; -use serde::Serialize; +use crate::{cleanup, delete_where_symbols, optimize, types::Bar, upsert, upsert_batch}; -pub async fn upsert(clickhouse_client: &Client, bar: &Bar) -> Result<(), Error> { - let mut insert = clickhouse_client.insert("bars")?; - insert.write(bar).await?; - insert.end().await -} - -pub async fn upsert_batch(clickhouse_client: &Client, bars: T) -> Result<(), Error> -where - T: IntoIterator + Send + Sync, - T::IntoIter: Send, -{ - let mut insert = clickhouse_client.insert("bars")?; - for bar in bars { - insert.write(&bar).await?; - } - insert.end().await -} - -pub async fn delete_where_symbols(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> -where - T: AsRef + Serialize + Send + Sync, -{ - clickhouse_client - .query("DELETE FROM bars WHERE symbol IN ?") - .bind(symbols) - .execute() - .await -} - -pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { - clickhouse_client - .query("DELETE FROM bars WHERE symbol NOT IN (SELECT symbol FROM assets)") - .execute() - .await -} +upsert!(Bar, "bars"); +upsert_batch!(Bar, "bars"); +delete_where_symbols!("bars"); +cleanup!("bars"); +optimize!("bars"); diff --git a/src/database/mod.rs b/src/database/mod.rs index 23d4e66..5255369 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,150 @@ pub mod assets; -pub mod backfills; +pub mod backfills_bars; +pub mod backfills_news; pub mod bars; pub mod news; pub mod orders; -pub mod positions; + +use clickhouse::{error::Error, Client}; +use log::info; +use tokio::try_join; + +#[macro_export] +macro_rules! select { + ($record:ty, $table_name:expr) => { + pub async fn select( + client: &clickhouse::Client, + ) -> Result, clickhouse::error::Error> { + client + .query(&format!("SELECT ?fields FROM {} FINAL", $table_name)) + .fetch_all::<$record>() + .await + } + }; +} + +#[macro_export] +macro_rules! select_where_symbol { + ($record:ty, $table_name:expr) => { + pub async fn select_where_symbol( + client: &clickhouse::Client, + symbol: &T, + ) -> Result, clickhouse::error::Error> + where + T: AsRef + serde::Serialize + Send + Sync, + { + client + .query(&format!( + "SELECT ?fields FROM {} FINAL WHERE symbol = ?", + $table_name + )) + .bind(symbol) + .fetch_optional::<$record>() + .await + } + }; +} + +#[macro_export] +macro_rules! upsert { + ($record:ty, $table_name:expr) => { + pub async fn upsert( + client: &clickhouse::Client, + record: &$record, + ) -> Result<(), clickhouse::error::Error> { + let mut insert = client.insert($table_name)?; + insert.write(record).await?; + insert.end().await + } + }; +} + +#[macro_export] +macro_rules! upsert_batch { + ($record:ty, $table_name:expr) => { + pub async fn upsert_batch<'a, T>( + client: &clickhouse::Client, + records: T, + ) -> Result<(), clickhouse::error::Error> + where + T: IntoIterator + Send + Sync, + T::IntoIter: Send, + { + let mut insert = client.insert($table_name)?; + for record in records { + insert.write(record).await?; + } + insert.end().await + } + }; +} + +#[macro_export] +macro_rules! delete_where_symbols { + ($table_name:expr) => { + pub async fn delete_where_symbols( + client: &clickhouse::Client, + symbols: &[T], + ) -> Result<(), clickhouse::error::Error> + where + T: AsRef + serde::Serialize + Send + Sync, + { + client + .query(&format!("DELETE FROM {} WHERE symbol IN ?", $table_name)) + .bind(symbols) + .execute() + .await + } + }; +} + +#[macro_export] +macro_rules! cleanup { + ($table_name:expr) => { + pub async fn cleanup(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { + client + .query(&format!( + "DELETE FROM {} WHERE symbol NOT IN (SELECT symbol FROM assets)", + $table_name + )) + .execute() + .await + } + }; +} + +#[macro_export] +macro_rules! optimize { + ($table_name:expr) => { + pub async fn optimize(client: &clickhouse::Client) -> Result<(), clickhouse::error::Error> { + client + .query(&format!("OPTIMIZE TABLE {} FINAL", $table_name)) + .execute() + .await + } + }; +} + +pub async fn cleanup_all(clickhouse_client: &Client) -> Result<(), Error> { + info!("Cleaning up database."); + try_join!( + bars::cleanup(clickhouse_client), + news::cleanup(clickhouse_client), + backfills_bars::cleanup(clickhouse_client), + backfills_news::cleanup(clickhouse_client) + ) + .map(|_| ()) +} + +pub async fn optimize_all(clickhouse_client: &Client) -> Result<(), Error> { + info!("Optimizing database."); + try_join!( + assets::optimize(clickhouse_client), + bars::optimize(clickhouse_client), + news::optimize(clickhouse_client), + backfills_bars::optimize(clickhouse_client), + backfills_news::optimize(clickhouse_client), + orders::optimize(clickhouse_client) + ) + .map(|_| ()) +} diff --git a/src/database/news.rs b/src/database/news.rs index 7a8c4d6..4f40072 100644 --- a/src/database/news.rs +++ b/src/database/news.rs @@ -1,24 +1,10 @@ -use crate::types::News; +use crate::{optimize, types::News, upsert, upsert_batch}; use clickhouse::{error::Error, Client}; use serde::Serialize; -pub async fn upsert(clickhouse_client: &Client, news: &News) -> Result<(), Error> { - let mut insert = clickhouse_client.insert("news")?; - insert.write(news).await?; - insert.end().await -} - -pub async fn upsert_batch(clickhouse_client: &Client, news: T) -> Result<(), Error> -where - T: IntoIterator + Send + Sync, - T::IntoIter: Send, -{ - let mut insert = clickhouse_client.insert("news")?; - for news in news { - insert.write(&news).await?; - } - insert.end().await -} +upsert!(News, "news"); +upsert_batch!(News, "news"); +optimize!("news"); pub async fn delete_where_symbols(clickhouse_client: &Client, symbols: &[T]) -> Result<(), Error> where diff --git a/src/database/orders.rs b/src/database/orders.rs index e69de29..faa79fd 100644 --- a/src/database/orders.rs +++ b/src/database/orders.rs @@ -0,0 +1,5 @@ +use crate::{optimize, types::Order, upsert, upsert_batch}; + +upsert!(Order, "orders"); +upsert_batch!(Order, "orders"); +optimize!("orders"); diff --git a/src/database/positions.rs b/src/database/positions.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/main.rs b/src/main.rs index 5654751..e00bde9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,11 +9,10 @@ mod threads; mod types; mod utils; -use crate::utils::cleanup; use config::Config; use dotenv::dotenv; use log4rs::config::Deserializers; -use tokio::{spawn, sync::mpsc}; +use tokio::{spawn, sync::mpsc, try_join}; #[tokio::main] async fn main() { @@ -21,7 +20,20 @@ async fn main() { log4rs::init_file("log4rs.yaml", Deserializers::default()).unwrap(); let config = Config::arc_from_env(); - cleanup(&config.clickhouse_client).await.unwrap(); + database::cleanup_all(&config.clickhouse_client) + .await + .unwrap(); + database::optimize_all(&config.clickhouse_client) + .await + .unwrap(); + + try_join!( + database::backfills_bars::unfresh(&config.clickhouse_client), + database::backfills_news::unfresh(&config.clickhouse_client) + ) + .unwrap(); + + spawn(threads::trading::run(config.clone())); let (data_sender, data_receiver) = mpsc::channel::(100); let (clock_sender, clock_receiver) = mpsc::channel::(1); diff --git a/src/routes/assets.rs b/src/routes/assets.rs index 4790796..533c656 100644 --- a/src/routes/assets.rs +++ b/src/routes/assets.rs @@ -62,13 +62,12 @@ pub async fn add( if !asset.tradable || !asset.fractionable { return Err(StatusCode::FORBIDDEN); } - let asset = Asset::from(asset); create_send_await!( data_sender, threads::data::Message::new, threads::data::Action::Add, - vec![(asset.symbol, asset.class)] + vec![(asset.symbol, asset.class.into())] ); Ok(StatusCode::CREATED) diff --git a/src/routes/mod.rs b/src/routes/mod.rs index abd7357..5d3cb5b 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,5 +1,5 @@ -pub mod assets; -pub mod health; +mod assets; +mod health; use crate::{config::Config, threads}; use axum::{ diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index fa72bf9..7d2947a 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -4,17 +4,13 @@ use crate::{ database, types::{ alpaca::{ - self, - api::{self, outgoing::Sort}, - shared::Source, + api, + shared::{Sort, Source}, }, news::Prediction, Backfill, Bar, Class, News, }, - utils::{ - duration_until, last_minute, remove_slash_from_pair, FIFTEEN_MINUTES, ONE_MINUTE, - ONE_SECOND, - }, + utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND}, }; use async_trait::async_trait; use futures_util::future::join_all; @@ -216,21 +212,12 @@ impl Handler for BarHandler { &self, symbol: String, ) -> Result, clickhouse::error::Error> { - database::backfills::select_latest_where_symbol( - &self.config.clickhouse_client, - &database::backfills::Table::Bars, - &symbol, - ) - .await + database::backfills_bars::select_where_symbol(&self.config.clickhouse_client, &symbol).await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::backfills::delete_where_symbols( - &self.config.clickhouse_client, - &database::backfills::Table::Bars, - symbols, - ) - .await + database::backfills_bars::delete_where_symbols(&self.config.clickhouse_client, symbols) + .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { @@ -252,7 +239,7 @@ impl Handler for BarHandler { let mut next_page_token = None; loop { - let Ok(message) = alpaca::api::incoming::bar::get_historical( + let Ok(message) = api::incoming::bar::get_historical( &self.config, self.data_url, &(self.api_query_constructor)( @@ -289,16 +276,12 @@ impl Handler for BarHandler { let backfill = bars.last().unwrap().clone().into(); - database::bars::upsert_batch(&self.config.clickhouse_client, bars) + database::bars::upsert_batch(&self.config.clickhouse_client, &bars) + .await + .unwrap(); + database::backfills_bars::upsert(&self.config.clickhouse_client, &backfill) .await .unwrap(); - database::backfills::upsert( - &self.config.clickhouse_client, - &database::backfills::Table::Bars, - &backfill, - ) - .await - .unwrap(); info!("Backfilled bars for {}.", symbol); } @@ -318,21 +301,12 @@ impl Handler for NewsHandler { &self, symbol: String, ) -> Result, clickhouse::error::Error> { - database::backfills::select_latest_where_symbol( - &self.config.clickhouse_client, - &database::backfills::Table::News, - &symbol, - ) - .await + database::backfills_news::select_where_symbol(&self.config.clickhouse_client, &symbol).await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { - database::backfills::delete_where_symbols( - &self.config.clickhouse_client, - &database::backfills::Table::News, - symbols, - ) - .await + database::backfills_news::delete_where_symbols(&self.config.clickhouse_client, symbols) + .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { @@ -352,10 +326,10 @@ impl Handler for NewsHandler { let mut next_page_token = None; loop { - let Ok(message) = alpaca::api::incoming::news::get_historical( + let Ok(message) = api::incoming::news::get_historical( &self.config, &api::outgoing::news::News { - symbols: vec![remove_slash_from_pair(&symbol)], + symbols: vec![symbol.clone()], start: Some(fetch_from), end: Some(fetch_to), limit: Some(50), @@ -421,16 +395,12 @@ impl Handler for NewsHandler { let backfill = (news.last().unwrap().clone(), symbol.clone()).into(); - database::news::upsert_batch(&self.config.clickhouse_client, news) + database::news::upsert_batch(&self.config.clickhouse_client, &news) + .await + .unwrap(); + database::backfills_news::upsert(&self.config.clickhouse_client, &backfill) .await .unwrap(); - database::backfills::upsert( - &self.config.clickhouse_client, - &database::backfills::Table::News, - &backfill, - ) - .await - .unwrap(); info!("Backfilled news for {}.", symbol); } diff --git a/src/threads/data/mod.rs b/src/threads/data/mod.rs index b950864..41557bc 100644 --- a/src/threads/data/mod.rs +++ b/src/threads/data/mod.rs @@ -1,5 +1,5 @@ -pub mod backfill; -pub mod websocket; +mod backfill; +mod websocket; use super::clock; use crate::{ @@ -8,7 +8,7 @@ use crate::{ }, create_send_await, database, types::{alpaca, Asset, Class}, - utils::{backoff, cleanup}, + utils::backoff, }; use futures_util::{future::join_all, StreamExt}; use itertools::{Either, Itertools}; @@ -128,6 +128,7 @@ async fn init_thread( } #[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_lines)] async fn handle_message( config: Arc, bars_us_equity_websocket_sender: mpsc::Sender, @@ -216,20 +217,33 @@ async fn handle_message( let assets = join_all(symbols.into_iter().map(|symbol| { let config = config.clone(); async move { - Asset::from( + let asset_future = async { alpaca::api::incoming::asset::get_by_symbol( &config, &symbol, Some(backoff::infinite()), ) .await - .unwrap(), - ) + .unwrap() + }; + + let position_future = async { + alpaca::api::incoming::position::get_by_symbol( + &config, + &symbol, + Some(backoff::infinite()), + ) + .await + .unwrap() + }; + + let (asset, position) = join!(asset_future, position_future); + Asset::from((asset, position)) } })) .await; - database::assets::upsert_batch(&config.clickhouse_client, assets) + database::assets::upsert_batch(&config.clickhouse_client, &assets) .await .unwrap(); } @@ -249,7 +263,9 @@ async fn handle_clock_message( bars_crypto_backfill_sender: mpsc::Sender, news_backfill_sender: mpsc::Sender, ) { - cleanup(&config.clickhouse_client).await.unwrap(); + database::cleanup_all(&config.clickhouse_client) + .await + .unwrap(); let assets = database::assets::select(&config.clickhouse_client) .await diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index 9d1f055..fa725df 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -3,7 +3,6 @@ use crate::{ config::Config, database, types::{alpaca::websocket, news::Prediction, Bar, Class, News}, - utils::add_slash_to_pair, }; use async_trait::async_trait; use futures_util::{ @@ -112,9 +111,7 @@ pub async fn run( async fn handle_message( handler: Arc>, pending: Arc>, - websocket_sender: Arc< - Mutex>, tungstenite::Message>>, - >, + sink: Arc>, tungstenite::Message>>>, message: Message, ) { match message.action { @@ -134,8 +131,7 @@ async fn handle_message( .subscriptions .extend(pending_subscriptions); - websocket_sender - .lock() + sink.lock() .await .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Subscribe( @@ -164,8 +160,7 @@ async fn handle_message( .unsubscriptions .extend(pending_unsubscriptions); - websocket_sender - .lock() + sink.lock() .await .send(tungstenite::Message::Text( to_string(&websocket::data::outgoing::Message::Unsubscribe( @@ -186,7 +181,7 @@ async fn handle_message( async fn handle_websocket_message( handler: Arc>, pending: Arc>, - sender: Arc>, tungstenite::Message>>>, + sink: Arc>, tungstenite::Message>>>, message: tungstenite::Message, ) { match message { @@ -208,11 +203,10 @@ async fn handle_websocket_message( error!("Failed to deserialize websocket message: {:?}", message); } } - tungstenite::Message::Ping(_) => { - sender - .lock() + tungstenite::Message::Ping(payload) => { + sink.lock() .await - .send(tungstenite::Message::Pong(vec![])) + .send(tungstenite::Message::Pong(payload)) .await .unwrap(); } @@ -358,11 +352,6 @@ impl Handler for NewsHandler { unreachable!() }; - let symbols = symbols - .into_iter() - .map(|symbol| add_slash_to_pair(&symbol)) - .collect::>(); - let mut pending = pending.write().await; let newly_subscribed = pending diff --git a/src/threads/trading/mod.rs b/src/threads/trading/mod.rs index e69de29..449cf96 100644 --- a/src/threads/trading/mod.rs +++ b/src/threads/trading/mod.rs @@ -0,0 +1,53 @@ +mod rehydrate; +mod websocket; + +use crate::{ + config::{Config, ALPACA_TRADING_WEBSOCKET_URL}, + database, + types::alpaca, +}; +use futures_util::StreamExt; +use log::warn; +use rehydrate::rehydrate; +use std::{collections::HashSet, sync::Arc}; +use tokio::spawn; +use tokio_tungstenite::connect_async; + +pub async fn run(config: Arc) { + let (websocket, _) = connect_async(ALPACA_TRADING_WEBSOCKET_URL).await.unwrap(); + let (mut websocket_sink, mut websocket_stream) = websocket.split(); + + alpaca::websocket::trading::authenticate(&config, &mut websocket_sink, &mut websocket_stream) + .await; + alpaca::websocket::trading::subscribe(&mut websocket_sink, &mut websocket_stream).await; + + rehydrate(&config).await; + check_positions(&config).await; + + spawn(websocket::run(config, websocket_stream, websocket_sink)); +} + +pub async fn check_positions(config: &Arc) { + let positions_future = async { + alpaca::api::incoming::position::get(config, None) + .await + .unwrap() + }; + + let assets_future = async { + database::assets::select(&config.clickhouse_client) + .await + .unwrap() + .into_iter() + .map(|asset| asset.symbol) + .collect::>() + }; + + let (positions, assets) = tokio::join!(positions_future, assets_future); + + for position in positions { + if !assets.contains(&position.symbol) { + warn!("Position for unmonitored asset: {:?}", position.symbol); + } + } +} diff --git a/src/threads/trading/rehydrate.rs b/src/threads/trading/rehydrate.rs new file mode 100644 index 0000000..8fa4e34 --- /dev/null +++ b/src/threads/trading/rehydrate.rs @@ -0,0 +1,48 @@ +use crate::{ + config::Config, + database, + types::alpaca::{api, shared::Sort}, +}; +use log::info; +use std::sync::Arc; +use time::OffsetDateTime; + +pub async fn rehydrate(config: &Arc) { + info!("Rehydrating trading data."); + + let mut orders = vec![]; + let mut after = OffsetDateTime::UNIX_EPOCH; + + while let Some(message) = api::incoming::order::get( + config, + &api::outgoing::order::Order { + status: Some(api::outgoing::order::Status::All), + limit: Some(500), + after: Some(after), + until: None, + direction: Some(Sort::Asc), + nested: Some(true), + symbols: None, + side: None, + }, + None, + ) + .await + .ok() + .filter(|message| !message.is_empty()) + { + orders.extend(message); + after = orders.last().unwrap().submitted_at; + } + + let orders = orders + .into_iter() + .flat_map(&api::incoming::order::Order::normalize) + .collect::>(); + + database::orders::upsert_batch(&config.clickhouse_client, &orders) + .await + .unwrap(); + + info!("Rehydrated trading data."); +} diff --git a/src/threads/trading/websocket.rs b/src/threads/trading/websocket.rs new file mode 100644 index 0000000..bace439 --- /dev/null +++ b/src/threads/trading/websocket.rs @@ -0,0 +1,96 @@ +use crate::{ + config::Config, + database, + types::{alpaca::websocket, Order}, +}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use log::{debug, error}; +use serde_json::from_str; +use std::sync::Arc; +use tokio::{net::TcpStream, spawn, sync::Mutex}; +use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; + +pub async fn run( + config: Arc, + mut websocket_stream: SplitStream>>, + websocket_sink: SplitSink>, tungstenite::Message>, +) { + let websocket_sink = Arc::new(Mutex::new(websocket_sink)); + + loop { + let message = websocket_stream.next().await.unwrap().unwrap(); + let config = config.clone(); + + spawn(handle_websocket_message( + config, + websocket_sink.clone(), + message, + )); + } +} + +async fn handle_websocket_message( + config: Arc, + sink: Arc>, tungstenite::Message>>>, + message: tungstenite::Message, +) { + match message { + tungstenite::Message::Binary(message) => { + if let Ok(message) = from_str::( + &String::from_utf8_lossy(&message), + ) { + spawn(handle_parsed_websocket_message(config.clone(), message)); + } else { + error!("Failed to deserialize websocket message: {:?}", message); + } + } + tungstenite::Message::Ping(payload) => { + sink.lock() + .await + .send(tungstenite::Message::Pong(payload)) + .await + .unwrap(); + } + _ => error!("Unexpected websocket message: {:?}", message), + } +} + +async fn handle_parsed_websocket_message( + config: Arc, + message: websocket::trading::incoming::Message, +) { + match message { + websocket::trading::incoming::Message::Order(message) => { + debug!( + "Received order message for {}: {:?}", + message.order.symbol, message.event + ); + + let order = Order::from(message.order); + + database::orders::upsert(&config.clickhouse_client, &order) + .await + .unwrap(); + + match message.event { + websocket::trading::incoming::order::Event::Fill { position_qty, .. } + | websocket::trading::incoming::order::Event::PartialFill { + position_qty, .. + } => { + database::assets::update_qty_where_symbol( + &config.clickhouse_client, + &order.symbol, + position_qty, + ) + .await + .unwrap(); + } + _ => (), + } + } + _ => unreachable!(), + } +} diff --git a/src/types/alpaca/api/incoming/asset.rs b/src/types/alpaca/api/incoming/asset.rs index 8ae5d57..d6d4872 100644 --- a/src/types/alpaca/api/incoming/asset.rs +++ b/src/types/alpaca/api/incoming/asset.rs @@ -1,3 +1,4 @@ +use super::position::Position; use crate::{ config::{Config, ALPACA_ASSET_API_URL}, types::{ @@ -30,14 +31,15 @@ pub struct Asset { pub attributes: Option>, } -impl From for types::Asset { - fn from(asset: Asset) -> Self { +impl From<(Asset, Option)> for types::Asset { + fn from((asset, position): (Asset, Option)) -> Self { Self { symbol: asset.symbol, class: asset.class.into(), exchange: asset.exchange.into(), status: asset.status.into(), time_added: time::OffsetDateTime::now_utc(), + qty: position.map(|position| position.qty).unwrap_or_default(), } } } diff --git a/src/types/alpaca/api/incoming/news.rs b/src/types/alpaca/api/incoming/news.rs index e4bd315..042672f 100644 --- a/src/types/alpaca/api/incoming/news.rs +++ b/src/types/alpaca/api/incoming/news.rs @@ -1,7 +1,10 @@ use crate::{ config::{Config, ALPACA_NEWS_DATA_URL}, - types::{self, alpaca::api::outgoing}, - utils::{add_slash_to_pair, normalize_news_content}, + types::{ + self, + alpaca::{api::outgoing, shared::news::normalize_html_content}, + }, + utils::de, }; use backoff::{future::retry_notify, ExponentialBackoff}; use log::warn; @@ -33,6 +36,7 @@ pub struct News { #[serde(with = "time::serde::rfc3339")] #[serde(rename = "updated_at")] pub time_updated: OffsetDateTime, + #[serde(deserialize_with = "de::add_slash_to_symbols")] pub symbols: Vec, pub headline: String, pub author: String, @@ -49,16 +53,12 @@ impl From for types::News { id: news.id, time_created: news.time_created, time_updated: news.time_updated, - symbols: news - .symbols - .into_iter() - .map(|symbol| add_slash_to_pair(&symbol)) - .collect(), - headline: normalize_news_content(&news.headline), - author: normalize_news_content(&news.author), - source: normalize_news_content(&news.source), - summary: normalize_news_content(&news.summary), - content: normalize_news_content(&news.content), + symbols: news.symbols, + headline: normalize_html_content(&news.headline), + author: normalize_html_content(&news.author), + source: normalize_html_content(&news.source), + summary: normalize_html_content(&news.summary), + content: normalize_html_content(&news.content), sentiment: types::news::Sentiment::Neutral, confidence: 0.0, url: news.url.unwrap_or_default(), diff --git a/src/types/alpaca/api/incoming/position.rs b/src/types/alpaca/api/incoming/position.rs index f71d16d..ee4fa27 100644 --- a/src/types/alpaca/api/incoming/position.rs +++ b/src/types/alpaca/api/incoming/position.rs @@ -1,12 +1,10 @@ use crate::{ config::{Config, ALPACA_POSITION_API_URL}, - types::{ + types::alpaca::shared::{ self, - alpaca::shared::{ - self, - asset::{Class, Exchange}, - }, + asset::{Class, Exchange}, }, + utils::de, }; use backoff::{future::retry_notify, ExponentialBackoff}; use log::warn; @@ -33,6 +31,7 @@ impl From for shared::order::Side { #[derive(Deserialize)] pub struct Position { pub asset_id: Uuid, + #[serde(deserialize_with = "de::add_slash_to_symbol")] pub symbol: String, pub exchange: Exchange, pub asset_class: Class, @@ -52,15 +51,6 @@ pub struct Position { pub asset_marginable: bool, } -impl From for types::Position { - fn from(position: Position) -> Self { - Self { - symbol: position.symbol, - qty: position.qty_available, - } - } -} - pub async fn get( config: &Arc, backoff: Option, @@ -93,3 +83,44 @@ pub async fn get( ) .await } + +pub async fn get_by_symbol( + config: &Arc, + symbol: &str, + backoff: Option, +) -> Result, reqwest::Error> { + retry_notify( + backoff.unwrap_or_default(), + || async { + config.alpaca_rate_limit.until_ready().await; + let response = config + .alpaca_client + .get(&format!("{ALPACA_POSITION_API_URL}/{symbol}")) + .send() + .await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Ok(None); + } + + response + .error_for_status() + .map_err(|e| match e.status() { + Some(reqwest::StatusCode::FORBIDDEN) => backoff::Error::Permanent(e), + _ => e.into(), + })? + .json::() + .await + .map_err(backoff::Error::Permanent) + .map(Some) + }, + |e, duration: Duration| { + warn!( + "Failed to get position, will retry in {} seconds: {}", + duration.as_secs(), + e + ); + }, + ) + .await +} diff --git a/src/types/alpaca/api/outgoing/bar.rs b/src/types/alpaca/api/outgoing/bar.rs index eceb03c..51c2086 100644 --- a/src/types/alpaca/api/outgoing/bar.rs +++ b/src/types/alpaca/api/outgoing/bar.rs @@ -1,41 +1,11 @@ -use super::{serialize_symbols, Sort}; -use crate::types::alpaca::shared::Source; +use crate::{ + types::alpaca::shared::{Sort, Source}, + utils::ser, +}; use serde::Serialize; use std::time::Duration; use time::OffsetDateTime; -fn serialize_timeframe(timeframe: &Duration, serializer: S) -> Result -where - S: serde::Serializer, -{ - let mins = timeframe.as_secs() / 60; - if mins < 60 { - return serializer.serialize_str(&format!("{mins}Min")); - } - - let hours = mins / 60; - if hours < 24 { - return serializer.serialize_str(&format!("{hours}Hour")); - } - - let days = hours / 24; - if days == 1 { - return serializer.serialize_str("1Day"); - } - - let weeks = days / 7; - if weeks == 1 { - return serializer.serialize_str("1Week"); - } - - let months = days / 30; - if [1, 2, 3, 4, 6, 12].contains(&months) { - return serializer.serialize_str(&format!("{months}Month")); - }; - - Err(serde::ser::Error::custom("Invalid timeframe duration")) -} - #[derive(Serialize)] #[allow(dead_code)] pub enum Adjustment { @@ -49,9 +19,9 @@ pub enum Adjustment { #[serde(untagged)] pub enum Bar { UsEquity { - #[serde(serialize_with = "serialize_symbols")] + #[serde(serialize_with = "ser::join_symbols")] symbols: Vec, - #[serde(serialize_with = "serialize_timeframe")] + #[serde(serialize_with = "ser::timeframe")] timeframe: Duration, #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "time::serde::rfc3339::option")] @@ -76,9 +46,9 @@ pub enum Bar { sort: Option, }, Crypto { - #[serde(serialize_with = "serialize_symbols")] + #[serde(serialize_with = "ser::join_symbols")] symbols: Vec, - #[serde(serialize_with = "serialize_timeframe")] + #[serde(serialize_with = "ser::timeframe")] timeframe: Duration, #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "time::serde::rfc3339::option")] diff --git a/src/types/alpaca/api/outgoing/mod.rs b/src/types/alpaca/api/outgoing/mod.rs index 83e799c..7c65f0a 100644 --- a/src/types/alpaca/api/outgoing/mod.rs +++ b/src/types/alpaca/api/outgoing/mod.rs @@ -1,34 +1,3 @@ pub mod bar; pub mod news; pub mod order; - -use serde::{Serialize, Serializer}; - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -#[allow(dead_code)] -pub enum Sort { - Asc, - Desc, -} - -fn serialize_symbols(symbols: &[String], serializer: S) -> Result -where - S: Serializer, -{ - let string = symbols.join(","); - serializer.serialize_str(&string) -} - -fn serialize_symbols_option( - symbols: &Option>, - serializer: S, -) -> Result -where - S: Serializer, -{ - match symbols { - Some(symbols) => serialize_symbols(symbols, serializer), - None => serializer.serialize_none(), - } -} diff --git a/src/types/alpaca/api/outgoing/news.rs b/src/types/alpaca/api/outgoing/news.rs index 2d57614..b50df11 100644 --- a/src/types/alpaca/api/outgoing/news.rs +++ b/src/types/alpaca/api/outgoing/news.rs @@ -1,10 +1,10 @@ -use super::{serialize_symbols, Sort}; +use crate::{types::alpaca::shared::Sort, utils::ser}; use serde::Serialize; use time::OffsetDateTime; #[derive(Serialize)] pub struct News { - #[serde(serialize_with = "serialize_symbols")] + #[serde(serialize_with = "ser::remove_slash_from_pairs_join_symbols")] pub symbols: Vec, #[serde(skip_serializing_if = "Option::is_none")] #[serde(with = "time::serde::rfc3339::option")] diff --git a/src/types/alpaca/api/outgoing/order.rs b/src/types/alpaca/api/outgoing/order.rs index c4bca0b..4d31f83 100644 --- a/src/types/alpaca/api/outgoing/order.rs +++ b/src/types/alpaca/api/outgoing/order.rs @@ -1,10 +1,13 @@ -use super::{serialize_symbols_option, Sort}; -use crate::types::alpaca::shared::order::Side; +use crate::{ + types::alpaca::shared::{order::Side, Sort}, + utils::ser, +}; use serde::Serialize; use time::OffsetDateTime; #[derive(Serialize)] #[serde(rename_all = "snake_case")] +#[allow(dead_code)] pub enum Status { Open, Closed, @@ -28,7 +31,7 @@ pub struct Order { #[serde(skip_serializing_if = "Option::is_none")] pub nested: Option, #[serde(skip_serializing_if = "Option::is_none")] - #[serde(serialize_with = "serialize_symbols_option")] + #[serde(serialize_with = "ser::join_symbols_option")] pub symbols: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub side: Option, diff --git a/src/types/alpaca/shared/mod.rs b/src/types/alpaca/shared/mod.rs index b345356..ccda22c 100644 --- a/src/types/alpaca/shared/mod.rs +++ b/src/types/alpaca/shared/mod.rs @@ -1,5 +1,8 @@ pub mod asset; +pub mod news; pub mod order; +pub mod sort; pub mod source; +pub use sort::Sort; pub use source::Source; diff --git a/src/utils/news.rs b/src/types/alpaca/shared/news.rs similarity index 56% rename from src/utils/news.rs rename to src/types/alpaca/shared/news.rs index 62d5dc5..3c9f986 100644 --- a/src/utils/news.rs +++ b/src/types/alpaca/shared/news.rs @@ -5,10 +5,9 @@ use regex::Regex; lazy_static! { static ref RE_TAGS: Regex = Regex::new("<[^>]+>").unwrap(); static ref RE_SPACES: Regex = Regex::new("[\\u00A0\\s]+").unwrap(); - static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap(); } -pub fn normalize_news_content(content: &str) -> String { +pub fn normalize_html_content(content: &str) -> String { let content = content.replace('\n', " "); let content = RE_TAGS.replace_all(&content, ""); let content = RE_SPACES.replace_all(&content, " "); @@ -17,14 +16,3 @@ pub fn normalize_news_content(content: &str) -> String { content.to_string() } - -pub fn add_slash_to_pair(pair: &str) -> String { - RE_SLASH.captures(pair).map_or_else( - || pair.to_string(), - |caps| format!("{}/{}", &caps[1], &caps[2]), - ) -} - -pub fn remove_slash_from_pair(pair: &str) -> String { - pair.replace('/', "") -} diff --git a/src/types/alpaca/shared/sort.rs b/src/types/alpaca/shared/sort.rs new file mode 100644 index 0000000..8a02674 --- /dev/null +++ b/src/types/alpaca/shared/sort.rs @@ -0,0 +1,9 @@ +use serde::Serialize; + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +#[allow(dead_code)] +pub enum Sort { + Asc, + Desc, +} diff --git a/src/types/alpaca/websocket/data/incoming/bar.rs b/src/types/alpaca/websocket/data/incoming/bar.rs index 31692d9..db3f188 100644 --- a/src/types/alpaca/websocket/data/incoming/bar.rs +++ b/src/types/alpaca/websocket/data/incoming/bar.rs @@ -1,4 +1,4 @@ -use crate::types; +use crate::types::Bar; use serde::Deserialize; use time::OffsetDateTime; @@ -25,7 +25,7 @@ pub struct Message { pub vwap: f64, } -impl From for types::Bar { +impl From for Bar { fn from(bar: Message) -> Self { Self { time: bar.time, diff --git a/src/types/alpaca/websocket/data/incoming/news.rs b/src/types/alpaca/websocket/data/incoming/news.rs index 9d3023a..3915a33 100644 --- a/src/types/alpaca/websocket/data/incoming/news.rs +++ b/src/types/alpaca/websocket/data/incoming/news.rs @@ -1,6 +1,6 @@ use crate::{ - types, - utils::{add_slash_to_pair, normalize_news_content}, + types::{alpaca::shared::news::normalize_html_content, news::Sentiment, News}, + utils::de, }; use serde::Deserialize; use time::OffsetDateTime; @@ -14,6 +14,7 @@ pub struct Message { #[serde(with = "time::serde::rfc3339")] #[serde(rename = "updated_at")] pub time_updated: OffsetDateTime, + #[serde(deserialize_with = "de::add_slash_to_symbols")] pub symbols: Vec, pub headline: String, pub author: String, @@ -23,23 +24,19 @@ pub struct Message { pub url: Option, } -impl From for types::News { +impl From for News { fn from(news: Message) -> Self { Self { id: news.id, time_created: news.time_created, time_updated: news.time_updated, - symbols: news - .symbols - .into_iter() - .map(|symbol| add_slash_to_pair(&symbol)) - .collect(), - headline: normalize_news_content(&news.headline), - author: normalize_news_content(&news.author), - source: normalize_news_content(&news.source), - summary: normalize_news_content(&news.summary), - content: normalize_news_content(&news.content), - sentiment: types::news::Sentiment::Neutral, + symbols: news.symbols, + headline: normalize_html_content(&news.headline), + author: normalize_html_content(&news.author), + source: normalize_html_content(&news.source), + summary: normalize_html_content(&news.summary), + content: normalize_html_content(&news.content), + sentiment: Sentiment::Neutral, confidence: 0.0, url: news.url.unwrap_or_default(), } diff --git a/src/types/alpaca/websocket/data/incoming/subscription.rs b/src/types/alpaca/websocket/data/incoming/subscription.rs index 184e9ef..6e05afd 100644 --- a/src/types/alpaca/websocket/data/incoming/subscription.rs +++ b/src/types/alpaca/websocket/data/incoming/subscription.rs @@ -1,3 +1,4 @@ +use crate::utils::de; use serde::Deserialize; #[derive(Deserialize, Debug, PartialEq, Eq)] @@ -16,6 +17,7 @@ pub enum Message { cancel_errors: Option>, }, News { + #[serde(deserialize_with = "de::add_slash_to_symbols")] news: Vec, }, } diff --git a/src/types/alpaca/websocket/data/outgoing/subscribe.rs b/src/types/alpaca/websocket/data/outgoing/subscribe.rs index bd953aa..e942d4b 100644 --- a/src/types/alpaca/websocket/data/outgoing/subscribe.rs +++ b/src/types/alpaca/websocket/data/outgoing/subscribe.rs @@ -1,4 +1,4 @@ -use crate::utils::remove_slash_from_pair; +use crate::utils::ser; use serde::Serialize; #[derive(Serialize)] @@ -22,6 +22,7 @@ pub enum Market { pub enum Message { Market(Market), News { + #[serde(serialize_with = "ser::remove_slash_from_symbols")] news: Vec, }, } @@ -43,11 +44,6 @@ impl Message { } pub fn new_news(symbols: Vec) -> Self { - Self::News { - news: symbols - .into_iter() - .map(|symbol| remove_slash_from_pair(&symbol)) - .collect(), - } + Self::News { news: symbols } } } diff --git a/src/types/alpaca/websocket/trading/incoming/order.rs b/src/types/alpaca/websocket/trading/incoming/order.rs index f518d5c..f589778 100644 --- a/src/types/alpaca/websocket/trading/incoming/order.rs +++ b/src/types/alpaca/websocket/trading/incoming/order.rs @@ -8,79 +8,45 @@ pub use shared::order::Order; #[derive(Deserialize, Debug, PartialEq)] #[serde(rename_all = "snake_case")] #[serde(tag = "event")] -pub enum Message { - New { - execution_id: Uuid, - order: Order, - }, +pub enum Event { + New, Fill { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, position_qty: f64, price: f64, }, PartialFill { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, position_qty: f64, price: f64, }, Canceled { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, }, Expired { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, }, - DoneForDay { - execution_id: Uuid, - order: Order, - }, + DoneForDay, Replaced { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, }, Rejected { - execution_id: Uuid, - order: Order, timestamp: OffsetDateTime, }, - PendingNew { - execution_id: Uuid, - order: Order, - }, - Stopped { - execution_id: Uuid, - order: Order, - }, - PendingCancel { - execution_id: Uuid, - order: Order, - }, - PendingReplace { - execution_id: Uuid, - order: Order, - }, - Calculated { - execution_id: Uuid, - order: Order, - }, - Suspended { - execution_id: Uuid, - order: Order, - }, - OrderReplaceRejected { - execution_id: Uuid, - order: Order, - }, - OrderCancelRejected { - execution_id: Uuid, - order: Order, - }, + PendingNew, + Stopped, + PendingCancel, + PendingReplace, + Calculated, + Suspended, + OrderReplaceRejected, + OrderCancelRejected, +} + +#[derive(Deserialize, Debug, PartialEq)] +pub struct Message { + pub execution_id: Uuid, + pub order: Order, + #[serde(flatten)] + pub event: Event, } diff --git a/src/types/alpaca/websocket/trading/mod.rs b/src/types/alpaca/websocket/trading/mod.rs index bb18eb0..638f6e8 100644 --- a/src/types/alpaca/websocket/trading/mod.rs +++ b/src/types/alpaca/websocket/trading/mod.rs @@ -33,15 +33,13 @@ pub async fn authenticate( Message::Binary(data) => { let data = String::from_utf8(data).unwrap(); - if from_str::>(&data) - .unwrap() - .first() - != Some(&websocket::trading::incoming::Message::Auth( + if from_str::(&data).unwrap() + != websocket::trading::incoming::Message::Auth( websocket::trading::incoming::auth::Message { status: websocket::trading::incoming::auth::Status::Authorized, action: websocket::trading::incoming::auth::Action::Auth, }, - )) + ) { panic!("Failed to authenticate with Alpaca websocket."); } @@ -49,3 +47,36 @@ pub async fn authenticate( _ => panic!("Failed to authenticate with Alpaca websocket."), }; } + +pub async fn subscribe( + sink: &mut SplitSink>, Message>, + stream: &mut SplitStream>>, +) { + sink.send(Message::Text( + to_string(&websocket::trading::outgoing::Message::Subscribe { + data: websocket::trading::outgoing::subscribe::Message { + streams: vec![String::from("trade_updates")], + }, + }) + .unwrap(), + )) + .await + .unwrap(); + + match stream.next().await.unwrap().unwrap() { + Message::Binary(data) => { + let data = String::from_utf8(data).unwrap(); + + if from_str::(&data).unwrap() + != websocket::trading::incoming::Message::Subscription( + websocket::trading::incoming::subscription::Message { + streams: vec![String::from("trade_updates")], + }, + ) + { + panic!("Failed to subscribe to Alpaca websocket."); + } + } + _ => panic!("Failed to subscribe to Alpaca websocket."), + }; +} diff --git a/src/types/alpaca/websocket/trading/outgoing/subscribe.rs b/src/types/alpaca/websocket/trading/outgoing/subscribe.rs index 46f3922..70b8237 100644 --- a/src/types/alpaca/websocket/trading/outgoing/subscribe.rs +++ b/src/types/alpaca/websocket/trading/outgoing/subscribe.rs @@ -2,5 +2,5 @@ use serde::Serialize; #[derive(Serialize)] pub struct Message { - streams: Vec, + pub streams: Vec, } diff --git a/src/types/asset.rs b/src/types/asset.rs index 2d05b92..0d037ee 100644 --- a/src/types/asset.rs +++ b/src/types/asset.rs @@ -24,7 +24,7 @@ pub enum Exchange { Crypto = 8, } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Row)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)] pub struct Asset { pub symbol: String, pub class: Class, @@ -32,6 +32,7 @@ pub struct Asset { pub status: bool, #[serde(with = "clickhouse::serde::time::datetime")] pub time_added: OffsetDateTime, + pub qty: f64, } impl Hash for Asset { diff --git a/src/types/mod.rs b/src/types/mod.rs index 850a554..9791365 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,11 +4,9 @@ pub mod backfill; pub mod bar; pub mod news; pub mod order; -pub mod position; pub use asset::{Asset, Class, Exchange}; pub use backfill::Backfill; pub use bar::Bar; pub use news::News; pub use order::Order; -pub use position::Position; diff --git a/src/types/position.rs b/src/types/position.rs deleted file mode 100644 index 45368eb..0000000 --- a/src/types/position.rs +++ /dev/null @@ -1,8 +0,0 @@ -use clickhouse::Row; -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Row)] -pub struct Position { - pub symbol: String, - pub qty: f64, -} diff --git a/src/utils/cleanup.rs b/src/utils/cleanup.rs deleted file mode 100644 index f9fa08e..0000000 --- a/src/utils/cleanup.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::database; -use clickhouse::{error::Error, Client}; -use tokio::try_join; - -pub async fn cleanup(clickhouse_client: &Client) -> Result<(), Error> { - try_join!( - database::bars::cleanup(clickhouse_client), - database::news::cleanup(clickhouse_client), - database::backfills::cleanup(clickhouse_client) - ) - .map(|_| ()) -} diff --git a/src/utils/de.rs b/src/utils/de.rs new file mode 100644 index 0000000..b82244d --- /dev/null +++ b/src/utils/de.rs @@ -0,0 +1,77 @@ +use lazy_static::lazy_static; +use regex::Regex; +use serde::{ + de::{self, SeqAccess, Visitor}, + Deserializer, +}; +use std::fmt; + +lazy_static! { + static ref RE_SLASH: Regex = Regex::new(r"^(.+)(BTC|USD.?)$").unwrap(); +} + +fn add_slash(pair: &str) -> String { + RE_SLASH.captures(pair).map_or_else( + || pair.to_string(), + |caps| format!("{}/{}", &caps[1], &caps[2]), + ) +} + +pub fn add_slash_to_symbol<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + struct StringVisitor; + + impl<'de> Visitor<'de> for StringVisitor { + type Value = String; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string without a slash") + } + + fn visit_str(self, pair: &str) -> Result + where + E: de::Error, + { + Ok(add_slash(pair)) + } + + fn visit_string(self, pair: String) -> Result + where + E: de::Error, + { + Ok(add_slash(&pair)) + } + } + + deserializer.deserialize_string(StringVisitor) +} + +pub fn add_slash_to_symbols<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct VecStringVisitor; + + impl<'de> Visitor<'de> for VecStringVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a list of strings without a slash") + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: SeqAccess<'de>, + { + let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(25)); + while let Some(value) = seq.next_element::()? { + vec.push(add_slash(&value)); + } + Ok(vec) + } + } + + deserializer.deserialize_seq(VecStringVisitor) +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index f0bb3f3..b850504 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,7 @@ pub mod backoff; -pub mod cleanup; +pub mod de; pub mod macros; -pub mod news; +pub mod ser; pub mod time; -pub use cleanup::cleanup; -pub use news::{add_slash_to_pair, normalize_news_content, remove_slash_from_pair}; pub use time::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND}; diff --git a/src/utils/ser.rs b/src/utils/ser.rs new file mode 100644 index 0000000..42053d9 --- /dev/null +++ b/src/utils/ser.rs @@ -0,0 +1,90 @@ +use serde::{ser::SerializeSeq, Serializer}; +use std::time::Duration; + +pub fn timeframe(timeframe: &Duration, serializer: S) -> Result +where + S: serde::Serializer, +{ + let mins = timeframe.as_secs() / 60; + if mins < 60 { + return serializer.serialize_str(&format!("{mins}Min")); + } + + let hours = mins / 60; + if hours < 24 { + return serializer.serialize_str(&format!("{hours}Hour")); + } + + let days = hours / 24; + if days == 1 { + return serializer.serialize_str("1Day"); + } + + let weeks = days / 7; + if weeks == 1 { + return serializer.serialize_str("1Week"); + } + + let months = days / 30; + if [1, 2, 3, 4, 6, 12].contains(&months) { + return serializer.serialize_str(&format!("{months}Month")); + }; + + Err(serde::ser::Error::custom("Invalid timeframe duration")) +} + +fn remove_slash(pair: &str) -> String { + pair.replace('/', "") +} + +pub fn join_symbols(symbols: &[String], serializer: S) -> Result +where + S: Serializer, +{ + let string = symbols.join(","); + serializer.serialize_str(&string) +} + +pub fn join_symbols_option( + symbols: &Option>, + serializer: S, +) -> Result +where + S: Serializer, +{ + match symbols { + Some(symbols) => join_symbols(symbols, serializer), + None => serializer.serialize_none(), + } +} + +pub fn remove_slash_from_symbols(pairs: &[String], serializer: S) -> Result +where + S: Serializer, +{ + let symbols = pairs + .iter() + .map(|pair| remove_slash(pair)) + .collect::>(); + + let mut seq = serializer.serialize_seq(Some(symbols.len()))?; + for symbol in symbols { + seq.serialize_element(&symbol)?; + } + seq.end() +} + +pub fn remove_slash_from_pairs_join_symbols( + symbols: &[String], + serializer: S, +) -> Result +where + S: Serializer, +{ + let symbols = symbols + .iter() + .map(|symbol| remove_slash(symbol)) + .collect::>(); + + join_symbols(&symbols, serializer) +} diff --git a/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql b/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql index 1475a49..c7b8f6d 100644 --- a/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql +++ b/support/clickhouse/docker-entrypoint-initdb.d/0000_init.sql @@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS qrust.assets ( ), status Boolean, time_added DateTime DEFAULT now(), + qty Float64 ) ENGINE = ReplacingMergeTree() PRIMARY KEY symbol; @@ -34,7 +35,8 @@ PARTITION BY toYYYYMM(time); CREATE TABLE IF NOT EXISTS qrust.backfills_bars ( symbol LowCardinality(String), - time DateTime + time DateTime, + fresh Boolean ) ENGINE = ReplacingMergeTree() PRIMARY KEY symbol; @@ -60,7 +62,8 @@ PRIMARY KEY id; CREATE TABLE IF NOT EXISTS qrust.backfills_news ( symbol LowCardinality(String), - time DateTime + time DateTime, + fresh Boolean ) ENGINE = ReplacingMergeTree() PRIMARY KEY symbol; @@ -117,10 +120,3 @@ CREATE TABLE IF NOT EXISTS qrust.orders ( ENGINE = ReplacingMergeTree() PARTITION BY toYYYYMM(time_submitted) PRIMARY KEY id; - -CREATE TABLE IF NOT EXISTS qrust.positions ( - symbol LowCardinality(String), - qty Float64 -) -ENGINE = ReplacingMergeTree() -PRIMARY KEY symbol;