use super::Job; use crate::{ config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS}, database, }; use async_trait::async_trait; use futures_util::future::join_all; use log::{error, info}; use qrust::{ types::{ alpaca::{ self, shared::{Sort, Source}, }, news::Prediction, Backfill, News, }, utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE}, }; use std::{collections::HashMap, sync::Arc}; use tokio::{task::block_in_place, time::sleep}; pub struct Handler { pub config: Arc, } #[async_trait] impl super::Handler for Handler { async fn select_latest_backfills( &self, symbols: &[String], ) -> Result, clickhouse::error::Error> { database::backfills_news::select_where_symbols( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, symbols, ) .await } async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::backfills_news::delete_where_symbols( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, symbols, ) .await } async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> { database::news::delete_where_symbols( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, symbols, ) .await } async fn queue_backfill(&self, jobs: &HashMap) { if jobs.is_empty() || *ALPACA_SOURCE == Source::Sip { return; } let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE); let symbols = jobs.keys().cloned().collect::>(); info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay); sleep(run_delay).await; } async fn backfill(&self, jobs: HashMap) { if jobs.is_empty() { return; } let symbols = jobs.keys().cloned().collect::>(); let symbols_set = symbols.iter().collect::>(); let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap(); let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap(); info!("Backfilling news for {:?}.", symbols); let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE); let mut last_times = HashMap::new(); let mut next_page_token = None; loop { let message = alpaca::api::incoming::news::get( &self.config.alpaca_client, &self.config.alpaca_rate_limiter, &alpaca::api::outgoing::news::News { symbols: symbols.clone(), start: Some(fetch_from), end: Some(fetch_to), page_token: next_page_token.clone(), sort: Some(Sort::Asc), ..Default::default() }, None, ) .await; if let Err(err) = message { error!("Failed to backfill news for {:?}: {:?}.", symbols, err); return; } let message = message.unwrap(); for news_item in message.news { let news_item = News::from(news_item); for symbol in &news_item.symbols { if symbols_set.contains(symbol) { last_times.insert(symbol.clone(), news_item.time_created); } } news.push(news_item); } if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() { let inputs = news .iter() .map(|news| format!("{}\n\n{}", news.headline, news.content)) .collect::>(); let predictions = join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move { let sequence_classifier = self.config.sequence_classifier.lock().await; block_in_place(|| { sequence_classifier .predict(inputs.iter().map(String::as_str).collect::>()) .into_iter() .map(|label| Prediction::try_from(label).unwrap()) .collect::>() }) })) .await .into_iter() .flatten(); news = news .into_iter() .zip(predictions) .map(|(news, prediction)| News { sentiment: prediction.sentiment, confidence: prediction.confidence, ..news }) .collect::>(); } if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() { database::news::upsert_batch( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &news, ) .await .unwrap(); let backfilled = last_times .drain() .map(|(symbol, time)| Backfill { symbol, time }) .collect::>(); database::backfills_news::upsert_batch( &self.config.clickhouse_client, &self.config.clickhouse_concurrency_limiter, &backfilled, ) .await .unwrap(); if message.next_page_token.is_none() { break; } next_page_token = message.next_page_token; news.clear(); } } info!("Backfilled news for {:?}.", symbols); } fn max_limit(&self) -> i64 { alpaca::api::outgoing::news::MAX_LIMIT } fn log_string(&self) -> &'static str { "news" } }