Files
qrust/src/routes/assets.rs
2024-03-11 16:53:12 +00:00

202 lines
5.5 KiB
Rust

use crate::{
config::{Config, ALPACA_API_BASE},
create_send_await, database, threads,
};
use axum::{extract::Path, Extension, Json};
use http::StatusCode;
use qrust::types::{alpaca, Asset};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::mpsc;
pub async fn get(
Extension(config): Extension<Arc<Config>>,
) -> Result<(StatusCode, Json<Vec<Asset>>), StatusCode> {
let assets = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(assets)))
}
pub async fn get_where_symbol(
Extension(config): Extension<Arc<Config>>,
Path(symbol): Path<String>,
) -> Result<(StatusCode, Json<Asset>), StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
asset.map_or(Err(StatusCode::NOT_FOUND), |asset| {
Ok((StatusCode::OK, Json(asset)))
})
}
#[derive(Deserialize)]
pub struct AddAssetsRequest {
symbols: Vec<String>,
}
#[derive(Serialize)]
pub struct AddAssetsResponse {
added: Vec<String>,
skipped: Vec<String>,
failed: Vec<String>,
}
pub async fn add(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Json(request): Json<AddAssetsRequest>,
) -> Result<(StatusCode, Json<AddAssetsResponse>), StatusCode> {
let database_symbols = database::assets::select(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.into_iter()
.map(|asset| asset.symbol)
.collect::<HashSet<_>>();
let mut alpaca_assets = alpaca::api::incoming::asset::get_by_symbols(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&request.symbols,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?
.into_iter()
.map(|asset| (asset.symbol.clone(), asset))
.collect::<HashMap<_, _>>();
let num_symbols = request.symbols.len();
let (assets, skipped, failed) = request.symbols.into_iter().fold(
(Vec::with_capacity(num_symbols), vec![], vec![]),
|(mut assets, mut skipped, mut failed), symbol| {
if database_symbols.contains(&symbol) {
skipped.push(symbol);
} else if let Some(asset) = alpaca_assets.remove(&symbol) {
if asset.status == alpaca::shared::asset::Status::Active
&& asset.tradable
&& asset.fractionable
{
assets.push((asset.symbol, asset.class.into()));
} else {
failed.push(asset.symbol);
}
} else {
failed.push(symbol);
}
(assets, skipped, failed)
},
);
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
assets.clone()
);
Ok((
StatusCode::CREATED,
Json(AddAssetsResponse {
added: assets.into_iter().map(|asset| asset.0).collect(),
skipped,
failed,
}),
))
}
pub async fn add_symbol(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
if database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.is_some()
{
return Err(StatusCode::CONFLICT);
}
let asset = alpaca::api::incoming::asset::get_by_symbol(
&config.alpaca_client,
&config.alpaca_rate_limiter,
&symbol,
None,
&ALPACA_API_BASE,
)
.await
.map_err(|e| {
e.status()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |status| {
StatusCode::from_u16(status.as_u16()).unwrap()
})
})?;
if asset.status != alpaca::shared::asset::Status::Active
|| !asset.tradable
|| !asset.fractionable
{
return Err(StatusCode::FORBIDDEN);
}
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Add,
vec![(asset.symbol, asset.class.into())]
);
Ok(StatusCode::CREATED)
}
pub async fn delete(
Extension(config): Extension<Arc<Config>>,
Extension(data_sender): Extension<mpsc::Sender<threads::data::Message>>,
Path(symbol): Path<String>,
) -> Result<StatusCode, StatusCode> {
let asset = database::assets::select_where_symbol(
&config.clickhouse_client,
&config.clickhouse_concurrency_limiter,
&symbol,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
create_send_await!(
data_sender,
threads::data::Message::new,
threads::data::Action::Remove,
vec![(asset.symbol, asset.class)]
);
Ok(StatusCode::NO_CONTENT)
}