Separate data management code
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -1,596 +0,0 @@
|
||||
use super::ThreadType;
|
||||
use crate::{
|
||||
config::{
|
||||
Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_SOURCE, ALPACA_STOCK_DATA_API_URL,
|
||||
BERT_MAX_INPUTS,
|
||||
},
|
||||
database,
|
||||
types::{
|
||||
alpaca::{
|
||||
self,
|
||||
shared::{Sort, Source},
|
||||
},
|
||||
news::Prediction,
|
||||
Backfill, Bar, Class, News,
|
||||
},
|
||||
utils::{duration_until, last_minute, FIFTEEN_MINUTES, ONE_MINUTE, ONE_SECOND},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future::join_all;
|
||||
use itertools::{Either, Itertools};
|
||||
use log::{error, info, warn};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
task::{block_in_place, JoinHandle},
|
||||
time::sleep,
|
||||
try_join,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub enum Action {
|
||||
Backfill,
|
||||
Purge,
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Action,
|
||||
pub symbols: Vec<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel::<()>();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Job {
|
||||
pub fetch_from: OffsetDateTime,
|
||||
pub fetch_to: OffsetDateTime,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
|
||||
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>);
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>);
|
||||
fn max_limit(&self) -> i64;
|
||||
fn log_string(&self) -> &'static str;
|
||||
}
|
||||
|
||||
pub struct Jobs {
|
||||
pub symbol_to_uuid: HashMap<String, Uuid>,
|
||||
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl Jobs {
|
||||
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
|
||||
let uuid = Uuid::new_v4();
|
||||
for symbol in jobs {
|
||||
self.symbol_to_uuid.insert(symbol.clone(), uuid);
|
||||
}
|
||||
self.uuid_to_job.insert(uuid, fut);
|
||||
}
|
||||
|
||||
pub fn get(&self, symbol: &str) -> Option<&JoinHandle<()>> {
|
||||
self.symbol_to_uuid
|
||||
.get(symbol)
|
||||
.and_then(|uuid| self.uuid_to_job.get(uuid))
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
|
||||
self.symbol_to_uuid
|
||||
.remove(symbol)
|
||||
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
|
||||
let backfill_jobs = Arc::new(Mutex::new(Jobs {
|
||||
symbol_to_uuid: HashMap::new(),
|
||||
uuid_to_job: HashMap::new(),
|
||||
}));
|
||||
|
||||
loop {
|
||||
let message = receiver.recv().await.unwrap();
|
||||
spawn(handle_backfill_message(
|
||||
handler.clone(),
|
||||
backfill_jobs.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backfill_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
backfill_jobs: Arc<Mutex<Jobs>>,
|
||||
message: Message,
|
||||
) {
|
||||
let mut backfill_jobs = backfill_jobs.lock().await;
|
||||
|
||||
match message.action {
|
||||
Action::Backfill => {
|
||||
let log_string = handler.log_string();
|
||||
let max_limit = handler.max_limit();
|
||||
|
||||
let backfills = handler
|
||||
.select_latest_backfills(&message.symbols)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|backfill| (backfill.symbol.clone(), backfill))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut jobs = vec![];
|
||||
|
||||
for symbol in message.symbols {
|
||||
if let Some(job) = backfill_jobs.get(&symbol) {
|
||||
if !job.is_finished() {
|
||||
warn!(
|
||||
"Backfill for {} {} is already running, skipping.",
|
||||
symbol, log_string
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let fetch_from = backfills
|
||||
.get(&symbol)
|
||||
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
|
||||
backfill.time + ONE_SECOND
|
||||
});
|
||||
|
||||
let fetch_to = last_minute();
|
||||
|
||||
if fetch_from > fetch_to {
|
||||
info!("No need to backfill {} {}.", symbol, log_string,);
|
||||
return;
|
||||
}
|
||||
|
||||
jobs.push((
|
||||
symbol,
|
||||
Job {
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let jobs = jobs
|
||||
.into_iter()
|
||||
.sorted_by_key(|job| job.1.fetch_from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut job_groups = vec![HashMap::new()];
|
||||
let mut current_minutes = 0;
|
||||
|
||||
for job in jobs {
|
||||
let minutes = (job.1.fetch_to - job.1.fetch_from).whole_minutes();
|
||||
|
||||
if job_groups.last().unwrap().is_empty() || (current_minutes + minutes) <= max_limit
|
||||
{
|
||||
let job_group = job_groups.last_mut().unwrap();
|
||||
job_group.insert(job.0, job.1);
|
||||
current_minutes += minutes;
|
||||
} else {
|
||||
let mut job_group = HashMap::new();
|
||||
job_group.insert(job.0, job.1);
|
||||
job_groups.push(job_group);
|
||||
current_minutes = minutes;
|
||||
}
|
||||
}
|
||||
|
||||
for job_group in job_groups {
|
||||
let symbols = job_group.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
let handler = handler.clone();
|
||||
let fut = spawn(async move {
|
||||
handler.queue_backfill(&job_group).await;
|
||||
handler.backfill(job_group).await;
|
||||
});
|
||||
|
||||
backfill_jobs.insert(symbols, fut);
|
||||
}
|
||||
}
|
||||
Action::Purge => {
|
||||
for symbol in &message.symbols {
|
||||
if let Some(job) = backfill_jobs.remove(symbol) {
|
||||
if !job.is_finished() {
|
||||
job.abort();
|
||||
}
|
||||
let _ = job.await;
|
||||
}
|
||||
}
|
||||
|
||||
try_join!(
|
||||
handler.delete_backfills(&message.symbols),
|
||||
handler.delete_data(&message.symbols)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
struct BarHandler {
|
||||
config: Arc<Config>,
|
||||
data_url: &'static str,
|
||||
api_query_constructor: fn(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar,
|
||||
}
|
||||
|
||||
fn us_equity_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar {
|
||||
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn crypto_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar {
|
||||
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for BarHandler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_bars::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.keys().collect::<Vec<_>>();
|
||||
|
||||
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling bars for {:?}.", symbols);
|
||||
|
||||
let mut bars = vec![];
|
||||
let mut last_time = symbols
|
||||
.iter()
|
||||
.map(|symbol| (symbol.clone(), None))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut next_page_token = None;
|
||||
|
||||
loop {
|
||||
let Ok(message) = alpaca::api::incoming::bar::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
self.data_url,
|
||||
&(self.api_query_constructor)(
|
||||
symbols.clone(),
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
next_page_token.clone(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
error!("Failed to backfill bars for {:?}.", symbols);
|
||||
return;
|
||||
};
|
||||
|
||||
for (symbol, bar_vec) in message.bars {
|
||||
if let Some(last) = bar_vec.last() {
|
||||
last_time.insert(symbol.clone(), Some(last.time));
|
||||
}
|
||||
|
||||
for bar in bar_vec {
|
||||
bars.push(Bar::from((bar, symbol.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bars,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
bars = vec![];
|
||||
}
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
next_page_token = message.next_page_token;
|
||||
}
|
||||
|
||||
let (backfilled, skipped): (Vec<_>, Vec<_>) =
|
||||
last_time.into_iter().partition_map(|(symbol, time)| {
|
||||
if let Some(time) = time {
|
||||
Either::Left(Backfill { symbol, time })
|
||||
} else {
|
||||
Either::Right(symbol)
|
||||
}
|
||||
});
|
||||
|
||||
database::backfills_bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("No bars to backfill for {:?}.", skipped);
|
||||
info!("Backfilled bars for {:?}.", backfilled);
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::api::outgoing::bar::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"bars"
|
||||
}
|
||||
}
|
||||
|
||||
struct NewsHandler {
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for NewsHandler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_news::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling news for {:?}.", symbols);
|
||||
|
||||
let mut news = vec![];
|
||||
let mut last_time = symbols
|
||||
.iter()
|
||||
.map(|symbol| (symbol.clone(), None))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut next_page_token = None;
|
||||
|
||||
loop {
|
||||
let Ok(message) = alpaca::api::incoming::news::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
&alpaca::api::outgoing::news::News {
|
||||
symbols: symbols.clone(),
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
error!("Failed to backfill news for {:?}.", symbols);
|
||||
return;
|
||||
};
|
||||
|
||||
for news_item in message.news {
|
||||
let news_item = News::from(news_item);
|
||||
|
||||
for symbol in &news_item.symbols {
|
||||
last_time.insert(symbol.clone(), Some(news_item.time_created));
|
||||
}
|
||||
|
||||
news.push(news_item);
|
||||
}
|
||||
|
||||
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
|
||||
let inputs = news
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predictions =
|
||||
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
news = news
|
||||
.into_iter()
|
||||
.zip(predictions)
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
news = vec![];
|
||||
}
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
next_page_token = message.next_page_token;
|
||||
}
|
||||
|
||||
let (backfilled, skipped): (Vec<_>, Vec<_>) =
|
||||
last_time.into_iter().partition_map(|(symbol, time)| {
|
||||
if let Some(time) = time {
|
||||
Either::Left(Backfill { symbol, time })
|
||||
} else {
|
||||
Either::Right(symbol)
|
||||
}
|
||||
});
|
||||
|
||||
database::backfills_news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
info!("No news to backfill for {:?}.", skipped);
|
||||
info!("Backfilled news for {:?}.", backfilled);
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::api::outgoing::news::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"news"
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
|
||||
match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => Box::new(BarHandler {
|
||||
config,
|
||||
data_url: ALPACA_STOCK_DATA_API_URL,
|
||||
api_query_constructor: us_equity_query_constructor,
|
||||
}),
|
||||
ThreadType::Bars(Class::Crypto) => Box::new(BarHandler {
|
||||
config,
|
||||
data_url: ALPACA_CRYPTO_DATA_API_URL,
|
||||
api_query_constructor: crypto_query_constructor,
|
||||
}),
|
||||
ThreadType::News => Box::new(NewsHandler { config }),
|
||||
}
|
||||
}
|
207
src/threads/data/backfill/bars.rs
Normal file
207
src/threads/data/backfill/bars.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE},
|
||||
database,
|
||||
types::{
|
||||
alpaca::{
|
||||
self,
|
||||
shared::{Sort, Source},
|
||||
},
|
||||
Backfill, Bar,
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use itertools::{Either, Itertools};
|
||||
use log::{error, info};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
pub data_url: &'static str,
|
||||
pub api_query_constructor: fn(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar,
|
||||
}
|
||||
|
||||
pub fn us_equity_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar {
|
||||
alpaca::api::outgoing::bar::Bar::UsEquity(alpaca::api::outgoing::bar::UsEquity {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn crypto_query_constructor(
|
||||
symbols: Vec<String>,
|
||||
fetch_from: OffsetDateTime,
|
||||
fetch_to: OffsetDateTime,
|
||||
next_page_token: Option<String>,
|
||||
) -> alpaca::api::outgoing::bar::Bar {
|
||||
alpaca::api::outgoing::bar::Bar::Crypto(alpaca::api::outgoing::bar::Crypto {
|
||||
symbols,
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token,
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_bars::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::bars::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.keys().collect::<Vec<_>>();
|
||||
|
||||
info!("Queing bar backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling bars for {:?}.", symbols);
|
||||
|
||||
let mut bars = vec![];
|
||||
let mut last_times = symbols
|
||||
.iter()
|
||||
.map(|symbol| (symbol.clone(), None))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut next_page_token = None;
|
||||
|
||||
loop {
|
||||
let Ok(message) = alpaca::api::incoming::bar::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
self.data_url,
|
||||
&(self.api_query_constructor)(
|
||||
symbols.clone(),
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
next_page_token.clone(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
error!("Failed to backfill bars for {:?}.", symbols);
|
||||
return;
|
||||
};
|
||||
|
||||
for (symbol, bar_vec) in message.bars {
|
||||
if let Some(last) = bar_vec.last() {
|
||||
last_times.insert(symbol.clone(), Some(last.time));
|
||||
}
|
||||
|
||||
for bar in bar_vec {
|
||||
bars.push(Bar::from((bar, symbol.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
if bars.len() >= database::bars::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bars,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
bars = vec![];
|
||||
}
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
next_page_token = message.next_page_token;
|
||||
}
|
||||
|
||||
let (backfilled, skipped): (Vec<_>, Vec<_>) =
|
||||
last_times.into_iter().partition_map(|(symbol, time)| {
|
||||
if let Some(time) = time {
|
||||
Either::Left(Backfill { symbol, time })
|
||||
} else {
|
||||
Either::Right(symbol)
|
||||
}
|
||||
});
|
||||
|
||||
database::backfills_bars::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = backfilled
|
||||
.into_iter()
|
||||
.map(|backfill| backfill.symbol)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !skipped.is_empty() {
|
||||
info!("No bars to backfill for {:?}.", skipped);
|
||||
}
|
||||
|
||||
if !backfilled.is_empty() {
|
||||
info!("Backfilled bars for {:?}.", backfilled);
|
||||
}
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::api::outgoing::bar::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"bars"
|
||||
}
|
||||
}
|
237
src/threads/data/backfill/mod.rs
Normal file
237
src/threads/data/backfill/mod.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
mod bars;
|
||||
mod news;
|
||||
|
||||
use super::ThreadType;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_CRYPTO_DATA_API_URL, ALPACA_STOCK_DATA_API_URL},
|
||||
types::{Backfill, Class},
|
||||
utils::{last_minute, ONE_SECOND},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use itertools::Itertools;
|
||||
use log::{info, warn};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
task::JoinHandle,
|
||||
try_join,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub enum Action {
|
||||
Backfill,
|
||||
Purge,
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Action,
|
||||
pub symbols: Vec<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Action, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel::<()>();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Job {
|
||||
pub fetch_from: OffsetDateTime,
|
||||
pub fetch_to: OffsetDateTime,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error>;
|
||||
async fn delete_backfills(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn delete_data(&self, symbol: &[String]) -> Result<(), clickhouse::error::Error>;
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>);
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>);
|
||||
fn max_limit(&self) -> i64;
|
||||
fn log_string(&self) -> &'static str;
|
||||
}
|
||||
|
||||
pub struct Jobs {
|
||||
pub symbol_to_uuid: HashMap<String, Uuid>,
|
||||
pub uuid_to_job: HashMap<Uuid, JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl Jobs {
|
||||
pub fn insert(&mut self, jobs: Vec<String>, fut: JoinHandle<()>) {
|
||||
let uuid = Uuid::new_v4();
|
||||
for symbol in jobs {
|
||||
self.symbol_to_uuid.insert(symbol.clone(), uuid);
|
||||
}
|
||||
self.uuid_to_job.insert(uuid, fut);
|
||||
}
|
||||
|
||||
pub fn get(&self, symbol: &str) -> Option<&JoinHandle<()>> {
|
||||
self.symbol_to_uuid
|
||||
.get(symbol)
|
||||
.and_then(|uuid| self.uuid_to_job.get(uuid))
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, symbol: &str) -> Option<JoinHandle<()>> {
|
||||
self.symbol_to_uuid
|
||||
.remove(symbol)
|
||||
.and_then(|uuid| self.uuid_to_job.remove(&uuid))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(handler: Arc<Box<dyn Handler>>, mut receiver: mpsc::Receiver<Message>) {
|
||||
let backfill_jobs = Arc::new(Mutex::new(Jobs {
|
||||
symbol_to_uuid: HashMap::new(),
|
||||
uuid_to_job: HashMap::new(),
|
||||
}));
|
||||
|
||||
loop {
|
||||
let message = receiver.recv().await.unwrap();
|
||||
spawn(handle_backfill_message(
|
||||
handler.clone(),
|
||||
backfill_jobs.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_backfill_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
backfill_jobs: Arc<Mutex<Jobs>>,
|
||||
message: Message,
|
||||
) {
|
||||
let mut backfill_jobs = backfill_jobs.lock().await;
|
||||
|
||||
match message.action {
|
||||
Action::Backfill => {
|
||||
let log_string = handler.log_string();
|
||||
let max_limit = handler.max_limit();
|
||||
|
||||
let backfills = handler
|
||||
.select_latest_backfills(&message.symbols)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|backfill| (backfill.symbol.clone(), backfill))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut jobs = vec![];
|
||||
|
||||
for symbol in message.symbols {
|
||||
if let Some(job) = backfill_jobs.get(&symbol) {
|
||||
if !job.is_finished() {
|
||||
warn!(
|
||||
"Backfill for {} {} is already running, skipping.",
|
||||
symbol, log_string
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let fetch_from = backfills
|
||||
.get(&symbol)
|
||||
.map_or(OffsetDateTime::UNIX_EPOCH, |backfill| {
|
||||
backfill.time + ONE_SECOND
|
||||
});
|
||||
|
||||
let fetch_to = last_minute();
|
||||
|
||||
if fetch_from > fetch_to {
|
||||
info!("No need to backfill {} {}.", symbol, log_string,);
|
||||
return;
|
||||
}
|
||||
|
||||
jobs.push((
|
||||
symbol,
|
||||
Job {
|
||||
fetch_from,
|
||||
fetch_to,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let jobs = jobs
|
||||
.into_iter()
|
||||
.sorted_by_key(|job| job.1.fetch_from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut job_groups = vec![HashMap::new()];
|
||||
let mut current_minutes = 0;
|
||||
|
||||
for job in jobs {
|
||||
let minutes = (job.1.fetch_to - job.1.fetch_from).whole_minutes();
|
||||
|
||||
if job_groups.last().unwrap().is_empty() || (current_minutes + minutes) <= max_limit
|
||||
{
|
||||
let job_group = job_groups.last_mut().unwrap();
|
||||
job_group.insert(job.0, job.1);
|
||||
current_minutes += minutes;
|
||||
} else {
|
||||
let mut job_group = HashMap::new();
|
||||
job_group.insert(job.0, job.1);
|
||||
job_groups.push(job_group);
|
||||
current_minutes = minutes;
|
||||
}
|
||||
}
|
||||
|
||||
for job_group in job_groups {
|
||||
let symbols = job_group.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
let handler = handler.clone();
|
||||
let fut = spawn(async move {
|
||||
handler.queue_backfill(&job_group).await;
|
||||
handler.backfill(job_group).await;
|
||||
});
|
||||
|
||||
backfill_jobs.insert(symbols, fut);
|
||||
}
|
||||
}
|
||||
Action::Purge => {
|
||||
for symbol in &message.symbols {
|
||||
if let Some(job) = backfill_jobs.remove(symbol) {
|
||||
if !job.is_finished() {
|
||||
job.abort();
|
||||
}
|
||||
let _ = job.await;
|
||||
}
|
||||
}
|
||||
|
||||
try_join!(
|
||||
handler.delete_backfills(&message.symbols),
|
||||
handler.delete_data(&message.symbols)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
|
||||
match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => Box::new(bars::Handler {
|
||||
config,
|
||||
data_url: ALPACA_STOCK_DATA_API_URL,
|
||||
api_query_constructor: bars::us_equity_query_constructor,
|
||||
}),
|
||||
ThreadType::Bars(Class::Crypto) => Box::new(bars::Handler {
|
||||
config,
|
||||
data_url: ALPACA_CRYPTO_DATA_API_URL,
|
||||
api_query_constructor: bars::crypto_query_constructor,
|
||||
}),
|
||||
ThreadType::News => Box::new(news::Handler { config }),
|
||||
}
|
||||
}
|
205
src/threads/data/backfill/news.rs
Normal file
205
src/threads/data/backfill/news.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use super::Job;
|
||||
use crate::{
|
||||
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
|
||||
database,
|
||||
types::{
|
||||
alpaca::{
|
||||
self,
|
||||
shared::{Sort, Source},
|
||||
},
|
||||
news::Prediction,
|
||||
Backfill, News,
|
||||
},
|
||||
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future::join_all;
|
||||
use itertools::{Either, Itertools};
|
||||
use log::{error, info};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{task::block_in_place, time::sleep};
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
async fn select_latest_backfills(
|
||||
&self,
|
||||
symbols: &[String],
|
||||
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
||||
database::backfills_news::select_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::backfills_news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
||||
database::news::delete_where_symbols(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
symbols,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
|
||||
if *ALPACA_SOURCE == Source::Sip {
|
||||
return;
|
||||
}
|
||||
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
|
||||
sleep(run_delay).await;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
||||
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
||||
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
||||
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
||||
|
||||
info!("Backfilling news for {:?}.", symbols);
|
||||
|
||||
let mut news = vec![];
|
||||
let mut last_times = symbols
|
||||
.iter()
|
||||
.map(|symbol| (symbol.clone(), None))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut next_page_token = None;
|
||||
|
||||
loop {
|
||||
let Ok(message) = alpaca::api::incoming::news::get(
|
||||
&self.config.alpaca_client,
|
||||
&self.config.alpaca_rate_limiter,
|
||||
&alpaca::api::outgoing::news::News {
|
||||
symbols: symbols.clone(),
|
||||
start: Some(fetch_from),
|
||||
end: Some(fetch_to),
|
||||
page_token: next_page_token.clone(),
|
||||
sort: Some(Sort::Asc),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
error!("Failed to backfill news for {:?}.", symbols);
|
||||
return;
|
||||
};
|
||||
|
||||
for news_item in message.news {
|
||||
let news_item = News::from(news_item);
|
||||
|
||||
for symbol in &news_item.symbols {
|
||||
if last_times.contains_key(symbol) {
|
||||
last_times.insert(symbol.clone(), Some(news_item.time_created));
|
||||
}
|
||||
}
|
||||
|
||||
news.push(news_item);
|
||||
}
|
||||
|
||||
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
|
||||
let inputs = news
|
||||
.iter()
|
||||
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predictions =
|
||||
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
news = news
|
||||
.into_iter()
|
||||
.zip(predictions)
|
||||
.map(|(news, prediction)| News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
||||
database::news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
news = vec![];
|
||||
}
|
||||
|
||||
if message.next_page_token.is_none() {
|
||||
break;
|
||||
}
|
||||
next_page_token = message.next_page_token;
|
||||
}
|
||||
|
||||
let (backfilled, skipped): (Vec<_>, Vec<_>) =
|
||||
last_times.into_iter().partition_map(|(symbol, time)| {
|
||||
if let Some(time) = time {
|
||||
Either::Left(Backfill { symbol, time })
|
||||
} else {
|
||||
Either::Right(symbol)
|
||||
}
|
||||
});
|
||||
|
||||
database::backfills_news::upsert_batch(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&backfilled,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let backfilled = backfilled
|
||||
.into_iter()
|
||||
.map(|backfill| backfill.symbol)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !skipped.is_empty() {
|
||||
info!("No news to backfill for {:?}.", skipped);
|
||||
}
|
||||
|
||||
if !backfilled.is_empty() {
|
||||
info!("Backfilled news for {:?}.", backfilled);
|
||||
}
|
||||
}
|
||||
|
||||
fn max_limit(&self) -> i64 {
|
||||
alpaca::api::outgoing::news::MAX_LIMIT
|
||||
}
|
||||
|
||||
fn log_string(&self) -> &'static str {
|
||||
"news"
|
||||
}
|
||||
}
|
@@ -1,437 +0,0 @@
|
||||
use super::ThreadType;
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, news::Prediction, Bar, Class, News},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{
|
||||
future::join_all,
|
||||
stream::{SplitSink, SplitStream},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use log::{debug, error, info};
|
||||
use serde_json::{from_str, to_string};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
select, spawn,
|
||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||
task::block_in_place,
|
||||
};
|
||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
pub enum Action {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
||||
|
||||
impl From<super::Action> for Option<Action> {
|
||||
fn from(action: super::Action) -> Self {
|
||||
match action {
|
||||
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
|
||||
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Option<Action>,
|
||||
pub symbols: Vec<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pending {
|
||||
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message;
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
mut receiver: mpsc::Receiver<Message>,
|
||||
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
|
||||
) {
|
||||
let pending = Arc::new(RwLock::new(Pending {
|
||||
subscriptions: HashMap::new(),
|
||||
unsubscriptions: HashMap::new(),
|
||||
}));
|
||||
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = receiver.recv() => {
|
||||
spawn(handle_message(
|
||||
handler.clone(),
|
||||
pending.clone(),
|
||||
websocket_sink.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
Some(Ok(message)) = websocket_stream.next() => {
|
||||
match message {
|
||||
tungstenite::Message::Text(message) => {
|
||||
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
|
||||
|
||||
if parsed_message.is_err() {
|
||||
error!("Failed to deserialize websocket message: {:?}", message);
|
||||
continue;
|
||||
}
|
||||
|
||||
for message in parsed_message.unwrap() {
|
||||
let handler = handler.clone();
|
||||
let pending = pending.clone();
|
||||
spawn(async move {
|
||||
handler.handle_websocket_message(pending, message).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
tungstenite::Message::Ping(_) => {}
|
||||
_ => error!("Unexpected websocket message: {:?}", message),
|
||||
}
|
||||
}
|
||||
else => panic!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
|
||||
message: Message,
|
||||
) {
|
||||
if message.symbols.is_empty() {
|
||||
message.response.send(()).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
match message.action {
|
||||
Some(Action::Subscribe) => {
|
||||
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.subscriptions
|
||||
.extend(pending_subscriptions);
|
||||
|
||||
sink.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Subscribe(
|
||||
handler.create_subscription_message(message.symbols),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
Some(Action::Unsubscribe) => {
|
||||
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.unsubscriptions
|
||||
.extend(pending_unsubscriptions);
|
||||
|
||||
sink.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Unsubscribe(
|
||||
handler.create_subscription_message(message.symbols.clone()),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
struct BarsHandler {
|
||||
config: Arc<Config>,
|
||||
subscription_message_constructor:
|
||||
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for BarsHandler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
(self.subscription_message_constructor)(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::Market {
|
||||
bars: symbols, ..
|
||||
} = message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to bars for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from bars for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Bar(message)
|
||||
| websocket::data::incoming::Message::UpdatedBar(message) => {
|
||||
let bar = Bar::from(message);
|
||||
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
||||
|
||||
database::bars::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bar,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Status(message) => {
|
||||
debug!(
|
||||
"Received status message for {}: {:?}.",
|
||||
message.symbol, message.status
|
||||
);
|
||||
|
||||
match message.status {
|
||||
websocket::data::incoming::status::Status::TradingHalt(_)
|
||||
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::status::Status::Resume(_)
|
||||
| websocket::data::incoming::status::Status::TradingResumption(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NewsHandler {
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for NewsHandler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
websocket::data::outgoing::subscribe::Message::new_news(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::News { news: symbols } =
|
||||
message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to news for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from news for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
|
||||
database::news::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
|
||||
match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => Box::new(BarsHandler {
|
||||
config,
|
||||
subscription_message_constructor:
|
||||
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
|
||||
}),
|
||||
ThreadType::Bars(Class::Crypto) => Box::new(BarsHandler {
|
||||
config,
|
||||
subscription_message_constructor:
|
||||
websocket::data::outgoing::subscribe::Message::new_market_crypto,
|
||||
}),
|
||||
ThreadType::News => Box::new(NewsHandler { config }),
|
||||
}
|
||||
}
|
128
src/threads/data/websocket/bars.rs
Normal file
128
src/threads/data/websocket/bars.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use super::Pending;
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, Bar},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, error, info};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
pub subscription_message_constructor:
|
||||
fn(Vec<String>) -> websocket::data::outgoing::subscribe::Message,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
(self.subscription_message_constructor)(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::Market {
|
||||
bars: symbols, ..
|
||||
} = message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to bars for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from bars for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Bar(message)
|
||||
| websocket::data::incoming::Message::UpdatedBar(message) => {
|
||||
let bar = Bar::from(message);
|
||||
debug!("Received bar for {}: {}.", bar.symbol, bar.time);
|
||||
|
||||
database::bars::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&bar,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Status(message) => {
|
||||
debug!(
|
||||
"Received status message for {}: {:?}.",
|
||||
message.symbol, message.status
|
||||
);
|
||||
|
||||
match message.status {
|
||||
websocket::data::incoming::status::Status::TradingHalt(_)
|
||||
| websocket::data::incoming::status::Status::VolatilityTradingPause(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::status::Status::Resume(_)
|
||||
| websocket::data::incoming::status::Status::TradingResumption(_) => {
|
||||
database::assets::update_status_where_symbol(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&message.symbol,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
216
src/threads/data/websocket/mod.rs
Normal file
216
src/threads/data/websocket/mod.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
mod bars;
|
||||
mod news;
|
||||
|
||||
use super::ThreadType;
|
||||
use crate::{
|
||||
config::Config,
|
||||
types::{alpaca::websocket, Class},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{
|
||||
future::join_all,
|
||||
stream::{SplitSink, SplitStream},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use log::error;
|
||||
use serde_json::{from_str, to_string};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
select, spawn,
|
||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||
};
|
||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||
|
||||
pub enum Action {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
||||
|
||||
impl From<super::Action> for Option<Action> {
|
||||
fn from(action: super::Action) -> Self {
|
||||
match action {
|
||||
super::Action::Add | super::Action::Enable => Some(Action::Subscribe),
|
||||
super::Action::Remove | super::Action::Disable => Some(Action::Unsubscribe),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Message {
|
||||
pub action: Option<Action>,
|
||||
pub symbols: Vec<String>,
|
||||
pub response: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(action: Option<Action>, symbols: Vec<String>) -> (Self, oneshot::Receiver<()>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
(
|
||||
Self {
|
||||
action,
|
||||
symbols,
|
||||
response: sender,
|
||||
},
|
||||
receiver,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pending {
|
||||
pub subscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
pub unsubscriptions: HashMap<String, oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler: Send + Sync {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message;
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
mut receiver: mpsc::Receiver<Message>,
|
||||
mut websocket_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
websocket_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>,
|
||||
) {
|
||||
let pending = Arc::new(RwLock::new(Pending {
|
||||
subscriptions: HashMap::new(),
|
||||
unsubscriptions: HashMap::new(),
|
||||
}));
|
||||
let websocket_sink = Arc::new(Mutex::new(websocket_sink));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(message) = receiver.recv() => {
|
||||
spawn(handle_message(
|
||||
handler.clone(),
|
||||
pending.clone(),
|
||||
websocket_sink.clone(),
|
||||
message,
|
||||
));
|
||||
}
|
||||
Some(Ok(message)) = websocket_stream.next() => {
|
||||
match message {
|
||||
tungstenite::Message::Text(message) => {
|
||||
let parsed_message = from_str::<Vec<websocket::data::incoming::Message>>(&message);
|
||||
|
||||
if parsed_message.is_err() {
|
||||
error!("Failed to deserialize websocket message: {:?}", message);
|
||||
continue;
|
||||
}
|
||||
|
||||
for message in parsed_message.unwrap() {
|
||||
let handler = handler.clone();
|
||||
let pending = pending.clone();
|
||||
spawn(async move {
|
||||
handler.handle_websocket_message(pending, message).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
tungstenite::Message::Ping(_) => {}
|
||||
_ => error!("Unexpected websocket message: {:?}", message),
|
||||
}
|
||||
}
|
||||
else => panic!("Communication channel unexpectedly closed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
handler: Arc<Box<dyn Handler>>,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
sink: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>>>,
|
||||
message: Message,
|
||||
) {
|
||||
if message.symbols.is_empty() {
|
||||
message.response.send(()).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
match message.action {
|
||||
Some(Action::Subscribe) => {
|
||||
let (pending_subscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.subscriptions
|
||||
.extend(pending_subscriptions);
|
||||
|
||||
sink.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Subscribe(
|
||||
handler.create_subscription_message(message.symbols),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
Some(Action::Unsubscribe) => {
|
||||
let (pending_unsubscriptions, receivers): (Vec<_>, Vec<_>) = message
|
||||
.symbols
|
||||
.iter()
|
||||
.map(|symbol| {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
((symbol.clone(), sender), receiver)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
pending
|
||||
.write()
|
||||
.await
|
||||
.unsubscriptions
|
||||
.extend(pending_unsubscriptions);
|
||||
|
||||
sink.lock()
|
||||
.await
|
||||
.send(tungstenite::Message::Text(
|
||||
to_string(&websocket::data::outgoing::Message::Unsubscribe(
|
||||
handler.create_subscription_message(message.symbols.clone()),
|
||||
))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_all(receivers).await;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
message.response.send(()).unwrap();
|
||||
}
|
||||
|
||||
pub fn create_handler(thread_type: ThreadType, config: Arc<Config>) -> Box<dyn Handler> {
|
||||
match thread_type {
|
||||
ThreadType::Bars(Class::UsEquity) => Box::new(bars::Handler {
|
||||
config,
|
||||
subscription_message_constructor:
|
||||
websocket::data::outgoing::subscribe::Message::new_market_us_equity,
|
||||
}),
|
||||
ThreadType::Bars(Class::Crypto) => Box::new(bars::Handler {
|
||||
config,
|
||||
subscription_message_constructor:
|
||||
websocket::data::outgoing::subscribe::Message::new_market_crypto,
|
||||
}),
|
||||
ThreadType::News => Box::new(news::Handler { config }),
|
||||
}
|
||||
}
|
114
src/threads/data/websocket/news.rs
Normal file
114
src/threads/data/websocket/news.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use super::Pending;
|
||||
use crate::{
|
||||
config::Config,
|
||||
database,
|
||||
types::{alpaca::websocket, news::Prediction, News},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, error, info};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{sync::RwLock, task::block_in_place};
|
||||
|
||||
pub struct Handler {
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Handler for Handler {
|
||||
fn create_subscription_message(
|
||||
&self,
|
||||
symbols: Vec<String>,
|
||||
) -> websocket::data::outgoing::subscribe::Message {
|
||||
websocket::data::outgoing::subscribe::Message::new_news(symbols)
|
||||
}
|
||||
|
||||
async fn handle_websocket_message(
|
||||
&self,
|
||||
pending: Arc<RwLock<Pending>>,
|
||||
message: websocket::data::incoming::Message,
|
||||
) {
|
||||
match message {
|
||||
websocket::data::incoming::Message::Subscription(message) => {
|
||||
let websocket::data::incoming::subscription::Message::News { news: symbols } =
|
||||
message
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut pending = pending.write().await;
|
||||
|
||||
let newly_subscribed = pending
|
||||
.subscriptions
|
||||
.extract_if(|symbol, _| symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let newly_unsubscribed = pending
|
||||
.unsubscriptions
|
||||
.extract_if(|symbol, _| !symbols.contains(symbol))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
drop(pending);
|
||||
|
||||
if !newly_subscribed.is_empty() {
|
||||
info!(
|
||||
"Subscribed to news for {:?}.",
|
||||
newly_subscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_subscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if !newly_unsubscribed.is_empty() {
|
||||
info!(
|
||||
"Unsubscribed from news for {:?}.",
|
||||
newly_unsubscribed.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
for sender in newly_unsubscribed.into_values() {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
websocket::data::incoming::Message::News(message) => {
|
||||
let news = News::from(message);
|
||||
|
||||
debug!(
|
||||
"Received news for {:?}: {}.",
|
||||
news.symbols, news.time_created
|
||||
);
|
||||
|
||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||
|
||||
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
||||
let prediction = block_in_place(|| {
|
||||
sequence_classifier
|
||||
.predict(vec![input.as_str()])
|
||||
.into_iter()
|
||||
.map(|label| Prediction::try_from(label).unwrap())
|
||||
.collect::<Vec<_>>()[0]
|
||||
});
|
||||
drop(sequence_classifier);
|
||||
|
||||
let news = News {
|
||||
sentiment: prediction.sentiment,
|
||||
confidence: prediction.confidence,
|
||||
..news
|
||||
};
|
||||
|
||||
database::news::upsert(
|
||||
&self.config.clickhouse_client,
|
||||
&self.config.clickhouse_concurrency_limiter,
|
||||
&news,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
websocket::data::incoming::Message::Error(message) => {
|
||||
error!("Received error message: {}.", message.message);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user