feat: add session persistence
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3005,6 +3005,7 @@ dependencies = [
|
|||||||
"log4rs",
|
"log4rs",
|
||||||
"openidconnect",
|
"openidconnect",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@@ -22,6 +22,7 @@ log = "0.4.27"
|
|||||||
log4rs = "1.3.0"
|
log4rs = "1.3.0"
|
||||||
openidconnect = { version = "4.0.0", features = ["reqwest"] }
|
openidconnect = { version = "4.0.0", features = ["reqwest"] }
|
||||||
serde = "1.0.219"
|
serde = "1.0.219"
|
||||||
|
serde_json = "1.0.140"
|
||||||
serde_yaml = "0.9.34"
|
serde_yaml = "0.9.34"
|
||||||
sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] }
|
sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] }
|
||||||
tokio = { version = "1.44.1", features = ["rt-multi-thread"] }
|
tokio = { version = "1.44.1", features = ["rt-multi-thread"] }
|
||||||
|
@@ -84,6 +84,7 @@ data:
|
|||||||
issuer_url: "https://id.veil.local"
|
issuer_url: "https://id.veil.local"
|
||||||
client_id: "veil"
|
client_id: "veil"
|
||||||
client_secret: "insecure_secret"
|
client_secret: "insecure_secret"
|
||||||
|
admin_group: "admins"
|
||||||
insecure: true
|
insecure: true
|
||||||
log4rs.yml: |
|
log4rs.yml: |
|
||||||
appenders:
|
appenders:
|
||||||
|
@@ -41,6 +41,8 @@ pub struct OAuthConfig {
|
|||||||
pub client_secret: String,
|
pub client_secret: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub insecure: bool,
|
pub insecure: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub admin_group: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
|
@@ -11,13 +11,18 @@ use axum::{
|
|||||||
use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason};
|
use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason};
|
||||||
use log::error;
|
use log::error;
|
||||||
use openidconnect::{
|
use openidconnect::{
|
||||||
AccessTokenHash, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername, Nonce,
|
AccessTokenHash, AdditionalClaims, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername,
|
||||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier,
|
Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier,
|
||||||
TokenResponse, core::CoreAuthenticationFlow, reqwest,
|
TokenResponse, UserInfoClaims,
|
||||||
|
core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse},
|
||||||
|
reqwest,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::state::{OAuthClient, State};
|
use crate::{
|
||||||
|
config::Config,
|
||||||
|
state::{OAuthClient, State},
|
||||||
|
};
|
||||||
|
|
||||||
static COOKIE_NAME: &str = "veil_session";
|
static COOKIE_NAME: &str = "veil_session";
|
||||||
|
|
||||||
@@ -26,6 +31,16 @@ pub struct User {
|
|||||||
pub subject: SubjectIdentifier,
|
pub subject: SubjectIdentifier,
|
||||||
pub username: EndUserUsername,
|
pub username: EndUserUsername,
|
||||||
pub email: Option<EndUserEmail>,
|
pub email: Option<EndUserEmail>,
|
||||||
|
pub is_admin: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for User {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
|
||||||
|
let body = serde_json::to_string(&self).unwrap();
|
||||||
|
(headers, body).into_response()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn login(
|
async fn login(
|
||||||
@@ -52,7 +67,19 @@ async fn login(
|
|||||||
.add_scope(Scope::new("groups".to_string()))
|
.add_scope(Scope::new("groups".to_string()))
|
||||||
.url();
|
.url();
|
||||||
|
|
||||||
|
let session = create_login_session(pkce_verifier, csrf_token, nonce)?;
|
||||||
|
let headers = create_login_cookie(&session_store, session).await?;
|
||||||
|
|
||||||
|
Ok((headers, Redirect::to(auth_url.as_str())))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_login_session(
|
||||||
|
pkce_verifier: PkceCodeVerifier,
|
||||||
|
csrf_token: CsrfToken,
|
||||||
|
nonce: Nonce,
|
||||||
|
) -> Result<Session, StatusCode> {
|
||||||
let mut session = Session::new();
|
let mut session = Session::new();
|
||||||
|
|
||||||
session
|
session
|
||||||
.insert("pkce_verifier", pkce_verifier)
|
.insert("pkce_verifier", pkce_verifier)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
@@ -68,6 +95,13 @@ async fn login(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_login_cookie(
|
||||||
|
session_store: &MemoryStore,
|
||||||
|
session: Session,
|
||||||
|
) -> Result<HeaderMap, StatusCode> {
|
||||||
let cookie = session_store
|
let cookie = session_store
|
||||||
.store_session(session)
|
.store_session(session)
|
||||||
.await
|
.await
|
||||||
@@ -80,8 +114,7 @@ async fn login(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let cookie =
|
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
|
||||||
format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; HttpOnly; Secure; Path=/");
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
@@ -92,7 +125,7 @@ async fn login(
|
|||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok((headers, Redirect::to(auth_url.as_str())))
|
Ok(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -107,6 +140,7 @@ async fn callback(
|
|||||||
extract::State(http_client): extract::State<reqwest::Client>,
|
extract::State(http_client): extract::State<reqwest::Client>,
|
||||||
extract::State(oauth_client): extract::State<OAuthClient>,
|
extract::State(oauth_client): extract::State<OAuthClient>,
|
||||||
extract::State(session_store): extract::State<MemoryStore>,
|
extract::State(session_store): extract::State<MemoryStore>,
|
||||||
|
extract::State(config): extract::State<Config>,
|
||||||
TypedHeader(cookies): TypedHeader<Cookie>,
|
TypedHeader(cookies): TypedHeader<Cookie>,
|
||||||
) -> Result<impl IntoResponse, StatusCode> {
|
) -> Result<impl IntoResponse, StatusCode> {
|
||||||
let cookie = cookies
|
let cookie = cookies
|
||||||
@@ -123,13 +157,49 @@ async fn callback(
|
|||||||
})?
|
})?
|
||||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||||
|
|
||||||
let csrf_token = session
|
let (csrf_token, pkce_verifier, nonce) = retrieve_login_session(&session)?;
|
||||||
.get::<CsrfToken>("csrf_token")
|
|
||||||
.ok_or_else(|| {
|
session_store.destroy_session(session).await.map_err(|e| {
|
||||||
error!("failed to get csrf_token from session");
|
error!("failed to destroy session: {e}");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let token_response = validate_claims(
|
||||||
|
&http_client,
|
||||||
|
&oauth_client,
|
||||||
|
params,
|
||||||
|
csrf_token,
|
||||||
|
pkce_verifier,
|
||||||
|
nonce,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let claims = oauth_client
|
||||||
|
.user_info(token_response.access_token().to_owned(), None)
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("failed to create userinfo request: {e}");
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?
|
})?
|
||||||
.clone();
|
.request_async(&http_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("failed to request user info: {e}");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let session = create_user_session(&config, &claims)?;
|
||||||
|
let headers = create_user_cookie(&session_store, session).await?;
|
||||||
|
|
||||||
|
Ok((headers, StatusCode::OK))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn retrieve_login_session(
|
||||||
|
session: &Session,
|
||||||
|
) -> Result<(CsrfToken, PkceCodeVerifier, Nonce), StatusCode> {
|
||||||
|
let csrf_token = session.get::<CsrfToken>("csrf_token").ok_or_else(|| {
|
||||||
|
error!("failed to get csrf_token from session");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
let pkce_verifier = session
|
let pkce_verifier = session
|
||||||
.get::<PkceCodeVerifier>("pkce_verifier")
|
.get::<PkceCodeVerifier>("pkce_verifier")
|
||||||
@@ -138,32 +208,35 @@ async fn callback(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let nonce = session
|
let nonce = session.get::<Nonce>("nonce").ok_or_else(|| {
|
||||||
.get::<Nonce>("nonce")
|
error!("failed to get nonce from session");
|
||||||
.ok_or_else(|| {
|
|
||||||
error!("failed to get nonce from session");
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
})?
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
session_store.destroy_session(session).await.map_err(|e| {
|
|
||||||
error!("failed to destroy session: {e}");
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
Ok((csrf_token, pkce_verifier, nonce))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_claims(
|
||||||
|
http_client: &reqwest::Client,
|
||||||
|
oauth_client: &OAuthClient,
|
||||||
|
params: CallbackParams,
|
||||||
|
csrf_token: CsrfToken,
|
||||||
|
pkce_verifier: PkceCodeVerifier,
|
||||||
|
nonce: Nonce,
|
||||||
|
) -> Result<CoreTokenResponse, StatusCode> {
|
||||||
if *csrf_token.secret() != params.state {
|
if *csrf_token.secret() != params.state {
|
||||||
error!("csrf_token mismatch");
|
error!("csrf_token mismatch");
|
||||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_response = oauth_client
|
let token_response = oauth_client
|
||||||
.exchange_code(AuthorizationCode::new(params.code))
|
.exchange_code(AuthorizationCode::new(params.code.clone()))
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("failed to exchange code: {e}");
|
error!("failed to exchange code: {e}");
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?
|
})?
|
||||||
.set_pkce_verifier(pkce_verifier)
|
.set_pkce_verifier(pkce_verifier)
|
||||||
.request_async(&http_client)
|
.request_async(http_client)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("failed to request token: {e}");
|
error!("failed to request token: {e}");
|
||||||
@@ -176,10 +249,14 @@ async fn callback(
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let id_token_verifier = oauth_client.id_token_verifier();
|
let id_token_verifier = oauth_client.id_token_verifier();
|
||||||
let claims = id_token.claims(&id_token_verifier, &nonce).map_err(|e| {
|
|
||||||
error!("failed to verify id_token: {e}");
|
let claims = id_token
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
.claims(&id_token_verifier, &nonce)
|
||||||
})?;
|
.map_err(|e| {
|
||||||
|
error!("failed to verify id_token: {e}");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?
|
||||||
|
.to_owned();
|
||||||
|
|
||||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||||
@@ -204,6 +281,20 @@ async fn callback(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(token_response)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ExtraClaims {
|
||||||
|
groups: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdditionalClaims for ExtraClaims {}
|
||||||
|
|
||||||
|
fn create_user_session(
|
||||||
|
config: &Config,
|
||||||
|
claims: &UserInfoClaims<ExtraClaims, CoreGenderClaim>,
|
||||||
|
) -> Result<Session, StatusCode> {
|
||||||
let user = User {
|
let user = User {
|
||||||
subject: claims.subject().to_owned(),
|
subject: claims.subject().to_owned(),
|
||||||
username: claims.preferred_username().cloned().ok_or_else(|| {
|
username: claims.preferred_username().cloned().ok_or_else(|| {
|
||||||
@@ -211,6 +302,17 @@ async fn callback(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?,
|
})?,
|
||||||
email: claims.email().cloned(),
|
email: claims.email().cloned(),
|
||||||
|
is_admin: config
|
||||||
|
.oauth
|
||||||
|
.admin_group
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|admin_group| {
|
||||||
|
claims
|
||||||
|
.additional_claims()
|
||||||
|
.groups
|
||||||
|
.iter()
|
||||||
|
.any(|group| group == admin_group)
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut session = Session::new();
|
let mut session = Session::new();
|
||||||
@@ -219,7 +321,37 @@ async fn callback(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(StatusCode::OK)
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_user_cookie(
|
||||||
|
session_store: &MemoryStore,
|
||||||
|
session: Session,
|
||||||
|
) -> Result<HeaderMap, StatusCode> {
|
||||||
|
let cookie = session_store
|
||||||
|
.store_session(session)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("failed to store session: {e}");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
error!("failed to retrieve stored session cookie");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
header::SET_COOKIE,
|
||||||
|
cookie.parse().map_err(|e| {
|
||||||
|
error!("failed to parse cookie: {e}");
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn logout(
|
async fn logout(
|
||||||
@@ -247,11 +379,16 @@ async fn logout(
|
|||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn session(user: User) -> Result<impl IntoResponse, StatusCode> {
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn routes(state: State) -> Router {
|
pub fn routes(state: State) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/auth/login", routing::get(login))
|
.route("/auth/login", routing::get(login))
|
||||||
.route("/auth/callback", routing::get(callback))
|
.route("/auth/callback", routing::get(callback))
|
||||||
.route("/auth/logout", routing::get(logout))
|
.route("/auth/logout", routing::get(logout))
|
||||||
|
.route("/auth/session", routing::get(session))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
32
src/state.rs
32
src/state.rs
@@ -47,6 +47,7 @@ pub type OAuthClient<
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct State {
|
pub struct State {
|
||||||
|
pub config: Config,
|
||||||
pub pg_pool: sqlx::PgPool,
|
pub pg_pool: sqlx::PgPool,
|
||||||
pub oauth_http_client: reqwest::Client,
|
pub oauth_http_client: reqwest::Client,
|
||||||
pub oauth_client: OAuthClient,
|
pub oauth_client: OAuthClient,
|
||||||
@@ -67,25 +68,21 @@ impl State {
|
|||||||
))
|
))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut http_client =
|
let oauth_http_client = reqwest::ClientBuilder::new()
|
||||||
reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none());
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
|
.danger_accept_invalid_certs(config.oauth.insecure)
|
||||||
if config.oauth.insecure {
|
.build()?;
|
||||||
http_client = http_client.danger_accept_invalid_certs(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
let http_client = http_client.build()?;
|
|
||||||
|
|
||||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||||
IssuerUrl::new(config.oauth.issuer_url)?,
|
IssuerUrl::new(config.oauth.issuer_url.clone())?,
|
||||||
&http_client,
|
&oauth_http_client,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let oauth_client = openidconnect::core::CoreClient::from_provider_metadata(
|
let oauth_client = OAuthClient::from_provider_metadata(
|
||||||
provider_metadata,
|
provider_metadata,
|
||||||
ClientId::new(config.oauth.client_id),
|
ClientId::new(config.oauth.client_id.clone()),
|
||||||
Some(ClientSecret::new(config.oauth.client_secret)),
|
Some(ClientSecret::new(config.oauth.client_secret.clone())),
|
||||||
)
|
)
|
||||||
.set_redirect_uri(RedirectUrl::new(format!(
|
.set_redirect_uri(RedirectUrl::new(format!(
|
||||||
"{}{}/api/auth/callback",
|
"{}{}/api/auth/callback",
|
||||||
@@ -106,14 +103,21 @@ impl State {
|
|||||||
});
|
});
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
config,
|
||||||
pg_pool,
|
pg_pool,
|
||||||
oauth_http_client: http_client,
|
oauth_http_client,
|
||||||
oauth_client,
|
oauth_client,
|
||||||
session_store,
|
session_store,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FromRef<State> for Config {
|
||||||
|
fn from_ref(state: &State) -> Self {
|
||||||
|
state.config.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl FromRef<State> for sqlx::PgPool {
|
impl FromRef<State> for sqlx::PgPool {
|
||||||
fn from_ref(state: &State) -> Self {
|
fn from_ref(state: &State) -> Self {
|
||||||
state.pg_pool.clone()
|
state.pg_pool.clone()
|
||||||
|
Reference in New Issue
Block a user