Add redis

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2025-06-05 08:29:58 +01:00
parent b81a49af3d
commit 050f25bba9
7 changed files with 103 additions and 35 deletions

13
Cargo.lock generated
View File

@@ -1139,7 +1139,9 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serde_yaml", "serde_yaml",
"time",
"tokio", "tokio",
"uuid",
] ]
[[package]] [[package]]
@@ -3084,6 +3086,17 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
dependencies = [
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]] [[package]]
name = "value-bag" name = "value-bag"
version = "1.11.1" version = "1.11.1"

View File

@@ -30,4 +30,6 @@ redis-macros = "0.5.4"
serde = "1.0.219" serde = "1.0.219"
serde_json = "1.0.140" serde_json = "1.0.140"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
time = { version = "0.3.41", features = ["serde"] }
tokio = { version = "1.45.1", features = ["rt-multi-thread", "process"] } tokio = { version = "1.45.1", features = ["rt-multi-thread", "process"] }
uuid = { version = "1.17.0", features = ["serde"] }

View File

@@ -1,6 +1,7 @@
use clap::Parser; use clap::Parser;
use serde::Deserialize; use serde::Deserialize;
use std::{ use std::{
error::Error,
fs, fs,
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
path::PathBuf, path::PathBuf,
@@ -56,8 +57,10 @@ pub struct Config {
pub redis: RedisConfig, pub redis: RedisConfig,
} }
impl Config { impl TryFrom<&PathBuf> for Config {
pub fn from_yaml(path: &PathBuf) -> Result<Self, Box<dyn std::error::Error>> { type Error = Box<dyn Error + Send + Sync>;
fn try_from(path: &PathBuf) -> Result<Self, Self::Error> {
let contents = fs::read_to_string(path)?; let contents = fs::read_to_string(path)?;
let config = serde_yaml::from_str(&contents)?; let config = serde_yaml::from_str(&contents)?;
Ok(config) Ok(config)

View File

@@ -22,8 +22,8 @@ async fn main() {
let args = Args::parse(); let args = Args::parse();
log4rs::init_file(args.log_config, Deserializers::default()).unwrap(); log4rs::init_file(args.log_config, Deserializers::default()).unwrap();
let config = Config::from_yaml(&args.config).unwrap(); let config = Config::try_from(&args.config).unwrap();
let state = State::from_config(config.clone()).await.unwrap(); let state = State::from_config(config.clone()).await;
let routes = routes::routes(state); let routes = routes::routes(state);
let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes); let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes);

15
src/models/invites.rs Normal file
View File

@@ -0,0 +1,15 @@
use redis_macros::{FromRedisValue, ToRedisArgs};
use serde::{Deserialize, Serialize};
use time::UtcDateTime;
use uuid::Uuid;
#[derive(Serialize, Deserialize, FromRedisValue, ToRedisArgs)]
struct Invite {
id: Uuid,
groups: Vec<String>,
emails: Vec<String>,
uses: i64,
max_uses: Option<i64>,
created_at: UtcDateTime,
expires_at: Option<UtcDateTime>,
}

View File

@@ -1,3 +1,4 @@
pub mod authelia; pub mod authelia;
pub mod groups; pub mod groups;
pub mod invites;
pub mod users; pub mod users;

View File

@@ -1,5 +1,3 @@
use std::error::Error;
use async_redis_session::RedisSessionStore; use async_redis_session::RedisSessionStore;
use axum::extract::FromRef; use axum::extract::FromRef;
use openidconnect::{ use openidconnect::{
@@ -12,7 +10,8 @@ use openidconnect::{
}, },
reqwest, reqwest,
}; };
use redis::aio::MultiplexedConnection; use redis::{self, AsyncCommands};
use tokio::spawn;
use crate::config::Config; use crate::config::Config;
@@ -48,23 +47,23 @@ pub struct State {
pub config: Config, pub config: Config,
pub oauth_http_client: reqwest::Client, pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient, pub oauth_client: OAuthClient,
pub redis_client: MultiplexedConnection, pub redis_client: redis::aio::MultiplexedConnection,
pub session_store: RedisSessionStore, pub session_store: RedisSessionStore,
} }
impl State { impl State {
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> { pub async fn from_config(config: Config) -> Self {
let (oauth_http_client, oauth_client) = oauth_client(&config).await?; let (oauth_http_client, oauth_client) = oauth_client(&config).await;
let redis_client = redis_client(&config).await?; let redis_client = redis_client(&config).await;
let session_store = session_store(&config)?; let session_store = session_store(&config);
Ok(Self { Self {
config, config,
oauth_http_client, oauth_http_client,
oauth_client, oauth_client,
redis_client, redis_client,
session_store, session_store,
}) }
} }
} }
@@ -86,7 +85,7 @@ impl FromRef<State> for OAuthClient {
} }
} }
impl FromRef<State> for MultiplexedConnection { impl FromRef<State> for redis::aio::MultiplexedConnection {
fn from_ref(state: &State) -> Self { fn from_ref(state: &State) -> Self {
state.redis_client.clone() state.redis_client.clone()
} }
@@ -98,53 +97,88 @@ impl FromRef<State> for RedisSessionStore {
} }
} }
async fn oauth_client( async fn oauth_client(config: &Config) -> (reqwest::Client, OAuthClient) {
config: &Config,
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
let oauth_http_client = reqwest::ClientBuilder::new() let oauth_http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(config.oauth.insecure) .danger_accept_invalid_certs(config.oauth.insecure)
.build()?; .build()
.unwrap();
let provider_metadata = CoreProviderMetadata::discover_async( let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(config.oauth.issuer_url.clone())?, IssuerUrl::new(config.oauth.issuer_url.clone()).unwrap(),
&oauth_http_client, &oauth_http_client,
) )
.await?; .await
.unwrap();
let oauth_client = OAuthClient::from_provider_metadata( let oauth_client = OAuthClient::from_provider_metadata(
provider_metadata, provider_metadata,
ClientId::new(config.oauth.client_id.clone()), ClientId::new(config.oauth.client_id.clone()),
Some(ClientSecret::new(config.oauth.client_secret.clone())), Some(ClientSecret::new(config.oauth.client_secret.clone())),
) )
.set_redirect_uri(RedirectUrl::new(format!( .set_redirect_uri(
"{}{}/api/auth/callback", RedirectUrl::new(format!(
config.server.host, config.server.subpath "{}{}/api/auth/callback",
))?); config.server.host, config.server.subpath
))
.unwrap(),
);
Ok((oauth_http_client, oauth_client)) (oauth_http_client, oauth_client)
} }
async fn redis_client( async fn redis_client(config: &Config) -> redis::aio::MultiplexedConnection {
config: &Config,
) -> Result<MultiplexedConnection, Box<dyn Error + Send + Sync>> {
let url = format!( let url = format!(
"redis://{}:{}/{}", "redis://{}:{}/{}",
config.redis.host, config.redis.port, config.redis.database config.redis.host, config.redis.port, config.redis.database
); );
let client = redis::Client::open(url)?; let client = redis::Client::open(url).unwrap();
let connection = client.get_multiplexed_async_connection().await?; let mut connection = client.get_multiplexed_async_connection().await.unwrap();
Ok(connection) let _: () = redis::cmd("CONFIG")
.arg("SET")
.arg("notify-keyspace-events")
.arg("Ex")
.query_async(&mut connection)
.await
.unwrap();
let database = config.redis.database.to_string();
spawn(async move {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let rconfig = redis::AsyncConnectionConfig::new().set_push_sender(tx);
let mut connection = client
.get_multiplexed_async_connection_with_config(&rconfig)
.await
.unwrap();
let channel = format!("__keyevent@{}__:expired", database);
connection.subscribe(&[channel]).await.unwrap();
while let Some(msg) = rx.recv().await {
if let Some(msg) = redis::Msg::from_push_info(msg) {
if let Ok(key) = msg.get_payload::<String>() {
if !key.starts_with("invite:") {
continue;
}
let id = key.trim_start_matches("invite:").to_string();
let _: i64 = connection.srem("invite:all", id).await.unwrap();
}
}
}
});
connection
} }
fn session_store(config: &Config) -> Result<RedisSessionStore, Box<dyn Error + Send + Sync>> { fn session_store(config: &Config) -> RedisSessionStore {
let url = format!( let url = format!(
"redis://{}:{}/{}", "redis://{}:{}/{}",
config.redis.host, config.redis.port, config.redis.database config.redis.host, config.redis.port, config.redis.database
); );
let session_store = RedisSessionStore::new(url)?.with_prefix("session:"); let session_store = RedisSessionStore::new(url).unwrap().with_prefix("session:");
Ok(session_store) session_store
} }