Files
qrust/src/threads/data/backfill/news.rs
2024-03-11 20:41:59 +00:00

197 lines
6.3 KiB
Rust

use super::Job;
use crate::{
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
database,
};
use async_trait::async_trait;
use futures_util::future::join_all;
use log::{error, info};
use qrust::{
types::{
alpaca::{
self,
shared::{Sort, Source},
},
news::Prediction,
Backfill, News,
},
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{task::block_in_place, time::sleep};
pub struct Handler {
pub config: Arc<Config>,
}
#[async_trait]
impl super::Handler for Handler {
async fn select_latest_backfills(
&self,
symbols: &[String],
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
database::backfills_news::select_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::backfills_news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
database::news::delete_where_symbols(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
symbols,
)
.await
}
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
if jobs.is_empty() || *ALPACA_SOURCE == Source::Sip {
return;
}
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
sleep(run_delay).await;
}
async fn backfill(&self, jobs: HashMap<String, Job>) {
if jobs.is_empty() {
return;
}
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
let symbols_set = symbols.iter().collect::<std::collections::HashSet<_>>();
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
info!("Backfilling news for {:?}.", symbols);
let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE);
let mut last_times = HashMap::new();
let mut next_page_token = None;
loop {
let message = alpaca::api::incoming::news::get(
&self.config.alpaca_client,
&self.config.alpaca_rate_limiter,
&alpaca::api::outgoing::news::News {
symbols: symbols.clone(),
start: Some(fetch_from),
end: Some(fetch_to),
page_token: next_page_token.clone(),
sort: Some(Sort::Asc),
..Default::default()
},
None,
)
.await;
if let Err(err) = message {
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
return;
}
let message = message.unwrap();
for news_item in message.news {
let news_item = News::from(news_item);
for symbol in &news_item.symbols {
if symbols_set.contains(symbol) {
last_times.insert(symbol.clone(), news_item.time_created);
}
}
news.push(news_item);
}
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
let inputs = news
.iter()
.map(|news| format!("{}\n\n{}", news.headline, news.content))
.collect::<Vec<_>>();
let predictions =
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
let sequence_classifier = self.config.sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter()
.map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()
})
}))
.await
.into_iter()
.flatten();
news = news
.into_iter()
.zip(predictions)
.map(|(news, prediction)| News {
sentiment: prediction.sentiment,
confidence: prediction.confidence,
..news
})
.collect::<Vec<_>>();
}
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
database::news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&news,
)
.await
.unwrap();
let backfilled = last_times
.drain()
.map(|(symbol, time)| Backfill { symbol, time })
.collect::<Vec<_>>();
database::backfills_news::upsert_batch(
&self.config.clickhouse_client,
&self.config.clickhouse_concurrency_limiter,
&backfilled,
)
.await
.unwrap();
if message.next_page_token.is_none() {
break;
}
next_page_token = message.next_page_token;
news.clear();
}
}
info!("Backfilled news for {:?}.", symbols);
}
fn max_limit(&self) -> i64 {
alpaca::api::outgoing::news::MAX_LIMIT
}
fn log_string(&self) -> &'static str {
"news"
}
}