197 lines
6.3 KiB
Rust
197 lines
6.3 KiB
Rust
use super::Job;
|
|
use crate::{
|
|
config::{Config, ALPACA_SOURCE, BERT_MAX_INPUTS},
|
|
database,
|
|
};
|
|
use async_trait::async_trait;
|
|
use futures_util::future::join_all;
|
|
use log::{error, info};
|
|
use qrust::{
|
|
types::{
|
|
alpaca::{
|
|
self,
|
|
shared::{Sort, Source},
|
|
},
|
|
news::Prediction,
|
|
Backfill, News,
|
|
},
|
|
utils::{duration_until, FIFTEEN_MINUTES, ONE_MINUTE},
|
|
};
|
|
use std::{collections::HashMap, sync::Arc};
|
|
use tokio::{task::block_in_place, time::sleep};
|
|
|
|
pub struct Handler {
|
|
pub config: Arc<Config>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Handler for Handler {
|
|
async fn select_latest_backfills(
|
|
&self,
|
|
symbols: &[String],
|
|
) -> Result<Vec<Backfill>, clickhouse::error::Error> {
|
|
database::backfills_news::select_where_symbols(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
symbols,
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn delete_backfills(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
|
database::backfills_news::delete_where_symbols(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
symbols,
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn delete_data(&self, symbols: &[String]) -> Result<(), clickhouse::error::Error> {
|
|
database::news::delete_where_symbols(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
symbols,
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn queue_backfill(&self, jobs: &HashMap<String, Job>) {
|
|
if jobs.is_empty() || *ALPACA_SOURCE == Source::Sip {
|
|
return;
|
|
}
|
|
|
|
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
|
let run_delay = duration_until(fetch_to + FIFTEEN_MINUTES + ONE_MINUTE);
|
|
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
|
|
|
info!("Queing news backfill for {:?} in {:?}.", symbols, run_delay);
|
|
sleep(run_delay).await;
|
|
}
|
|
|
|
async fn backfill(&self, jobs: HashMap<String, Job>) {
|
|
if jobs.is_empty() {
|
|
return;
|
|
}
|
|
|
|
let symbols = jobs.keys().cloned().collect::<Vec<_>>();
|
|
let symbols_set = symbols.iter().collect::<std::collections::HashSet<_>>();
|
|
let fetch_from = jobs.values().map(|job| job.fetch_from).min().unwrap();
|
|
let fetch_to = jobs.values().map(|job| job.fetch_to).max().unwrap();
|
|
|
|
info!("Backfilling news for {:?}.", symbols);
|
|
|
|
let mut news = Vec::with_capacity(database::news::BATCH_FLUSH_SIZE);
|
|
let mut last_times = HashMap::new();
|
|
let mut next_page_token = None;
|
|
|
|
loop {
|
|
let message = alpaca::api::incoming::news::get(
|
|
&self.config.alpaca_client,
|
|
&self.config.alpaca_rate_limiter,
|
|
&alpaca::api::outgoing::news::News {
|
|
symbols: symbols.clone(),
|
|
start: Some(fetch_from),
|
|
end: Some(fetch_to),
|
|
page_token: next_page_token.clone(),
|
|
sort: Some(Sort::Asc),
|
|
..Default::default()
|
|
},
|
|
None,
|
|
)
|
|
.await;
|
|
|
|
if let Err(err) = message {
|
|
error!("Failed to backfill news for {:?}: {:?}.", symbols, err);
|
|
return;
|
|
}
|
|
|
|
let message = message.unwrap();
|
|
|
|
for news_item in message.news {
|
|
let news_item = News::from(news_item);
|
|
|
|
for symbol in &news_item.symbols {
|
|
if symbols_set.contains(symbol) {
|
|
last_times.insert(symbol.clone(), news_item.time_created);
|
|
}
|
|
}
|
|
|
|
news.push(news_item);
|
|
}
|
|
|
|
if news.len() >= *BERT_MAX_INPUTS || message.next_page_token.is_none() {
|
|
let inputs = news
|
|
.iter()
|
|
.map(|news| format!("{}\n\n{}", news.headline, news.content))
|
|
.collect::<Vec<_>>();
|
|
|
|
let predictions =
|
|
join_all(inputs.chunks(*BERT_MAX_INPUTS).map(|inputs| async move {
|
|
let sequence_classifier = self.config.sequence_classifier.lock().await;
|
|
block_in_place(|| {
|
|
sequence_classifier
|
|
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
|
.into_iter()
|
|
.map(|label| Prediction::try_from(label).unwrap())
|
|
.collect::<Vec<_>>()
|
|
})
|
|
}))
|
|
.await
|
|
.into_iter()
|
|
.flatten();
|
|
|
|
news = news
|
|
.into_iter()
|
|
.zip(predictions)
|
|
.map(|(news, prediction)| News {
|
|
sentiment: prediction.sentiment,
|
|
confidence: prediction.confidence,
|
|
..news
|
|
})
|
|
.collect::<Vec<_>>();
|
|
}
|
|
|
|
if news.len() >= database::news::BATCH_FLUSH_SIZE || message.next_page_token.is_none() {
|
|
database::news::upsert_batch(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&news,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
let backfilled = last_times
|
|
.drain()
|
|
.map(|(symbol, time)| Backfill { symbol, time })
|
|
.collect::<Vec<_>>();
|
|
|
|
database::backfills_news::upsert_batch(
|
|
&self.config.clickhouse_client,
|
|
&self.config.clickhouse_concurrency_limiter,
|
|
&backfilled,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
if message.next_page_token.is_none() {
|
|
break;
|
|
}
|
|
|
|
next_page_token = message.next_page_token;
|
|
news.clear();
|
|
}
|
|
}
|
|
|
|
info!("Backfilled news for {:?}.", symbols);
|
|
}
|
|
|
|
fn max_limit(&self) -> i64 {
|
|
alpaca::api::outgoing::news::MAX_LIMIT
|
|
}
|
|
|
|
fn log_string(&self) -> &'static str {
|
|
"news"
|
|
}
|
|
}
|