Add redis session store

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2025-06-04 23:43:41 +01:00
parent ec7055d5ff
commit 455bf7b88d
8 changed files with 507 additions and 85 deletions

View File

@@ -40,11 +40,20 @@ pub struct AutheliaConfig {
pub user_database: PathBuf,
}
#[derive(Clone, Deserialize)]
pub struct RedisConfig {
pub host: String,
pub port: u16,
#[serde(default)]
pub database: u8,
}
#[derive(Clone, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub oauth: OAuthConfig,
pub authelia: AutheliaConfig,
pub redis: RedisConfig,
}
impl Config {

View File

@@ -1,6 +1,7 @@
use std::{borrow::Cow, convert::Infallible};
use async_session::{MemoryStore, Session, SessionStore};
use async_redis_session::RedisSessionStore;
use async_session::{Session, SessionStore};
use axum::{
Json, RequestPartsExt, Router,
extract::{self, FromRef, FromRequestParts, OptionalFromRequestParts},
@@ -35,7 +36,7 @@ pub struct User {
async fn login(
extract::State(oauth_client): extract::State<OAuthClient>,
extract::State(session_store): extract::State<MemoryStore>,
extract::State(session_store): extract::State<RedisSessionStore>,
) -> Result<impl IntoResponse, StatusCode> {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
@@ -89,7 +90,7 @@ fn create_login_session(
}
async fn create_login_cookie(
session_store: &MemoryStore,
session_store: &RedisSessionStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store
@@ -127,9 +128,9 @@ struct CallbackParams {
async fn callback(
extract::Query(params): extract::Query<CallbackParams>,
extract::State(http_client): extract::State<reqwest::Client>,
extract::State(oauth_http_client): extract::State<reqwest::Client>,
extract::State(oauth_client): extract::State<OAuthClient>,
extract::State(session_store): extract::State<MemoryStore>,
extract::State(session_store): extract::State<RedisSessionStore>,
extract::State(config): extract::State<Config>,
TypedHeader(cookies): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
@@ -155,7 +156,7 @@ async fn callback(
})?;
let token_response = validate_claims(
&http_client,
&oauth_http_client,
&oauth_client,
params,
csrf_token,
@@ -170,7 +171,7 @@ async fn callback(
error!("failed to create userinfo request: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.request_async(&http_client)
.request_async(&oauth_http_client)
.await
.map_err(|e| {
error!("failed to request user info: {e}");
@@ -207,7 +208,7 @@ fn retrieve_login_session(
}
async fn validate_claims(
http_client: &reqwest::Client,
oauth_http_client: &reqwest::Client,
oauth_client: &OAuthClient,
params: CallbackParams,
csrf_token: CsrfToken,
@@ -226,7 +227,7 @@ async fn validate_claims(
StatusCode::INTERNAL_SERVER_ERROR
})?
.set_pkce_verifier(pkce_verifier)
.request_async(http_client)
.request_async(oauth_http_client)
.await
.map_err(|e| {
error!("failed to request token: {e}");
@@ -313,7 +314,7 @@ fn create_user_session(
}
async fn create_user_cookie(
session_store: &MemoryStore,
session_store: &RedisSessionStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store
@@ -343,7 +344,7 @@ async fn create_user_cookie(
}
async fn logout(
extract::State(session_store): extract::State<MemoryStore>,
extract::State(session_store): extract::State<RedisSessionStore>,
TypedHeader(cookies): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
let cookie = cookies.get(COOKIE_NAME).ok_or(StatusCode::UNAUTHORIZED)?;
@@ -382,13 +383,13 @@ pub fn routes(state: State) -> Router {
impl<S> FromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
RedisSessionStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = MemoryStore::from_ref(state);
let store = RedisSessionStore::from_ref(state);
let cookies = parts.extract::<TypedHeader<Cookie>>().await.map_err(|e| {
if *e.name() == header::COOKIE {
@@ -427,7 +428,7 @@ where
impl<S> OptionalFromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
RedisSessionStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;

View File

@@ -1,8 +1,7 @@
use std::error::Error;
use async_session::MemoryStore;
use async_redis_session::RedisSessionStore;
use axum::extract::FromRef;
use log::error;
use openidconnect::{
ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
IssuerUrl, RedirectUrl, StandardErrorResponse,
@@ -13,12 +12,9 @@ use openidconnect::{
},
reqwest,
};
use tokio::{
spawn,
time::{Duration, sleep},
};
use redis::aio::MultiplexedConnection;
use crate::{config::Config, models::authelia};
use crate::config::Config;
pub type OAuthClient<
HasAuthUrl = EndpointSet,
@@ -52,52 +48,24 @@ pub struct State {
pub config: Config,
pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient,
pub session_store: MemoryStore,
pub redis_client: MultiplexedConnection,
pub session_store: RedisSessionStore,
}
impl State {
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> {
let (oauth_http_client, oauth_client) = oauth(&config).await?;
let session_store = session_store();
let (oauth_http_client, oauth_client) = oauth_client(&config).await?;
let redis_client = redis_client(&config).await?;
let session_store = session_store(&config)?;
Ok(Self {
config,
oauth_http_client,
oauth_client,
redis_client,
session_store,
})
}
pub fn load_users(&self) -> Result<authelia::Users, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
let users = authelia::Users::from(users_file);
Ok(users)
}
pub fn load_groups(&self) -> Result<authelia::Groups, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
let groups = authelia::Groups::from(users_file);
Ok(groups)
}
pub fn load_users_and_groups(
&self,
) -> Result<(authelia::Users, authelia::Groups), Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
let users = authelia::Users::from(users_file.clone());
let groups = authelia::Groups::from(users_file);
Ok((users, groups))
}
pub fn save_users(&self, users: authelia::Users) -> Result<(), Box<dyn Error + Send + Sync>> {
let users_file = authelia::UsersFile::from(users);
let file_contents = serde_yaml::to_string(&users_file)?;
std::fs::write(&self.config.authelia.user_database, file_contents)?;
Ok(())
}
}
impl FromRef<State> for Config {
@@ -118,13 +86,19 @@ impl FromRef<State> for OAuthClient {
}
}
impl FromRef<State> for MemoryStore {
impl FromRef<State> for MultiplexedConnection {
fn from_ref(state: &State) -> Self {
state.redis_client.clone()
}
}
impl FromRef<State> for RedisSessionStore {
fn from_ref(state: &State) -> Self {
state.session_store.clone()
}
}
async fn oauth(
async fn oauth_client(
config: &Config,
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
let oauth_http_client = reqwest::ClientBuilder::new()
@@ -151,19 +125,26 @@ async fn oauth(
Ok((oauth_http_client, oauth_client))
}
fn session_store() -> MemoryStore {
let session_store = MemoryStore::new();
async fn redis_client(
config: &Config,
) -> Result<MultiplexedConnection, Box<dyn Error + Send + Sync>> {
let url = format!(
"redis://{}:{}/{}",
config.redis.host, config.redis.port, config.redis.database
);
let session_store_clone = session_store.clone();
spawn(async move {
loop {
match session_store_clone.cleanup().await {
Ok(()) => {}
Err(e) => error!("Failed to clean up session store: {e}"),
}
sleep(Duration::from_secs(60)).await;
}
});
let client = redis::Client::open(url)?;
let connection = client.get_multiplexed_async_connection().await?;
session_store
Ok(connection)
}
fn session_store(config: &Config) -> Result<RedisSessionStore, Box<dyn Error + Send + Sync>> {
let url = format!(
"redis://{}:{}/{}",
config.redis.host, config.redis.port, config.redis.database
);
let session_store = RedisSessionStore::new(url)?.with_prefix("session:");
Ok(session_store)
}

40
src/utils/authelia.rs Normal file
View File

@@ -0,0 +1,40 @@
use std::error::Error;
use crate::{models, state::State};
impl State {
pub fn load_users(&self) -> Result<models::authelia::Users, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file: models::authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
let users = models::authelia::Users::from(users_file);
Ok(users)
}
pub fn load_groups(&self) -> Result<models::authelia::Groups, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<models::authelia::UsersFile>(&file_contents)?;
let groups = models::authelia::Groups::from(users_file);
Ok(groups)
}
pub fn load_users_and_groups(
&self,
) -> Result<(models::authelia::Users, models::authelia::Groups), Box<dyn Error + Send + Sync>>
{
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<models::authelia::UsersFile>(&file_contents)?;
let users = models::authelia::Users::from(users_file.clone());
let groups = models::authelia::Groups::from(users_file);
Ok((users, groups))
}
pub fn save_users(
&self,
users: models::authelia::Users,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let users_file = models::authelia::UsersFile::from(users);
let file_contents = serde_yaml::to_string(&users_file)?;
std::fs::write(&self.config.authelia.user_database, file_contents)?;
Ok(())
}
}

View File

@@ -1 +1,2 @@
pub mod authelia;
pub mod crypto;