diff --git a/src/threads/data/backfill.rs b/src/threads/data/backfill.rs index 0faf277..9c0e662 100644 --- a/src/threads/data/backfill.rs +++ b/src/threads/data/backfill.rs @@ -20,7 +20,7 @@ use time::OffsetDateTime; use tokio::{ join, spawn, sync::{mpsc, oneshot, Mutex, RwLock}, - task::JoinHandle, + task::{block_in_place, JoinHandle}, time::sleep, }; @@ -397,13 +397,14 @@ async fn execute_backfill_news( let predictions = join_all(inputs.chunks(app_config.max_bert_inputs).map(|inputs| { let sequence_classifier = app_config.sequence_classifier.clone(); async move { - sequence_classifier - .lock() - .await - .predict(inputs.iter().map(String::as_str).collect::>()) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>() + let sequence_classifier = sequence_classifier.lock().await; + block_in_place(|| { + sequence_classifier + .predict(inputs.iter().map(String::as_str).collect::>()) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>() + }) } })) .await diff --git a/src/threads/data/websocket.rs b/src/threads/data/websocket.rs index e5b6ba5..0ce8cbc 100644 --- a/src/threads/data/websocket.rs +++ b/src/threads/data/websocket.rs @@ -17,6 +17,7 @@ use tokio::{ net::TcpStream, spawn, sync::{mpsc, Mutex, RwLock}, + task::block_in_place, }; use tokio_tungstenite::{tungstenite, MaybeTlsStream, WebSocketStream}; @@ -209,14 +210,15 @@ async fn handle_parsed_websocket_message( let input = format!("{}\n\n{}", news.headline, news.content); - let prediction = app_config - .sequence_classifier - .lock() - .await - .predict(vec![input.as_str()]) - .into_iter() - .map(|label| Prediction::try_from(label).unwrap()) - .collect::>()[0]; + let sequence_classifier = app_config.sequence_classifier.lock().await; + let prediction = block_in_place(|| { + sequence_classifier + .predict(vec![input.as_str()]) + .into_iter() + .map(|label| Prediction::try_from(label).unwrap()) + .collect::>()[0] + }); + drop(sequence_classifier); let news = News { sentiment: prediction.sentiment,