feat: add session persistence

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2025-03-28 14:34:42 +00:00
parent d2a115f389
commit d5e1b1b437
6 changed files with 189 additions and 43 deletions

1
Cargo.lock generated
View File

@@ -3005,6 +3005,7 @@ dependencies = [
"log4rs", "log4rs",
"openidconnect", "openidconnect",
"serde", "serde",
"serde_json",
"serde_yaml", "serde_yaml",
"sqlx", "sqlx",
"tokio", "tokio",

View File

@@ -22,6 +22,7 @@ log = "0.4.27"
log4rs = "1.3.0" log4rs = "1.3.0"
openidconnect = { version = "4.0.0", features = ["reqwest"] } openidconnect = { version = "4.0.0", features = ["reqwest"] }
serde = "1.0.219" serde = "1.0.219"
serde_json = "1.0.140"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] } sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] }
tokio = { version = "1.44.1", features = ["rt-multi-thread"] } tokio = { version = "1.44.1", features = ["rt-multi-thread"] }

View File

@@ -84,6 +84,7 @@ data:
issuer_url: "https://id.veil.local" issuer_url: "https://id.veil.local"
client_id: "veil" client_id: "veil"
client_secret: "insecure_secret" client_secret: "insecure_secret"
admin_group: "admins"
insecure: true insecure: true
log4rs.yml: | log4rs.yml: |
appenders: appenders:

View File

@@ -41,6 +41,8 @@ pub struct OAuthConfig {
pub client_secret: String, pub client_secret: String,
#[serde(default)] #[serde(default)]
pub insecure: bool, pub insecure: bool,
#[serde(default)]
pub admin_group: Option<String>,
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]

View File

@@ -11,13 +11,18 @@ use axum::{
use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason}; use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason};
use log::error; use log::error;
use openidconnect::{ use openidconnect::{
AccessTokenHash, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername, Nonce, AccessTokenHash, AdditionalClaims, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier,
TokenResponse, core::CoreAuthenticationFlow, reqwest, TokenResponse, UserInfoClaims,
core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse},
reqwest,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::state::{OAuthClient, State}; use crate::{
config::Config,
state::{OAuthClient, State},
};
static COOKIE_NAME: &str = "veil_session"; static COOKIE_NAME: &str = "veil_session";
@@ -26,6 +31,16 @@ pub struct User {
pub subject: SubjectIdentifier, pub subject: SubjectIdentifier,
pub username: EndUserUsername, pub username: EndUserUsername,
pub email: Option<EndUserEmail>, pub email: Option<EndUserEmail>,
pub is_admin: bool,
}
impl IntoResponse for User {
fn into_response(self) -> Response {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
let body = serde_json::to_string(&self).unwrap();
(headers, body).into_response()
}
} }
async fn login( async fn login(
@@ -52,7 +67,19 @@ async fn login(
.add_scope(Scope::new("groups".to_string())) .add_scope(Scope::new("groups".to_string()))
.url(); .url();
let session = create_login_session(pkce_verifier, csrf_token, nonce)?;
let headers = create_login_cookie(&session_store, session).await?;
Ok((headers, Redirect::to(auth_url.as_str())))
}
fn create_login_session(
pkce_verifier: PkceCodeVerifier,
csrf_token: CsrfToken,
nonce: Nonce,
) -> Result<Session, StatusCode> {
let mut session = Session::new(); let mut session = Session::new();
session session
.insert("pkce_verifier", pkce_verifier) .insert("pkce_verifier", pkce_verifier)
.map_err(|e| { .map_err(|e| {
@@ -68,6 +95,13 @@ async fn login(
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
Ok(session)
}
async fn create_login_cookie(
session_store: &MemoryStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store let cookie = session_store
.store_session(session) .store_session(session)
.await .await
@@ -80,8 +114,7 @@ async fn login(
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
let cookie = let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
@@ -92,7 +125,7 @@ async fn login(
})?, })?,
); );
Ok((headers, Redirect::to(auth_url.as_str()))) Ok(headers)
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -107,6 +140,7 @@ async fn callback(
extract::State(http_client): extract::State<reqwest::Client>, extract::State(http_client): extract::State<reqwest::Client>,
extract::State(oauth_client): extract::State<OAuthClient>, extract::State(oauth_client): extract::State<OAuthClient>,
extract::State(session_store): extract::State<MemoryStore>, extract::State(session_store): extract::State<MemoryStore>,
extract::State(config): extract::State<Config>,
TypedHeader(cookies): TypedHeader<Cookie>, TypedHeader(cookies): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
let cookie = cookies let cookie = cookies
@@ -123,13 +157,49 @@ async fn callback(
})? })?
.ok_or(StatusCode::UNAUTHORIZED)?; .ok_or(StatusCode::UNAUTHORIZED)?;
let csrf_token = session let (csrf_token, pkce_verifier, nonce) = retrieve_login_session(&session)?;
.get::<CsrfToken>("csrf_token")
.ok_or_else(|| { session_store.destroy_session(session).await.map_err(|e| {
error!("failed to get csrf_token from session"); error!("failed to destroy session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let token_response = validate_claims(
&http_client,
&oauth_client,
params,
csrf_token,
pkce_verifier,
nonce,
)
.await?;
let claims = oauth_client
.user_info(token_response.access_token().to_owned(), None)
.map_err(|e| {
error!("failed to create userinfo request: {e}");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})? })?
.clone(); .request_async(&http_client)
.await
.map_err(|e| {
error!("failed to request user info: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let session = create_user_session(&config, &claims)?;
let headers = create_user_cookie(&session_store, session).await?;
Ok((headers, StatusCode::OK))
}
fn retrieve_login_session(
session: &Session,
) -> Result<(CsrfToken, PkceCodeVerifier, Nonce), StatusCode> {
let csrf_token = session.get::<CsrfToken>("csrf_token").ok_or_else(|| {
error!("failed to get csrf_token from session");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let pkce_verifier = session let pkce_verifier = session
.get::<PkceCodeVerifier>("pkce_verifier") .get::<PkceCodeVerifier>("pkce_verifier")
@@ -138,32 +208,35 @@ async fn callback(
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
let nonce = session let nonce = session.get::<Nonce>("nonce").ok_or_else(|| {
.get::<Nonce>("nonce")
.ok_or_else(|| {
error!("failed to get nonce from session"); error!("failed to get nonce from session");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?
.clone();
session_store.destroy_session(session).await.map_err(|e| {
error!("failed to destroy session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
Ok((csrf_token, pkce_verifier, nonce))
}
async fn validate_claims(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
params: CallbackParams,
csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier,
nonce: Nonce,
) -> Result<CoreTokenResponse, StatusCode> {
if *csrf_token.secret() != params.state { if *csrf_token.secret() != params.state {
error!("csrf_token mismatch"); error!("csrf_token mismatch");
return Err(StatusCode::INTERNAL_SERVER_ERROR); return Err(StatusCode::INTERNAL_SERVER_ERROR);
} }
let token_response = oauth_client let token_response = oauth_client
.exchange_code(AuthorizationCode::new(params.code)) .exchange_code(AuthorizationCode::new(params.code.clone()))
.map_err(|e| { .map_err(|e| {
error!("failed to exchange code: {e}"); error!("failed to exchange code: {e}");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})? })?
.set_pkce_verifier(pkce_verifier) .set_pkce_verifier(pkce_verifier)
.request_async(&http_client) .request_async(http_client)
.await .await
.map_err(|e| { .map_err(|e| {
error!("failed to request token: {e}"); error!("failed to request token: {e}");
@@ -176,10 +249,14 @@ async fn callback(
})?; })?;
let id_token_verifier = oauth_client.id_token_verifier(); let id_token_verifier = oauth_client.id_token_verifier();
let claims = id_token.claims(&id_token_verifier, &nonce).map_err(|e| {
let claims = id_token
.claims(&id_token_verifier, &nonce)
.map_err(|e| {
error!("failed to verify id_token: {e}"); error!("failed to verify id_token: {e}");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?
.to_owned();
if let Some(expected_access_token_hash) = claims.access_token_hash() { if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token( let actual_access_token_hash = AccessTokenHash::from_token(
@@ -204,6 +281,20 @@ async fn callback(
} }
} }
Ok(token_response)
}
#[derive(Debug, Serialize, Deserialize)]
struct ExtraClaims {
groups: Vec<String>,
}
impl AdditionalClaims for ExtraClaims {}
fn create_user_session(
config: &Config,
claims: &UserInfoClaims<ExtraClaims, CoreGenderClaim>,
) -> Result<Session, StatusCode> {
let user = User { let user = User {
subject: claims.subject().to_owned(), subject: claims.subject().to_owned(),
username: claims.preferred_username().cloned().ok_or_else(|| { username: claims.preferred_username().cloned().ok_or_else(|| {
@@ -211,6 +302,17 @@ async fn callback(
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?, })?,
email: claims.email().cloned(), email: claims.email().cloned(),
is_admin: config
.oauth
.admin_group
.as_ref()
.is_some_and(|admin_group| {
claims
.additional_claims()
.groups
.iter()
.any(|group| group == admin_group)
}),
}; };
let mut session = Session::new(); let mut session = Session::new();
@@ -219,7 +321,37 @@ async fn callback(
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
Ok(StatusCode::OK) Ok(session)
}
async fn create_user_cookie(
session_store: &MemoryStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store
.store_session(session)
.await
.map_err(|e| {
error!("failed to store session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or_else(|| {
error!("failed to retrieve stored session cookie");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
header::SET_COOKIE,
cookie.parse().map_err(|e| {
error!("failed to parse cookie: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?,
);
Ok(headers)
} }
async fn logout( async fn logout(
@@ -247,11 +379,16 @@ async fn logout(
Ok(StatusCode::OK) Ok(StatusCode::OK)
} }
async fn session(user: User) -> Result<impl IntoResponse, StatusCode> {
Ok(user)
}
pub fn routes(state: State) -> Router { pub fn routes(state: State) -> Router {
Router::new() Router::new()
.route("/auth/login", routing::get(login)) .route("/auth/login", routing::get(login))
.route("/auth/callback", routing::get(callback)) .route("/auth/callback", routing::get(callback))
.route("/auth/logout", routing::get(logout)) .route("/auth/logout", routing::get(logout))
.route("/auth/session", routing::get(session))
.with_state(state) .with_state(state)
} }

View File

@@ -47,6 +47,7 @@ pub type OAuthClient<
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub config: Config,
pub pg_pool: sqlx::PgPool, pub pg_pool: sqlx::PgPool,
pub oauth_http_client: reqwest::Client, pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient, pub oauth_client: OAuthClient,
@@ -67,25 +68,21 @@ impl State {
)) ))
.await?; .await?;
let mut http_client = let oauth_http_client = reqwest::ClientBuilder::new()
reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); .redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(config.oauth.insecure)
if config.oauth.insecure { .build()?;
http_client = http_client.danger_accept_invalid_certs(true);
}
let http_client = http_client.build()?;
let provider_metadata = CoreProviderMetadata::discover_async( let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(config.oauth.issuer_url)?, IssuerUrl::new(config.oauth.issuer_url.clone())?,
&http_client, &oauth_http_client,
) )
.await?; .await?;
let oauth_client = openidconnect::core::CoreClient::from_provider_metadata( let oauth_client = OAuthClient::from_provider_metadata(
provider_metadata, provider_metadata,
ClientId::new(config.oauth.client_id), ClientId::new(config.oauth.client_id.clone()),
Some(ClientSecret::new(config.oauth.client_secret)), Some(ClientSecret::new(config.oauth.client_secret.clone())),
) )
.set_redirect_uri(RedirectUrl::new(format!( .set_redirect_uri(RedirectUrl::new(format!(
"{}{}/api/auth/callback", "{}{}/api/auth/callback",
@@ -106,14 +103,21 @@ impl State {
}); });
Ok(Self { Ok(Self {
config,
pg_pool, pg_pool,
oauth_http_client: http_client, oauth_http_client,
oauth_client, oauth_client,
session_store, session_store,
}) })
} }
} }
impl FromRef<State> for Config {
fn from_ref(state: &State) -> Self {
state.config.clone()
}
}
impl FromRef<State> for sqlx::PgPool { impl FromRef<State> for sqlx::PgPool {
fn from_ref(state: &State) -> Self { fn from_ref(state: &State) -> Self {
state.pg_pool.clone() state.pg_pool.clone()