Add redis session store
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -40,11 +40,20 @@ pub struct AutheliaConfig {
|
||||
pub user_database: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct RedisConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
#[serde(default)]
|
||||
pub database: u8,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub server: ServerConfig,
|
||||
pub oauth: OAuthConfig,
|
||||
pub authelia: AutheliaConfig,
|
||||
pub redis: RedisConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
use std::{borrow::Cow, convert::Infallible};
|
||||
|
||||
use async_session::{MemoryStore, Session, SessionStore};
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use async_session::{Session, SessionStore};
|
||||
use axum::{
|
||||
Json, RequestPartsExt, Router,
|
||||
extract::{self, FromRef, FromRequestParts, OptionalFromRequestParts},
|
||||
@@ -35,7 +36,7 @@ pub struct User {
|
||||
|
||||
async fn login(
|
||||
extract::State(oauth_client): extract::State<OAuthClient>,
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
extract::State(session_store): extract::State<RedisSessionStore>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
@@ -89,7 +90,7 @@ fn create_login_session(
|
||||
}
|
||||
|
||||
async fn create_login_cookie(
|
||||
session_store: &MemoryStore,
|
||||
session_store: &RedisSessionStore,
|
||||
session: Session,
|
||||
) -> Result<HeaderMap, StatusCode> {
|
||||
let cookie = session_store
|
||||
@@ -127,9 +128,9 @@ struct CallbackParams {
|
||||
|
||||
async fn callback(
|
||||
extract::Query(params): extract::Query<CallbackParams>,
|
||||
extract::State(http_client): extract::State<reqwest::Client>,
|
||||
extract::State(oauth_http_client): extract::State<reqwest::Client>,
|
||||
extract::State(oauth_client): extract::State<OAuthClient>,
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
extract::State(session_store): extract::State<RedisSessionStore>,
|
||||
extract::State(config): extract::State<Config>,
|
||||
TypedHeader(cookies): TypedHeader<Cookie>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
@@ -155,7 +156,7 @@ async fn callback(
|
||||
})?;
|
||||
|
||||
let token_response = validate_claims(
|
||||
&http_client,
|
||||
&oauth_http_client,
|
||||
&oauth_client,
|
||||
params,
|
||||
csrf_token,
|
||||
@@ -170,7 +171,7 @@ async fn callback(
|
||||
error!("failed to create userinfo request: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.request_async(&http_client)
|
||||
.request_async(&oauth_http_client)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to request user info: {e}");
|
||||
@@ -207,7 +208,7 @@ fn retrieve_login_session(
|
||||
}
|
||||
|
||||
async fn validate_claims(
|
||||
http_client: &reqwest::Client,
|
||||
oauth_http_client: &reqwest::Client,
|
||||
oauth_client: &OAuthClient,
|
||||
params: CallbackParams,
|
||||
csrf_token: CsrfToken,
|
||||
@@ -226,7 +227,7 @@ async fn validate_claims(
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(http_client)
|
||||
.request_async(oauth_http_client)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to request token: {e}");
|
||||
@@ -313,7 +314,7 @@ fn create_user_session(
|
||||
}
|
||||
|
||||
async fn create_user_cookie(
|
||||
session_store: &MemoryStore,
|
||||
session_store: &RedisSessionStore,
|
||||
session: Session,
|
||||
) -> Result<HeaderMap, StatusCode> {
|
||||
let cookie = session_store
|
||||
@@ -343,7 +344,7 @@ async fn create_user_cookie(
|
||||
}
|
||||
|
||||
async fn logout(
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
extract::State(session_store): extract::State<RedisSessionStore>,
|
||||
TypedHeader(cookies): TypedHeader<Cookie>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let cookie = cookies.get(COOKIE_NAME).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
@@ -382,13 +383,13 @@ pub fn routes(state: State) -> Router {
|
||||
|
||||
impl<S> FromRequestParts<S> for User
|
||||
where
|
||||
MemoryStore: FromRef<S>,
|
||||
RedisSessionStore: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let store = MemoryStore::from_ref(state);
|
||||
let store = RedisSessionStore::from_ref(state);
|
||||
|
||||
let cookies = parts.extract::<TypedHeader<Cookie>>().await.map_err(|e| {
|
||||
if *e.name() == header::COOKIE {
|
||||
@@ -427,7 +428,7 @@ where
|
||||
|
||||
impl<S> OptionalFromRequestParts<S> for User
|
||||
where
|
||||
MemoryStore: FromRef<S>,
|
||||
RedisSessionStore: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
93
src/state.rs
93
src/state.rs
@@ -1,8 +1,7 @@
|
||||
use std::error::Error;
|
||||
|
||||
use async_session::MemoryStore;
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use axum::extract::FromRef;
|
||||
use log::error;
|
||||
use openidconnect::{
|
||||
ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
|
||||
IssuerUrl, RedirectUrl, StandardErrorResponse,
|
||||
@@ -13,12 +12,9 @@ use openidconnect::{
|
||||
},
|
||||
reqwest,
|
||||
};
|
||||
use tokio::{
|
||||
spawn,
|
||||
time::{Duration, sleep},
|
||||
};
|
||||
use redis::aio::MultiplexedConnection;
|
||||
|
||||
use crate::{config::Config, models::authelia};
|
||||
use crate::config::Config;
|
||||
|
||||
pub type OAuthClient<
|
||||
HasAuthUrl = EndpointSet,
|
||||
@@ -52,52 +48,24 @@ pub struct State {
|
||||
pub config: Config,
|
||||
pub oauth_http_client: reqwest::Client,
|
||||
pub oauth_client: OAuthClient,
|
||||
pub session_store: MemoryStore,
|
||||
pub redis_client: MultiplexedConnection,
|
||||
pub session_store: RedisSessionStore,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let (oauth_http_client, oauth_client) = oauth(&config).await?;
|
||||
let session_store = session_store();
|
||||
let (oauth_http_client, oauth_client) = oauth_client(&config).await?;
|
||||
let redis_client = redis_client(&config).await?;
|
||||
let session_store = session_store(&config)?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
oauth_http_client,
|
||||
oauth_client,
|
||||
redis_client,
|
||||
session_store,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_users(&self) -> Result<authelia::Users, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
|
||||
let users = authelia::Users::from(users_file);
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub fn load_groups(&self) -> Result<authelia::Groups, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
||||
let groups = authelia::Groups::from(users_file);
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
pub fn load_users_and_groups(
|
||||
&self,
|
||||
) -> Result<(authelia::Users, authelia::Groups), Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
||||
let users = authelia::Users::from(users_file.clone());
|
||||
let groups = authelia::Groups::from(users_file);
|
||||
Ok((users, groups))
|
||||
}
|
||||
|
||||
pub fn save_users(&self, users: authelia::Users) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let users_file = authelia::UsersFile::from(users);
|
||||
let file_contents = serde_yaml::to_string(&users_file)?;
|
||||
std::fs::write(&self.config.authelia.user_database, file_contents)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for Config {
|
||||
@@ -118,13 +86,19 @@ impl FromRef<State> for OAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for MemoryStore {
|
||||
impl FromRef<State> for MultiplexedConnection {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.redis_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for RedisSessionStore {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.session_store.clone()
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth(
|
||||
async fn oauth_client(
|
||||
config: &Config,
|
||||
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
|
||||
let oauth_http_client = reqwest::ClientBuilder::new()
|
||||
@@ -151,19 +125,26 @@ async fn oauth(
|
||||
Ok((oauth_http_client, oauth_client))
|
||||
}
|
||||
|
||||
fn session_store() -> MemoryStore {
|
||||
let session_store = MemoryStore::new();
|
||||
async fn redis_client(
|
||||
config: &Config,
|
||||
) -> Result<MultiplexedConnection, Box<dyn Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"redis://{}:{}/{}",
|
||||
config.redis.host, config.redis.port, config.redis.database
|
||||
);
|
||||
|
||||
let session_store_clone = session_store.clone();
|
||||
spawn(async move {
|
||||
loop {
|
||||
match session_store_clone.cleanup().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => error!("Failed to clean up session store: {e}"),
|
||||
}
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
});
|
||||
let client = redis::Client::open(url)?;
|
||||
let connection = client.get_multiplexed_async_connection().await?;
|
||||
|
||||
session_store
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
fn session_store(config: &Config) -> Result<RedisSessionStore, Box<dyn Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"redis://{}:{}/{}",
|
||||
config.redis.host, config.redis.port, config.redis.database
|
||||
);
|
||||
let session_store = RedisSessionStore::new(url)?.with_prefix("session:");
|
||||
|
||||
Ok(session_store)
|
||||
}
|
||||
|
40
src/utils/authelia.rs
Normal file
40
src/utils/authelia.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{models, state::State};
|
||||
|
||||
impl State {
|
||||
pub fn load_users(&self) -> Result<models::authelia::Users, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file: models::authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
|
||||
let users = models::authelia::Users::from(users_file);
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub fn load_groups(&self) -> Result<models::authelia::Groups, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<models::authelia::UsersFile>(&file_contents)?;
|
||||
let groups = models::authelia::Groups::from(users_file);
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
pub fn load_users_and_groups(
|
||||
&self,
|
||||
) -> Result<(models::authelia::Users, models::authelia::Groups), Box<dyn Error + Send + Sync>>
|
||||
{
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<models::authelia::UsersFile>(&file_contents)?;
|
||||
let users = models::authelia::Users::from(users_file.clone());
|
||||
let groups = models::authelia::Groups::from(users_file);
|
||||
Ok((users, groups))
|
||||
}
|
||||
|
||||
pub fn save_users(
|
||||
&self,
|
||||
users: models::authelia::Users,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let users_file = models::authelia::UsersFile::from(users);
|
||||
let file_contents = serde_yaml::to_string(&users_file)?;
|
||||
std::fs::write(&self.config.authelia.user_database, file_contents)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
@@ -1 +1,2 @@
|
||||
pub mod authelia;
|
||||
pub mod crypto;
|
||||
|
Reference in New Issue
Block a user