Make sentiment predictions blocking
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -20,7 +20,7 @@ use time::OffsetDateTime;
|
|||||||
use tokio::{
|
use tokio::{
|
||||||
join, spawn,
|
join, spawn,
|
||||||
sync::{mpsc, oneshot, Mutex, RwLock},
|
sync::{mpsc, oneshot, Mutex, RwLock},
|
||||||
task::JoinHandle,
|
task::{block_in_place, JoinHandle},
|
||||||
time::sleep,
|
time::sleep,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -397,13 +397,14 @@ async fn execute_backfill_news(
|
|||||||
let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| {
|
let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| {
|
||||||
let sequence_classifier = app_config.sequence_classifier.clone();
|
let sequence_classifier = app_config.sequence_classifier.clone();
|
||||||
async move {
|
async move {
|
||||||
sequence_classifier
|
let sequence_classifier = sequence_classifier.lock().await;
|
||||||
.lock()
|
block_in_place(|| {
|
||||||
.await
|
sequence_classifier
|
||||||
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
.predict(inputs.iter().map(String::as_str).collect::<Vec<_>>())
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|label| Prediction::try_from(label).unwrap())
|
.map(|label| Prediction::try_from(label).unwrap())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
.await
|
.await
|
||||||
|
@@ -17,6 +17,7 @@ use tokio::{
|
|||||||
net::TcpStream,
|
net::TcpStream,
|
||||||
spawn,
|
spawn,
|
||||||
sync::{mpsc, Mutex, RwLock},
|
sync::{mpsc, Mutex, RwLock},
|
||||||
|
task::block_in_place,
|
||||||
};
|
};
|
||||||
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream};
|
||||||
|
|
||||||
@@ -209,14 +210,15 @@ async fn handle_parsed_websocket_message(
|
|||||||
|
|
||||||
let input = format!("{}\n\n{}", news.headline, news.content);
|
let input = format!("{}\n\n{}", news.headline, news.content);
|
||||||
|
|
||||||
let prediction = app_config
|
let sequence_classifier = app_config.sequence_classifier.lock().await;
|
||||||
.sequence_classifier
|
let prediction = block_in_place(|| {
|
||||||
.lock()
|
sequence_classifier
|
||||||
.await
|
.predict(vec![input.as_str()])
|
||||||
.predict(vec![input.as_str()])
|
.into_iter()
|
||||||
.into_iter()
|
.map(|label| Prediction::try_from(label).unwrap())
|
||||||
.map(|label| Prediction::try_from(label).unwrap())
|
.collect::<Vec<_>>()[0]
|
||||||
.collect::<Vec<_>>()[0];
|
});
|
||||||
|
drop(sequence_classifier);
|
||||||
|
|
||||||
let news = News {
|
let news = News {
|
||||||
sentiment: prediction.sentiment,
|
sentiment: prediction.sentiment,
|
||||||
|
Reference in New Issue
Block a user