Initial commit
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
67
src/config.rs
Normal file
67
src/config.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
fs,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
#[serde(default = "default_server_address")]
|
||||
pub address: std::net::IpAddr,
|
||||
#[serde(default = "default_server_port")]
|
||||
pub port: u16,
|
||||
#[serde(default)]
|
||||
pub subpath: String,
|
||||
}
|
||||
|
||||
const fn default_server_address() -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
|
||||
}
|
||||
|
||||
const fn default_server_port() -> u16 {
|
||||
8080
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct OAuthConfig {
|
||||
pub issuer_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
#[serde(default)]
|
||||
pub insecure: bool,
|
||||
pub admin_group: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct AutheliaConfig {
|
||||
pub user_database: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub server: ServerConfig,
|
||||
pub oauth: OAuthConfig,
|
||||
pub authelia: AutheliaConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_yaml(path: &PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let contents = fs::read_to_string(path)?;
|
||||
let config = serde_yaml::from_str(&contents)?;
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None, author)]
|
||||
pub struct Args {
|
||||
/// Path to the YAML config file
|
||||
#[arg(short, long, value_name = "FILE", default_value = "config.yaml")]
|
||||
pub config: PathBuf,
|
||||
/// Path to the log4rs config file
|
||||
#[arg(short, long, value_name = "FILE", default_value = "log4rs.yaml")]
|
||||
pub log_config: PathBuf,
|
||||
}
|
36
src/main.rs
Normal file
36
src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
mod config;
|
||||
mod models;
|
||||
mod routes;
|
||||
mod state;
|
||||
mod utils;
|
||||
|
||||
use axum::serve;
|
||||
use clap::Parser;
|
||||
use log::info;
|
||||
use log4rs::config::Deserializers;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use config::{Args, Config};
|
||||
use state::State;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = Args::parse();
|
||||
log4rs::init_file(args.log_config, Deserializers::default()).unwrap();
|
||||
|
||||
let config = Config::from_yaml(&args.config).unwrap();
|
||||
let state = State::from_config(config.clone()).await.unwrap();
|
||||
|
||||
let routes = routes::routes(state);
|
||||
let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes);
|
||||
|
||||
let addr = SocketAddr::from((config.server.address, config.server.port));
|
||||
let listener = TcpListener::bind(addr).await.unwrap();
|
||||
|
||||
info!("Listening on {}", listener.local_addr().unwrap());
|
||||
serve(listener, app).await.unwrap();
|
||||
}
|
160
src/models/authelia.rs
Normal file
160
src/models/authelia.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use non_empty_string::NonEmptyString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsersFile {
|
||||
pub users: HashMap<NonEmptyString, UserFile>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<NonEmptyString, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserFile {
|
||||
pub displayname: NonEmptyString,
|
||||
pub password: NonEmptyString,
|
||||
pub email: Option<String>,
|
||||
pub disabled: Option<bool>,
|
||||
pub groups: Option<Vec<NonEmptyString>>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<NonEmptyString, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub displayname: NonEmptyString,
|
||||
pub email: Option<String>,
|
||||
pub password: NonEmptyString,
|
||||
pub disabled: bool,
|
||||
pub groups: Vec<NonEmptyString>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<NonEmptyString, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Users {
|
||||
pub users: HashMap<NonEmptyString, User>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<NonEmptyString, Value>>,
|
||||
}
|
||||
|
||||
impl Deref for Users {
|
||||
type Target = HashMap<NonEmptyString, User>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.users
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Users {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.users
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserFile> for User {
|
||||
fn from(user_file: UserFile) -> Self {
|
||||
Self {
|
||||
displayname: user_file.displayname,
|
||||
email: user_file.email,
|
||||
password: user_file.password,
|
||||
disabled: user_file.disabled.unwrap_or(false),
|
||||
groups: user_file.groups.unwrap_or_default(),
|
||||
extra: user_file.extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UsersFile> for Users {
|
||||
fn from(users_file: UsersFile) -> Self {
|
||||
Self {
|
||||
users: users_file
|
||||
.users
|
||||
.into_iter()
|
||||
.map(|(key, user)| (key, User::from(user)))
|
||||
.collect(),
|
||||
extra: users_file.extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<User> for UserFile {
|
||||
fn from(user: User) -> Self {
|
||||
Self {
|
||||
displayname: user.displayname,
|
||||
email: user.email,
|
||||
password: user.password,
|
||||
disabled: if user.disabled { Some(true) } else { None },
|
||||
groups: if user.groups.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(user.groups)
|
||||
},
|
||||
extra: user.extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Users> for UsersFile {
|
||||
fn from(users: Users) -> Self {
|
||||
Self {
|
||||
users: users
|
||||
.users
|
||||
.into_iter()
|
||||
.map(|(key, user)| (key, UserFile::from(user)))
|
||||
.collect(),
|
||||
extra: users.extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Group {
|
||||
pub users: Vec<NonEmptyString>,
|
||||
}
|
||||
|
||||
pub struct Groups {
|
||||
pub groups: HashMap<NonEmptyString, Group>,
|
||||
}
|
||||
|
||||
impl Deref for Groups {
|
||||
type Target = HashMap<NonEmptyString, Group>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.groups
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Groups {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.groups
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UsersFile> for Groups {
|
||||
fn from(users_file: UsersFile) -> Self {
|
||||
users_file.users.into_iter().fold(
|
||||
Self {
|
||||
groups: HashMap::new(),
|
||||
},
|
||||
|mut acc, (key, user)| {
|
||||
for group in user.groups.unwrap_or_default() {
|
||||
acc.entry(group)
|
||||
.or_insert_with(|| Group { users: Vec::new() })
|
||||
.users
|
||||
.push(key.clone());
|
||||
}
|
||||
acc
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
1
src/models/mod.rs
Normal file
1
src/models/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod authelia;
|
442
src/routes/auth.rs
Normal file
442
src/routes/auth.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use std::{borrow::Cow, convert::Infallible};
|
||||
|
||||
use async_session::{MemoryStore, Session, SessionStore};
|
||||
use axum::{
|
||||
Json, RequestPartsExt, Router,
|
||||
extract::{self, FromRef, FromRequestParts, OptionalFromRequestParts},
|
||||
http::{HeaderMap, StatusCode, header, request::Parts},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing,
|
||||
};
|
||||
use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason};
|
||||
use log::error;
|
||||
use openidconnect::{
|
||||
AccessTokenHash, AdditionalClaims, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername,
|
||||
Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier,
|
||||
TokenResponse, UserInfoClaims,
|
||||
core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse},
|
||||
reqwest,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
config::Config,
|
||||
state::{OAuthClient, State},
|
||||
};
|
||||
|
||||
static COOKIE_NAME: &str = "glyph_session";
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct User {
|
||||
pub subject: SubjectIdentifier,
|
||||
pub username: EndUserUsername,
|
||||
pub email: Option<EndUserEmail>,
|
||||
}
|
||||
|
||||
async fn login(
|
||||
extract::State(oauth_client): extract::State<OAuthClient>,
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let (auth_url, csrf_token, nonce) = oauth_client
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.set_redirect_uri(Cow::Borrowed(oauth_client.redirect_uri().ok_or_else(
|
||||
|| {
|
||||
error!("missing redirect URI");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
},
|
||||
)?))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.add_scope(Scope::new("groups".to_string()))
|
||||
.url();
|
||||
|
||||
let session = create_login_session(pkce_verifier, csrf_token, nonce)?;
|
||||
let headers = create_login_cookie(&session_store, session).await?;
|
||||
|
||||
Ok((headers, Redirect::to(auth_url.as_str())))
|
||||
}
|
||||
|
||||
fn create_login_session(
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
csrf_token: CsrfToken,
|
||||
nonce: Nonce,
|
||||
) -> Result<Session, StatusCode> {
|
||||
let mut session = Session::new();
|
||||
|
||||
session
|
||||
.insert("pkce_verifier", pkce_verifier)
|
||||
.map_err(|e| {
|
||||
error!("failed to insert pkce_verifier into session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
session.insert("csrf_token", csrf_token).map_err(|e| {
|
||||
error!("failed to insert csrf_token into session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
session.insert("nonce", nonce).map_err(|e| {
|
||||
error!("failed to insert nonce into session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
async fn create_login_cookie(
|
||||
session_store: &MemoryStore,
|
||||
session: Session,
|
||||
) -> Result<HeaderMap, StatusCode> {
|
||||
let cookie = session_store
|
||||
.store_session(session)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to store session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
error!("failed to retrieve stored session cookie");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
header::SET_COOKIE,
|
||||
cookie.parse().map_err(|e| {
|
||||
error!("failed to parse cookie: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?,
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
async fn callback(
|
||||
extract::Query(params): extract::Query<CallbackParams>,
|
||||
extract::State(http_client): extract::State<reqwest::Client>,
|
||||
extract::State(oauth_client): extract::State<OAuthClient>,
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
extract::State(config): extract::State<Config>,
|
||||
TypedHeader(cookies): TypedHeader<Cookie>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let cookie = cookies
|
||||
.get(COOKIE_NAME)
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?
|
||||
.to_string();
|
||||
|
||||
let session = session_store
|
||||
.load_session(cookie)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to load session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let (csrf_token, pkce_verifier, nonce) = retrieve_login_session(&session)?;
|
||||
|
||||
session_store.destroy_session(session).await.map_err(|e| {
|
||||
error!("failed to destroy session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let token_response = validate_claims(
|
||||
&http_client,
|
||||
&oauth_client,
|
||||
params,
|
||||
csrf_token,
|
||||
pkce_verifier,
|
||||
nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let claims = oauth_client
|
||||
.user_info(token_response.access_token().to_owned(), None)
|
||||
.map_err(|e| {
|
||||
error!("failed to create userinfo request: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.request_async(&http_client)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to request user info: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let session = create_user_session(&config, &claims)?;
|
||||
let headers = create_user_cookie(&session_store, session).await?;
|
||||
|
||||
Ok((headers, StatusCode::OK))
|
||||
}
|
||||
|
||||
fn retrieve_login_session(
|
||||
session: &Session,
|
||||
) -> Result<(CsrfToken, PkceCodeVerifier, Nonce), StatusCode> {
|
||||
let csrf_token = session.get::<CsrfToken>("csrf_token").ok_or_else(|| {
|
||||
error!("failed to get csrf_token from session");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let pkce_verifier = session
|
||||
.get::<PkceCodeVerifier>("pkce_verifier")
|
||||
.ok_or_else(|| {
|
||||
error!("failed to get pkce_verifier from session");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let nonce = session.get::<Nonce>("nonce").ok_or_else(|| {
|
||||
error!("failed to get nonce from session");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok((csrf_token, pkce_verifier, nonce))
|
||||
}
|
||||
|
||||
async fn validate_claims(
|
||||
http_client: &reqwest::Client,
|
||||
oauth_client: &OAuthClient,
|
||||
params: CallbackParams,
|
||||
csrf_token: CsrfToken,
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
nonce: Nonce,
|
||||
) -> Result<CoreTokenResponse, StatusCode> {
|
||||
if *csrf_token.secret() != params.state {
|
||||
error!("csrf_token mismatch");
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
let token_response = oauth_client
|
||||
.exchange_code(AuthorizationCode::new(params.code.clone()))
|
||||
.map_err(|e| {
|
||||
error!("failed to exchange code: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(http_client)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to request token: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let id_token = token_response.id_token().ok_or_else(|| {
|
||||
error!("missing id_token in token response");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let id_token_verifier = oauth_client.id_token_verifier();
|
||||
|
||||
let claims = id_token
|
||||
.claims(&id_token_verifier, &nonce)
|
||||
.map_err(|e| {
|
||||
error!("failed to verify id_token: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.to_owned();
|
||||
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
id_token.signing_alg().map_err(|e| {
|
||||
error!("failed to get signing algorithm: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?,
|
||||
id_token.signing_key(&id_token_verifier).map_err(|e| {
|
||||
error!("failed to get signing key: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?,
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!("failed to compute access token hash: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
error!("access token hash mismatch");
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(token_response)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ExtraClaims {
|
||||
groups: Vec<String>,
|
||||
}
|
||||
|
||||
impl AdditionalClaims for ExtraClaims {}
|
||||
|
||||
fn create_user_session(
|
||||
config: &Config,
|
||||
claims: &UserInfoClaims<ExtraClaims, CoreGenderClaim>,
|
||||
) -> Result<Session, StatusCode> {
|
||||
if !claims
|
||||
.additional_claims()
|
||||
.groups
|
||||
.iter()
|
||||
.any(|group| group == &config.oauth.admin_group)
|
||||
{
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let user = User {
|
||||
subject: claims.subject().to_owned(),
|
||||
username: claims.preferred_username().cloned().ok_or_else(|| {
|
||||
error!("missing preferred_username in claims");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?,
|
||||
email: claims.email().cloned(),
|
||||
};
|
||||
|
||||
let mut session = Session::new();
|
||||
session.insert("user", user).map_err(|e| {
|
||||
error!("failed to insert user into session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
async fn create_user_cookie(
|
||||
session_store: &MemoryStore,
|
||||
session: Session,
|
||||
) -> Result<HeaderMap, StatusCode> {
|
||||
let cookie = session_store
|
||||
.store_session(session)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to store session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
error!("failed to retrieve stored session cookie");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
header::SET_COOKIE,
|
||||
cookie.parse().map_err(|e| {
|
||||
error!("failed to parse cookie: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?,
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
async fn logout(
|
||||
extract::State(session_store): extract::State<MemoryStore>,
|
||||
TypedHeader(cookies): TypedHeader<Cookie>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let cookie = cookies.get(COOKIE_NAME).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let Some(session) = session_store
|
||||
.load_session(cookie.to_string())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to load session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
else {
|
||||
return Ok(StatusCode::OK);
|
||||
};
|
||||
|
||||
session_store.destroy_session(session).await.map_err(|e| {
|
||||
error!("failed to destroy session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
async fn session(user: User) -> Result<impl IntoResponse, StatusCode> {
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
pub fn routes(state: State) -> Router {
|
||||
Router::new()
|
||||
.route("/auth/login", routing::get(login))
|
||||
.route("/auth/callback", routing::get(callback))
|
||||
.route("/auth/logout", routing::get(logout))
|
||||
.route("/auth/session", routing::get(session))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for User
|
||||
where
|
||||
MemoryStore: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let store = MemoryStore::from_ref(state);
|
||||
|
||||
let cookies = parts.extract::<TypedHeader<Cookie>>().await.map_err(|e| {
|
||||
if *e.name() == header::COOKIE {
|
||||
if matches!(e.reason(), TypedHeaderRejectionReason::Missing) {
|
||||
StatusCode::UNAUTHORIZED.into_response()
|
||||
} else {
|
||||
error!("failed to extract cookies: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
} else {
|
||||
error!("failed to extract cookies: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
})?;
|
||||
|
||||
let session_cookie = cookies
|
||||
.get(COOKIE_NAME)
|
||||
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
|
||||
|
||||
let session = store
|
||||
.load_session(session_cookie.to_string())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("failed to load session: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
})?
|
||||
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
|
||||
|
||||
let user = session
|
||||
.get::<Self>("user")
|
||||
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> OptionalFromRequestParts<S> for User
|
||||
where
|
||||
MemoryStore: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
(<Self as FromRequestParts<S>>::from_request_parts(parts, state).await)
|
||||
.map_or(Ok(None), |user| Ok(Some(user)))
|
||||
}
|
||||
}
|
240
src/routes/groups.rs
Normal file
240
src/routes/groups.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::{
|
||||
Json, Router, extract,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing,
|
||||
};
|
||||
use log::error;
|
||||
|
||||
use non_empty_string::NonEmptyString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{models::authelia, routes::auth, state::State};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GroupResponse {
|
||||
groupname: NonEmptyString,
|
||||
users: Vec<NonEmptyString>,
|
||||
}
|
||||
|
||||
impl From<(NonEmptyString, authelia::Group)> for GroupResponse {
|
||||
fn from((groupname, group): (NonEmptyString, authelia::Group)) -> Self {
|
||||
Self {
|
||||
groupname,
|
||||
users: group.users,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type GroupsResponse = HashMap<NonEmptyString, GroupResponse>;
|
||||
|
||||
impl From<authelia::Groups> for GroupsResponse {
|
||||
fn from(groups: authelia::Groups) -> Self {
|
||||
groups
|
||||
.groups
|
||||
.into_iter()
|
||||
.map(|(key, group)| (key.clone(), GroupResponse::from((key, group))))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all(
|
||||
_user: auth::User,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let groups = state.load_groups().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(GroupsResponse::from(groups)))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
_user: auth::User,
|
||||
extract::Path(groupname): extract::Path<NonEmptyString>,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let groups = state.load_groups().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
groups.get(&groupname).cloned().map_or_else(
|
||||
|| Err(StatusCode::NOT_FOUND),
|
||||
|group| Ok(Json(GroupResponse::from((groupname, group))).into_response()),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GroupCreate {
|
||||
groupname: NonEmptyString,
|
||||
users: Vec<NonEmptyString>,
|
||||
}
|
||||
|
||||
impl From<GroupCreate> for authelia::Group {
|
||||
fn from(update: GroupCreate) -> Self {
|
||||
Self {
|
||||
users: update.users,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
_user: auth::User,
|
||||
extract::State(state): extract::State<State>,
|
||||
extract::Json(group_create): extract::Json<GroupCreate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let groupname = group_create.groupname.clone();
|
||||
if groups.contains_key(&groupname) {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let group_created = authelia::Group::from(group_create);
|
||||
|
||||
for username in &group_created.users {
|
||||
if !users.contains_key(username) {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
users
|
||||
.get_mut(username)
|
||||
.unwrap()
|
||||
.groups
|
||||
.push(groupname.clone());
|
||||
}
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(GroupResponse::from((groupname, group_created))).into_response())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GroupUpdate {
|
||||
groupname: Option<NonEmptyString>,
|
||||
users: Vec<NonEmptyString>,
|
||||
}
|
||||
|
||||
impl From<GroupUpdate> for authelia::Group {
|
||||
fn from(update: GroupUpdate) -> Self {
|
||||
Self {
|
||||
users: update.users,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
user: auth::User,
|
||||
extract::Path(groupname): extract::Path<NonEmptyString>,
|
||||
extract::State(state): extract::State<State>,
|
||||
extract::Json(group_update): extract::Json<GroupUpdate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let new_groupname = group_update
|
||||
.groupname
|
||||
.clone()
|
||||
.unwrap_or_else(|| groupname.clone());
|
||||
|
||||
let group_existing = groups.get(&groupname).ok_or(StatusCode::NOT_FOUND)?;
|
||||
let group_updated = authelia::Group::from(group_update);
|
||||
|
||||
if groupname != new_groupname
|
||||
&& (groupname == state.config.oauth.admin_group
|
||||
|| new_groupname == state.config.oauth.admin_group)
|
||||
{
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
if groupname != new_groupname && groups.contains_key(&new_groupname) {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
for user in &group_existing.users {
|
||||
let user = users.get_mut(user).unwrap();
|
||||
let pos = user.groups.iter().position(|g| g == &groupname).unwrap();
|
||||
user.groups.remove(pos);
|
||||
}
|
||||
|
||||
for username in &group_updated.users {
|
||||
if !users.contains_key(username) {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
let user = users.get_mut(username).unwrap();
|
||||
if !user.groups.contains(&new_groupname) {
|
||||
user.groups.push(new_groupname.clone());
|
||||
}
|
||||
}
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if new_groupname == state.config.oauth.admin_group
|
||||
&& !group_updated
|
||||
.users
|
||||
.iter()
|
||||
.any(|group_user| *group_user == *user.username.to_string())
|
||||
{
|
||||
return Ok(Redirect::to("/api/auth/logout").into_response());
|
||||
}
|
||||
|
||||
Ok(Json(GroupResponse::from((new_groupname, group_updated))).into_response())
|
||||
}
|
||||
|
||||
pub async fn delete(
|
||||
_user: auth::User,
|
||||
extract::Path(groupname): extract::Path<String>,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if groupname == state.config.oauth.admin_group {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
if let Some(old_group) = groups.get(&groupname) {
|
||||
for user in &old_group.users {
|
||||
let user = users.get_mut(user).unwrap();
|
||||
let pos = user.groups.iter().position(|g| g == &groupname).unwrap();
|
||||
user.groups.remove(pos);
|
||||
}
|
||||
} else {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT.into_response())
|
||||
}
|
||||
|
||||
pub fn routes(state: State) -> Router {
|
||||
Router::new()
|
||||
.route("/groups", routing::get(get_all))
|
||||
.route("/groups/{username}", routing::get(get))
|
||||
.route("/groups", routing::post(create))
|
||||
.route("/groups/{username}", routing::put(update))
|
||||
.route("/groups/{username}", routing::delete(delete))
|
||||
.with_state(state)
|
||||
}
|
13
src/routes/health.rs
Normal file
13
src/routes/health.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use axum::{Router, http::StatusCode, response::IntoResponse, routing};
|
||||
|
||||
use crate::state::State;
|
||||
|
||||
pub async fn get() -> Result<impl IntoResponse, StatusCode> {
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
pub fn routes(state: State) -> Router {
|
||||
Router::new()
|
||||
.route("/health", routing::get(get))
|
||||
.with_state(state)
|
||||
}
|
21
src/routes/mod.rs
Normal file
21
src/routes/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
mod auth;
|
||||
mod groups;
|
||||
mod health;
|
||||
mod users;
|
||||
|
||||
use axum::Router;
|
||||
|
||||
use crate::state::State;
|
||||
|
||||
pub fn routes(state: State) -> Router {
|
||||
let auth = auth::routes(state.clone());
|
||||
let health = health::routes(state.clone());
|
||||
let users = users::routes(state.clone());
|
||||
let groups = groups::routes(state);
|
||||
|
||||
Router::new()
|
||||
.merge(auth)
|
||||
.merge(health)
|
||||
.merge(users)
|
||||
.merge(groups)
|
||||
}
|
218
src/routes/users.rs
Normal file
218
src/routes/users.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::{
|
||||
Json, Router, extract,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing,
|
||||
};
|
||||
use log::error;
|
||||
|
||||
use non_empty_string::NonEmptyString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
models::authelia, routes::auth, state::State, utils::crypto::generate_random_password_hash,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UserResponse {
|
||||
username: NonEmptyString,
|
||||
displayname: NonEmptyString,
|
||||
email: Option<String>,
|
||||
groups: Option<Vec<NonEmptyString>>,
|
||||
}
|
||||
|
||||
impl From<(NonEmptyString, authelia::User)> for UserResponse {
|
||||
fn from((username, user): (NonEmptyString, authelia::User)) -> Self {
|
||||
Self {
|
||||
username,
|
||||
displayname: user.displayname,
|
||||
email: user.email,
|
||||
groups: Some(user.groups),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UsersResponse = HashMap<NonEmptyString, UserResponse>;
|
||||
|
||||
impl From<authelia::Users> for UsersResponse {
|
||||
fn from(users: authelia::Users) -> Self {
|
||||
users
|
||||
.users
|
||||
.into_iter()
|
||||
.map(|(key, user)| (key.clone(), UserResponse::from((key, user))))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all(
|
||||
_user: auth::User,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let users = state.load_users().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(UsersResponse::from(users)))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
_user: auth::User,
|
||||
extract::Path(username): extract::Path<NonEmptyString>,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let users = state.load_users().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
users.get(&username).cloned().map_or_else(
|
||||
|| Err(StatusCode::NOT_FOUND),
|
||||
|user| Ok(Json(UserResponse::from((username, user))).into_response()),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UserCreate {
|
||||
username: NonEmptyString,
|
||||
displayname: NonEmptyString,
|
||||
email: NonEmptyString,
|
||||
disabled: Option<bool>,
|
||||
groups: Option<Vec<NonEmptyString>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::fallible_impl_from)]
|
||||
impl From<UserCreate> for authelia::User {
|
||||
fn from(user_create: UserCreate) -> Self {
|
||||
Self {
|
||||
displayname: user_create.displayname,
|
||||
email: Some(user_create.email.to_string()),
|
||||
password: NonEmptyString::new(generate_random_password_hash()).unwrap(),
|
||||
disabled: user_create.disabled.unwrap_or(false),
|
||||
groups: user_create.groups.unwrap_or_default(),
|
||||
extra: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
_user: auth::User,
|
||||
extract::State(state): extract::State<State>,
|
||||
extract::Json(user_create): extract::Json<UserCreate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let mut users = state.load_users().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let username = user_create.username.clone();
|
||||
if users.contains_key(&username) {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let user_created = authelia::User::from(user_create);
|
||||
users.users.insert(username.clone(), user_created.clone());
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(UserResponse::from((username, user_created))).into_response())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UserUpdate {
|
||||
username: Option<NonEmptyString>,
|
||||
displayname: NonEmptyString,
|
||||
email: NonEmptyString,
|
||||
disabled: Option<bool>,
|
||||
groups: Option<Vec<NonEmptyString>>,
|
||||
}
|
||||
|
||||
impl From<(Self, UserUpdate)> for authelia::User {
|
||||
fn from((user_existing, user_update): (Self, UserUpdate)) -> Self {
|
||||
Self {
|
||||
displayname: user_update.displayname,
|
||||
email: Some(user_update.email.to_string()),
|
||||
password: user_existing.password,
|
||||
disabled: user_update.disabled.unwrap_or(user_existing.disabled),
|
||||
groups: user_update.groups.unwrap_or(user_existing.groups),
|
||||
extra: user_existing.extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
user: auth::User,
|
||||
extract::Path(username): extract::Path<NonEmptyString>,
|
||||
extract::State(state): extract::State<State>,
|
||||
extract::Json(user_update): extract::Json<UserUpdate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let mut users = state.load_users().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let new_username = user_update
|
||||
.username
|
||||
.clone()
|
||||
.unwrap_or_else(|| username.clone());
|
||||
|
||||
let user_existing = users.remove(&username).ok_or(StatusCode::NOT_FOUND)?;
|
||||
let user_updated = authelia::User::from((user_existing, user_update));
|
||||
|
||||
users
|
||||
.users
|
||||
.insert(new_username.clone(), user_updated.clone());
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if user.username.to_string() == username && (username != new_username || user_updated.disabled)
|
||||
{
|
||||
return Ok(Redirect::to("/api/auth/logout").into_response());
|
||||
}
|
||||
|
||||
Ok(Json(UserResponse::from((new_username, user_updated))).into_response())
|
||||
}
|
||||
|
||||
pub async fn delete(
|
||||
user: auth::User,
|
||||
extract::Path(username): extract::Path<String>,
|
||||
extract::State(state): extract::State<State>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let mut users = state.load_users().map_err(|e| {
|
||||
error!("Failed to read users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if users.remove(&username).is_none() {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
state.save_users(users).map_err(|e| {
|
||||
error!("Failed to save users file: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if user.username.to_string() == username {
|
||||
return Ok(Redirect::to("/api/auth/logout").into_response());
|
||||
}
|
||||
|
||||
Ok(StatusCode::NO_CONTENT.into_response())
|
||||
}
|
||||
|
||||
pub fn routes(state: State) -> Router {
|
||||
Router::new()
|
||||
.route("/users", routing::get(get_all))
|
||||
.route("/users/{username}", routing::get(get))
|
||||
.route("/users", routing::post(create))
|
||||
.route("/users/{username}", routing::put(update))
|
||||
.route("/users/{username}", routing::delete(delete))
|
||||
.with_state(state)
|
||||
}
|
169
src/state.rs
Normal file
169
src/state.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::error::Error;
|
||||
|
||||
use async_session::MemoryStore;
|
||||
use axum::extract::FromRef;
|
||||
use log::error;
|
||||
use openidconnect::{
|
||||
ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
|
||||
IssuerUrl, RedirectUrl, StandardErrorResponse,
|
||||
core::{
|
||||
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
|
||||
CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken,
|
||||
CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse,
|
||||
},
|
||||
reqwest,
|
||||
};
|
||||
use tokio::{
|
||||
spawn,
|
||||
time::{Duration, sleep},
|
||||
};
|
||||
|
||||
use crate::{config::Config, models::authelia};
|
||||
|
||||
pub type OAuthClient<
|
||||
HasAuthUrl = EndpointSet,
|
||||
HasDeviceAuthUrl = EndpointNotSet,
|
||||
HasIntrospectionUrl = EndpointNotSet,
|
||||
HasRevocationUrl = EndpointNotSet,
|
||||
HasTokenUrl = EndpointMaybeSet,
|
||||
HasUserInfoUrl = EndpointMaybeSet,
|
||||
> = openidconnect::Client<
|
||||
EmptyAdditionalClaims,
|
||||
CoreAuthDisplay,
|
||||
CoreGenderClaim,
|
||||
CoreJweContentEncryptionAlgorithm,
|
||||
CoreJsonWebKey,
|
||||
CoreAuthPrompt,
|
||||
StandardErrorResponse<CoreErrorResponseType>,
|
||||
CoreTokenResponse,
|
||||
CoreTokenIntrospectionResponse,
|
||||
CoreRevocableToken,
|
||||
CoreRevocationErrorResponse,
|
||||
HasAuthUrl,
|
||||
HasDeviceAuthUrl,
|
||||
HasIntrospectionUrl,
|
||||
HasRevocationUrl,
|
||||
HasTokenUrl,
|
||||
HasUserInfoUrl,
|
||||
>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct State {
|
||||
pub config: Config,
|
||||
pub oauth_http_client: reqwest::Client,
|
||||
pub oauth_client: OAuthClient,
|
||||
pub session_store: MemoryStore,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let (oauth_http_client, oauth_client) = oauth(&config).await?;
|
||||
let session_store = session_store();
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
oauth_http_client,
|
||||
oauth_client,
|
||||
session_store,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_users(&self) -> Result<authelia::Users, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
|
||||
let users = authelia::Users::from(users_file);
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub fn load_groups(&self) -> Result<authelia::Groups, Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
||||
let groups = authelia::Groups::from(users_file);
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
pub fn load_users_and_groups(
|
||||
&self,
|
||||
) -> Result<(authelia::Users, authelia::Groups), Box<dyn Error + Send + Sync>> {
|
||||
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
||||
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
||||
let users = authelia::Users::from(users_file.clone());
|
||||
let groups = authelia::Groups::from(users_file);
|
||||
Ok((users, groups))
|
||||
}
|
||||
|
||||
pub fn save_users(&self, users: authelia::Users) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let users_file = authelia::UsersFile::from(users);
|
||||
let file_contents = serde_yaml::to_string(&users_file)?;
|
||||
std::fs::write(&self.config.authelia.user_database, file_contents)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for Config {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for reqwest::Client {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_http_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for OAuthClient {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for MemoryStore {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.session_store.clone()
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth(
|
||||
config: &Config,
|
||||
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
|
||||
let oauth_http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.danger_accept_invalid_certs(config.oauth.insecure)
|
||||
.build()?;
|
||||
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(config.oauth.issuer_url.clone())?,
|
||||
&oauth_http_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let oauth_client = OAuthClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(config.oauth.client_id.clone()),
|
||||
Some(ClientSecret::new(config.oauth.client_secret.clone())),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(format!(
|
||||
"{}{}/api/auth/callback",
|
||||
config.server.host, config.server.subpath
|
||||
))?);
|
||||
|
||||
Ok((oauth_http_client, oauth_client))
|
||||
}
|
||||
|
||||
fn session_store() -> MemoryStore {
|
||||
let session_store = MemoryStore::new();
|
||||
|
||||
let session_store_clone = session_store.clone();
|
||||
spawn(async move {
|
||||
loop {
|
||||
match session_store_clone.cleanup().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => error!("Failed to clean up session store: {e}"),
|
||||
}
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
});
|
||||
|
||||
session_store
|
||||
}
|
35
src/utils/crypto.rs
Normal file
35
src/utils/crypto.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use argon2::{
|
||||
Argon2,
|
||||
password_hash::{PasswordHasher, SaltString, rand_core::OsRng},
|
||||
};
|
||||
use passwords::PasswordGenerator;
|
||||
|
||||
pub fn generate_random_password_hash() -> String {
|
||||
let password = (PasswordGenerator {
|
||||
length: 64,
|
||||
numbers: true,
|
||||
lowercase_letters: true,
|
||||
uppercase_letters: true,
|
||||
symbols: true,
|
||||
spaces: false,
|
||||
exclude_similar_characters: false,
|
||||
strict: true,
|
||||
})
|
||||
.generate_one()
|
||||
.unwrap();
|
||||
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
|
||||
let argon2 = Argon2::new(
|
||||
argon2::Algorithm::Argon2id,
|
||||
argon2::Version::V0x13,
|
||||
argon2::Params::new(65536, 3, 4, Some(32)).unwrap(),
|
||||
);
|
||||
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
password_hash
|
||||
}
|
1
src/utils/mod.rs
Normal file
1
src/utils/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod crypto;
|
Reference in New Issue
Block a user