From d5e1b1b437028ff3aa1e342c5d0bad3f612e1fab Mon Sep 17 00:00:00 2001 From: Nikolaos Karaolidis Date: Fri, 28 Mar 2025 14:34:42 +0000 Subject: [PATCH] feat: add session persistence Signed-off-by: Nikolaos Karaolidis --- Cargo.lock | 1 + Cargo.toml | 1 + manifest.yaml | 1 + src/config.rs | 2 + src/routes/auth.rs | 195 ++++++++++++++++++++++++++++++++++++++------- src/state.rs | 32 ++++---- 6 files changed, 189 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4651482..b69d536 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3005,6 +3005,7 @@ dependencies = [ "log4rs", "openidconnect", "serde", + "serde_json", "serde_yaml", "sqlx", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 7cfc018..68ea091 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ log = "0.4.27" log4rs = "1.3.0" openidconnect = { version = "4.0.0", features = ["reqwest"] } serde = "1.0.219" +serde_json = "1.0.140" serde_yaml = "0.9.34" sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] } tokio = { version = "1.44.1", features = ["rt-multi-thread"] } diff --git a/manifest.yaml b/manifest.yaml index 73db240..d767ff5 100644 --- a/manifest.yaml +++ b/manifest.yaml @@ -84,6 +84,7 @@ data: issuer_url: "https://id.veil.local" client_id: "veil" client_secret: "insecure_secret" + admin_group: "admins" insecure: true log4rs.yml: | appenders: diff --git a/src/config.rs b/src/config.rs index fdeb02c..f95905b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -41,6 +41,8 @@ pub struct OAuthConfig { pub client_secret: String, #[serde(default)] pub insecure: bool, + #[serde(default)] + pub admin_group: Option, } #[derive(Clone, Deserialize)] diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 01147b7..22f7191 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -11,13 +11,18 @@ use axum::{ use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason}; use log::error; use openidconnect::{ - AccessTokenHash, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername, Nonce, - OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier, - TokenResponse, core::CoreAuthenticationFlow, reqwest, + AccessTokenHash, AdditionalClaims, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername, + Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier, + TokenResponse, UserInfoClaims, + core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse}, + reqwest, }; use serde::{Deserialize, Serialize}; -use crate::state::{OAuthClient, State}; +use crate::{ + config::Config, + state::{OAuthClient, State}, +}; static COOKIE_NAME: &str = "veil_session"; @@ -26,6 +31,16 @@ pub struct User { pub subject: SubjectIdentifier, pub username: EndUserUsername, pub email: Option, + pub is_admin: bool, +} + +impl IntoResponse for User { + fn into_response(self) -> Response { + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); + let body = serde_json::to_string(&self).unwrap(); + (headers, body).into_response() + } } async fn login( @@ -52,7 +67,19 @@ async fn login( .add_scope(Scope::new("groups".to_string())) .url(); + let session = create_login_session(pkce_verifier, csrf_token, nonce)?; + let headers = create_login_cookie(&session_store, session).await?; + + Ok((headers, Redirect::to(auth_url.as_str()))) +} + +fn create_login_session( + pkce_verifier: PkceCodeVerifier, + csrf_token: CsrfToken, + nonce: Nonce, +) -> Result { let mut session = Session::new(); + session .insert("pkce_verifier", pkce_verifier) .map_err(|e| { @@ -68,6 +95,13 @@ async fn login( StatusCode::INTERNAL_SERVER_ERROR })?; + Ok(session) +} + +async fn create_login_cookie( + session_store: &MemoryStore, + session: Session, +) -> Result { let cookie = session_store .store_session(session) .await @@ -80,8 +114,7 @@ async fn login( StatusCode::INTERNAL_SERVER_ERROR })?; - let cookie = - format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; HttpOnly; Secure; Path=/"); + let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/"); let mut headers = HeaderMap::new(); headers.insert( @@ -92,7 +125,7 @@ async fn login( })?, ); - Ok((headers, Redirect::to(auth_url.as_str()))) + Ok(headers) } #[derive(Debug, Deserialize)] @@ -107,6 +140,7 @@ async fn callback( extract::State(http_client): extract::State, extract::State(oauth_client): extract::State, extract::State(session_store): extract::State, + extract::State(config): extract::State, TypedHeader(cookies): TypedHeader, ) -> Result { let cookie = cookies @@ -123,13 +157,49 @@ async fn callback( })? .ok_or(StatusCode::UNAUTHORIZED)?; - let csrf_token = session - .get::("csrf_token") - .ok_or_else(|| { - error!("failed to get csrf_token from session"); + let (csrf_token, pkce_verifier, nonce) = retrieve_login_session(&session)?; + + session_store.destroy_session(session).await.map_err(|e| { + error!("failed to destroy session: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let token_response = validate_claims( + &http_client, + &oauth_client, + params, + csrf_token, + pkce_verifier, + nonce, + ) + .await?; + + let claims = oauth_client + .user_info(token_response.access_token().to_owned(), None) + .map_err(|e| { + error!("failed to create userinfo request: {e}"); StatusCode::INTERNAL_SERVER_ERROR })? - .clone(); + .request_async(&http_client) + .await + .map_err(|e| { + error!("failed to request user info: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let session = create_user_session(&config, &claims)?; + let headers = create_user_cookie(&session_store, session).await?; + + Ok((headers, StatusCode::OK)) +} + +fn retrieve_login_session( + session: &Session, +) -> Result<(CsrfToken, PkceCodeVerifier, Nonce), StatusCode> { + let csrf_token = session.get::("csrf_token").ok_or_else(|| { + error!("failed to get csrf_token from session"); + StatusCode::INTERNAL_SERVER_ERROR + })?; let pkce_verifier = session .get::("pkce_verifier") @@ -138,32 +208,35 @@ async fn callback( StatusCode::INTERNAL_SERVER_ERROR })?; - let nonce = session - .get::("nonce") - .ok_or_else(|| { - error!("failed to get nonce from session"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .clone(); - - session_store.destroy_session(session).await.map_err(|e| { - error!("failed to destroy session: {e}"); + let nonce = session.get::("nonce").ok_or_else(|| { + error!("failed to get nonce from session"); StatusCode::INTERNAL_SERVER_ERROR })?; + Ok((csrf_token, pkce_verifier, nonce)) +} + +async fn validate_claims( + http_client: &reqwest::Client, + oauth_client: &OAuthClient, + params: CallbackParams, + csrf_token: CsrfToken, + pkce_verifier: PkceCodeVerifier, + nonce: Nonce, +) -> Result { if *csrf_token.secret() != params.state { error!("csrf_token mismatch"); return Err(StatusCode::INTERNAL_SERVER_ERROR); } let token_response = oauth_client - .exchange_code(AuthorizationCode::new(params.code)) + .exchange_code(AuthorizationCode::new(params.code.clone())) .map_err(|e| { error!("failed to exchange code: {e}"); StatusCode::INTERNAL_SERVER_ERROR })? .set_pkce_verifier(pkce_verifier) - .request_async(&http_client) + .request_async(http_client) .await .map_err(|e| { error!("failed to request token: {e}"); @@ -176,10 +249,14 @@ async fn callback( })?; let id_token_verifier = oauth_client.id_token_verifier(); - let claims = id_token.claims(&id_token_verifier, &nonce).map_err(|e| { - error!("failed to verify id_token: {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + + let claims = id_token + .claims(&id_token_verifier, &nonce) + .map_err(|e| { + error!("failed to verify id_token: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .to_owned(); if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( @@ -204,6 +281,20 @@ async fn callback( } } + Ok(token_response) +} + +#[derive(Debug, Serialize, Deserialize)] +struct ExtraClaims { + groups: Vec, +} + +impl AdditionalClaims for ExtraClaims {} + +fn create_user_session( + config: &Config, + claims: &UserInfoClaims, +) -> Result { let user = User { subject: claims.subject().to_owned(), username: claims.preferred_username().cloned().ok_or_else(|| { @@ -211,6 +302,17 @@ async fn callback( StatusCode::INTERNAL_SERVER_ERROR })?, email: claims.email().cloned(), + is_admin: config + .oauth + .admin_group + .as_ref() + .is_some_and(|admin_group| { + claims + .additional_claims() + .groups + .iter() + .any(|group| group == admin_group) + }), }; let mut session = Session::new(); @@ -219,7 +321,37 @@ async fn callback( StatusCode::INTERNAL_SERVER_ERROR })?; - Ok(StatusCode::OK) + Ok(session) +} + +async fn create_user_cookie( + session_store: &MemoryStore, + session: Session, +) -> Result { + let cookie = session_store + .store_session(session) + .await + .map_err(|e| { + error!("failed to store session: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? + .ok_or_else(|| { + error!("failed to retrieve stored session cookie"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/"); + + let mut headers = HeaderMap::new(); + headers.insert( + header::SET_COOKIE, + cookie.parse().map_err(|e| { + error!("failed to parse cookie: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?, + ); + + Ok(headers) } async fn logout( @@ -247,11 +379,16 @@ async fn logout( Ok(StatusCode::OK) } +async fn session(user: User) -> Result { + Ok(user) +} + pub fn routes(state: State) -> Router { Router::new() .route("/auth/login", routing::get(login)) .route("/auth/callback", routing::get(callback)) .route("/auth/logout", routing::get(logout)) + .route("/auth/session", routing::get(session)) .with_state(state) } diff --git a/src/state.rs b/src/state.rs index fc414c7..74e4df5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -47,6 +47,7 @@ pub type OAuthClient< #[derive(Clone)] pub struct State { + pub config: Config, pub pg_pool: sqlx::PgPool, pub oauth_http_client: reqwest::Client, pub oauth_client: OAuthClient, @@ -67,25 +68,21 @@ impl State { )) .await?; - let mut http_client = - reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()); - - if config.oauth.insecure { - http_client = http_client.danger_accept_invalid_certs(true); - } - - let http_client = http_client.build()?; + let oauth_http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .danger_accept_invalid_certs(config.oauth.insecure) + .build()?; let provider_metadata = CoreProviderMetadata::discover_async( - IssuerUrl::new(config.oauth.issuer_url)?, - &http_client, + IssuerUrl::new(config.oauth.issuer_url.clone())?, + &oauth_http_client, ) .await?; - let oauth_client = openidconnect::core::CoreClient::from_provider_metadata( + let oauth_client = OAuthClient::from_provider_metadata( provider_metadata, - ClientId::new(config.oauth.client_id), - Some(ClientSecret::new(config.oauth.client_secret)), + ClientId::new(config.oauth.client_id.clone()), + Some(ClientSecret::new(config.oauth.client_secret.clone())), ) .set_redirect_uri(RedirectUrl::new(format!( "{}{}/api/auth/callback", @@ -106,14 +103,21 @@ impl State { }); Ok(Self { + config, pg_pool, - oauth_http_client: http_client, + oauth_http_client, oauth_client, session_store, }) } } +impl FromRef for Config { + fn from_ref(state: &State) -> Self { + state.config.clone() + } +} + impl FromRef for sqlx::PgPool { fn from_ref(state: &State) -> Self { state.pg_pool.clone()