Make sentiment predictions blocking

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2024-02-05 00:30:32 +00:00
parent caaa31133a
commit a2bcb6d17e
2 changed files with 19 additions and 16 deletions

View File

@@ -20,7 +20,7 @@ use time::OffsetDateTime;
use tokio::{ use tokio::{
join, spawn, join, spawn,
sync::{mpsc, oneshot, Mutex, RwLock}, sync::{mpsc, oneshot, Mutex, RwLock},
task::JoinHandle, task::{block_in_place, JoinHandle},
time::sleep, time::sleep,
}; };
@@ -397,13 +397,14 @@ async fn execute_backfill_news(
let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| { let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| {
let sequence_classifier = app_config.sequence_classifier.clone(); let sequence_classifier = app_config.sequence_classifier.clone();
async move { async move {
let sequence_classifier = sequence_classifier.lock().await;
block_in_place(|| {
sequence_classifier sequence_classifier
.lock()
.await
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>()) .predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
.into_iter() .into_iter()
.map(|label| Prediction::try_from(label).unwrap()) .map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>() .collect::<Vec<_>>()
})
} }
})) }))
.await .await

View File

@@ -17,6 +17,7 @@ use tokio::{
net::TcpStream, net::TcpStream,
spawn, spawn,
sync::{mpsc, Mutex, RwLock}, sync::{mpsc, Mutex, RwLock},
task::block_in_place,
}; };
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
@@ -209,14 +210,15 @@ async fn handle_parsed_websocket_message(
let input = format!("{}\n\n{}", news.headline, news.content); let input = format!("{}\n\n{}", news.headline, news.content);
let prediction = app_config let sequence_classifier = app_config.sequence_classifier.lock().await;
.sequence_classifier let prediction = block_in_place(|| {
.lock() sequence_classifier
.await
.predict(vec![input.as_str()]) .predict(vec![input.as_str()])
.into_iter() .into_iter()
.map(|label| Prediction::try_from(label).unwrap()) .map(|label| Prediction::try_from(label).unwrap())
.collect::<Vec<_>>()[0]; .collect::<Vec<_>>()[0]
});
drop(sequence_classifier);
let news = News { let news = News {
sentiment: prediction.sentiment, sentiment: prediction.sentiment,