Initial commit

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2025-06-04 22:50:18 +01:00
commit ec7055d5ff
22 changed files with 5558 additions and 0 deletions

67
src/config.rs Normal file
View File

@@ -0,0 +1,67 @@
use clap::Parser;
use serde::Deserialize;
use std::{
fs,
net::{IpAddr, Ipv4Addr},
path::PathBuf,
};
#[derive(Clone, Deserialize)]
pub struct ServerConfig {
pub host: String,
#[serde(default = "default_server_address")]
pub address: std::net::IpAddr,
#[serde(default = "default_server_port")]
pub port: u16,
#[serde(default)]
pub subpath: String,
}
const fn default_server_address() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
const fn default_server_port() -> u16 {
8080
}
#[derive(Clone, Deserialize)]
pub struct OAuthConfig {
pub issuer_url: String,
pub client_id: String,
pub client_secret: String,
#[serde(default)]
pub insecure: bool,
pub admin_group: String,
}
#[derive(Clone, Deserialize)]
pub struct AutheliaConfig {
pub user_database: PathBuf,
}
#[derive(Clone, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub oauth: OAuthConfig,
pub authelia: AutheliaConfig,
}
impl Config {
pub fn from_yaml(path: &PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let contents = fs::read_to_string(path)?;
let config = serde_yaml::from_str(&contents)?;
Ok(config)
}
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None, author)]
pub struct Args {
/// Path to the YAML config file
#[arg(short, long, value_name = "FILE", default_value = "config.yaml")]
pub config: PathBuf,
/// Path to the log4rs config file
#[arg(short, long, value_name = "FILE", default_value = "log4rs.yaml")]
pub log_config: PathBuf,
}

36
src/main.rs Normal file
View File

@@ -0,0 +1,36 @@
#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
#![allow(clippy::missing_docs_in_private_items)]
mod config;
mod models;
mod routes;
mod state;
mod utils;
use axum::serve;
use clap::Parser;
use log::info;
use log4rs::config::Deserializers;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use config::{Args, Config};
use state::State;
#[tokio::main]
async fn main() {
let args = Args::parse();
log4rs::init_file(args.log_config, Deserializers::default()).unwrap();
let config = Config::from_yaml(&args.config).unwrap();
let state = State::from_config(config.clone()).await.unwrap();
let routes = routes::routes(state);
let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes);
let addr = SocketAddr::from((config.server.address, config.server.port));
let listener = TcpListener::bind(addr).await.unwrap();
info!("Listening on {}", listener.local_addr().unwrap());
serve(listener, app).await.unwrap();
}

160
src/models/authelia.rs Normal file
View File

@@ -0,0 +1,160 @@
use non_empty_string::NonEmptyString;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsersFile {
pub users: HashMap<NonEmptyString, UserFile>,
#[serde(flatten)]
pub extra: Option<HashMap<NonEmptyString, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserFile {
pub displayname: NonEmptyString,
pub password: NonEmptyString,
pub email: Option<String>,
pub disabled: Option<bool>,
pub groups: Option<Vec<NonEmptyString>>,
#[serde(flatten)]
pub extra: Option<HashMap<NonEmptyString, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub displayname: NonEmptyString,
pub email: Option<String>,
pub password: NonEmptyString,
pub disabled: bool,
pub groups: Vec<NonEmptyString>,
#[serde(flatten)]
pub extra: Option<HashMap<NonEmptyString, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Users {
pub users: HashMap<NonEmptyString, User>,
#[serde(flatten)]
pub extra: Option<HashMap<NonEmptyString, Value>>,
}
impl Deref for Users {
type Target = HashMap<NonEmptyString, User>;
fn deref(&self) -> &Self::Target {
&self.users
}
}
impl DerefMut for Users {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.users
}
}
impl From<UserFile> for User {
fn from(user_file: UserFile) -> Self {
Self {
displayname: user_file.displayname,
email: user_file.email,
password: user_file.password,
disabled: user_file.disabled.unwrap_or(false),
groups: user_file.groups.unwrap_or_default(),
extra: user_file.extra,
}
}
}
impl From<UsersFile> for Users {
fn from(users_file: UsersFile) -> Self {
Self {
users: users_file
.users
.into_iter()
.map(|(key, user)| (key, User::from(user)))
.collect(),
extra: users_file.extra,
}
}
}
impl From<User> for UserFile {
fn from(user: User) -> Self {
Self {
displayname: user.displayname,
email: user.email,
password: user.password,
disabled: if user.disabled { Some(true) } else { None },
groups: if user.groups.is_empty() {
None
} else {
Some(user.groups)
},
extra: user.extra,
}
}
}
impl From<Users> for UsersFile {
fn from(users: Users) -> Self {
Self {
users: users
.users
.into_iter()
.map(|(key, user)| (key, UserFile::from(user)))
.collect(),
extra: users.extra,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Group {
pub users: Vec<NonEmptyString>,
}
pub struct Groups {
pub groups: HashMap<NonEmptyString, Group>,
}
impl Deref for Groups {
type Target = HashMap<NonEmptyString, Group>;
fn deref(&self) -> &Self::Target {
&self.groups
}
}
impl DerefMut for Groups {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.groups
}
}
impl From<UsersFile> for Groups {
fn from(users_file: UsersFile) -> Self {
users_file.users.into_iter().fold(
Self {
groups: HashMap::new(),
},
|mut acc, (key, user)| {
for group in user.groups.unwrap_or_default() {
acc.entry(group)
.or_insert_with(|| Group { users: Vec::new() })
.users
.push(key.clone());
}
acc
},
)
}
}

1
src/models/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod authelia;

442
src/routes/auth.rs Normal file
View File

@@ -0,0 +1,442 @@
use std::{borrow::Cow, convert::Infallible};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
Json, RequestPartsExt, Router,
extract::{self, FromRef, FromRequestParts, OptionalFromRequestParts},
http::{HeaderMap, StatusCode, header, request::Parts},
response::{IntoResponse, Redirect, Response},
routing,
};
use axum_extra::{TypedHeader, headers::Cookie, typed_header::TypedHeaderRejectionReason};
use log::error;
use openidconnect::{
AccessTokenHash, AdditionalClaims, AuthorizationCode, CsrfToken, EndUserEmail, EndUserUsername,
Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, Scope, SubjectIdentifier,
TokenResponse, UserInfoClaims,
core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse},
reqwest,
};
use serde::{Deserialize, Serialize};
use crate::{
config::Config,
state::{OAuthClient, State},
};
static COOKIE_NAME: &str = "glyph_session";
#[derive(Clone, Deserialize, Serialize)]
pub struct User {
pub subject: SubjectIdentifier,
pub username: EndUserUsername,
pub email: Option<EndUserEmail>,
}
async fn login(
extract::State(oauth_client): extract::State<OAuthClient>,
extract::State(session_store): extract::State<MemoryStore>,
) -> Result<impl IntoResponse, StatusCode> {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = oauth_client
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.set_pkce_challenge(pkce_challenge)
.set_redirect_uri(Cow::Borrowed(oauth_client.redirect_uri().ok_or_else(
|| {
error!("missing redirect URI");
StatusCode::INTERNAL_SERVER_ERROR
},
)?))
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.add_scope(Scope::new("groups".to_string()))
.url();
let session = create_login_session(pkce_verifier, csrf_token, nonce)?;
let headers = create_login_cookie(&session_store, session).await?;
Ok((headers, Redirect::to(auth_url.as_str())))
}
fn create_login_session(
pkce_verifier: PkceCodeVerifier,
csrf_token: CsrfToken,
nonce: Nonce,
) -> Result<Session, StatusCode> {
let mut session = Session::new();
session
.insert("pkce_verifier", pkce_verifier)
.map_err(|e| {
error!("failed to insert pkce_verifier into session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
session.insert("csrf_token", csrf_token).map_err(|e| {
error!("failed to insert csrf_token into session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
session.insert("nonce", nonce).map_err(|e| {
error!("failed to insert nonce into session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(session)
}
async fn create_login_cookie(
session_store: &MemoryStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store
.store_session(session)
.await
.map_err(|e| {
error!("failed to store session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or_else(|| {
error!("failed to retrieve stored session cookie");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
header::SET_COOKIE,
cookie.parse().map_err(|e| {
error!("failed to parse cookie: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?,
);
Ok(headers)
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct CallbackParams {
code: String,
state: String,
}
async fn callback(
extract::Query(params): extract::Query<CallbackParams>,
extract::State(http_client): extract::State<reqwest::Client>,
extract::State(oauth_client): extract::State<OAuthClient>,
extract::State(session_store): extract::State<MemoryStore>,
extract::State(config): extract::State<Config>,
TypedHeader(cookies): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
let cookie = cookies
.get(COOKIE_NAME)
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
let session = session_store
.load_session(cookie)
.await
.map_err(|e| {
error!("failed to load session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or(StatusCode::UNAUTHORIZED)?;
let (csrf_token, pkce_verifier, nonce) = retrieve_login_session(&session)?;
session_store.destroy_session(session).await.map_err(|e| {
error!("failed to destroy session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let token_response = validate_claims(
&http_client,
&oauth_client,
params,
csrf_token,
pkce_verifier,
nonce,
)
.await?;
let claims = oauth_client
.user_info(token_response.access_token().to_owned(), None)
.map_err(|e| {
error!("failed to create userinfo request: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.request_async(&http_client)
.await
.map_err(|e| {
error!("failed to request user info: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let session = create_user_session(&config, &claims)?;
let headers = create_user_cookie(&session_store, session).await?;
Ok((headers, StatusCode::OK))
}
fn retrieve_login_session(
session: &Session,
) -> Result<(CsrfToken, PkceCodeVerifier, Nonce), StatusCode> {
let csrf_token = session.get::<CsrfToken>("csrf_token").ok_or_else(|| {
error!("failed to get csrf_token from session");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let pkce_verifier = session
.get::<PkceCodeVerifier>("pkce_verifier")
.ok_or_else(|| {
error!("failed to get pkce_verifier from session");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let nonce = session.get::<Nonce>("nonce").ok_or_else(|| {
error!("failed to get nonce from session");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok((csrf_token, pkce_verifier, nonce))
}
async fn validate_claims(
http_client: &reqwest::Client,
oauth_client: &OAuthClient,
params: CallbackParams,
csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier,
nonce: Nonce,
) -> Result<CoreTokenResponse, StatusCode> {
if *csrf_token.secret() != params.state {
error!("csrf_token mismatch");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
let token_response = oauth_client
.exchange_code(AuthorizationCode::new(params.code.clone()))
.map_err(|e| {
error!("failed to exchange code: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.set_pkce_verifier(pkce_verifier)
.request_async(http_client)
.await
.map_err(|e| {
error!("failed to request token: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let id_token = token_response.id_token().ok_or_else(|| {
error!("missing id_token in token response");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let id_token_verifier = oauth_client.id_token_verifier();
let claims = id_token
.claims(&id_token_verifier, &nonce)
.map_err(|e| {
error!("failed to verify id_token: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.to_owned();
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
id_token.signing_alg().map_err(|e| {
error!("failed to get signing algorithm: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?,
id_token.signing_key(&id_token_verifier).map_err(|e| {
error!("failed to get signing key: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?,
)
.map_err(|e| {
error!("failed to compute access token hash: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if actual_access_token_hash != *expected_access_token_hash {
error!("access token hash mismatch");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
Ok(token_response)
}
#[derive(Debug, Serialize, Deserialize)]
struct ExtraClaims {
groups: Vec<String>,
}
impl AdditionalClaims for ExtraClaims {}
fn create_user_session(
config: &Config,
claims: &UserInfoClaims<ExtraClaims, CoreGenderClaim>,
) -> Result<Session, StatusCode> {
if !claims
.additional_claims()
.groups
.iter()
.any(|group| group == &config.oauth.admin_group)
{
return Err(StatusCode::FORBIDDEN);
}
let user = User {
subject: claims.subject().to_owned(),
username: claims.preferred_username().cloned().ok_or_else(|| {
error!("missing preferred_username in claims");
StatusCode::INTERNAL_SERVER_ERROR
})?,
email: claims.email().cloned(),
};
let mut session = Session::new();
session.insert("user", user).map_err(|e| {
error!("failed to insert user into session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(session)
}
async fn create_user_cookie(
session_store: &MemoryStore,
session: Session,
) -> Result<HeaderMap, StatusCode> {
let cookie = session_store
.store_session(session)
.await
.map_err(|e| {
error!("failed to store session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or_else(|| {
error!("failed to retrieve stored session cookie");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let cookie = format!("{COOKIE_NAME}={cookie}; HttpOnly; SameSite=Lax; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
header::SET_COOKIE,
cookie.parse().map_err(|e| {
error!("failed to parse cookie: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?,
);
Ok(headers)
}
async fn logout(
extract::State(session_store): extract::State<MemoryStore>,
TypedHeader(cookies): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
let cookie = cookies.get(COOKIE_NAME).ok_or(StatusCode::UNAUTHORIZED)?;
let Some(session) = session_store
.load_session(cookie.to_string())
.await
.map_err(|e| {
error!("failed to load session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
else {
return Ok(StatusCode::OK);
};
session_store.destroy_session(session).await.map_err(|e| {
error!("failed to destroy session: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(StatusCode::OK)
}
async fn session(user: User) -> Result<impl IntoResponse, StatusCode> {
Ok(Json(user))
}
pub fn routes(state: State) -> Router {
Router::new()
.route("/auth/login", routing::get(login))
.route("/auth/callback", routing::get(callback))
.route("/auth/logout", routing::get(logout))
.route("/auth/session", routing::get(session))
.with_state(state)
}
impl<S> FromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = MemoryStore::from_ref(state);
let cookies = parts.extract::<TypedHeader<Cookie>>().await.map_err(|e| {
if *e.name() == header::COOKIE {
if matches!(e.reason(), TypedHeaderRejectionReason::Missing) {
StatusCode::UNAUTHORIZED.into_response()
} else {
error!("failed to extract cookies: {e}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
} else {
error!("failed to extract cookies: {e}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
})?;
let session_cookie = cookies
.get(COOKIE_NAME)
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
let session = store
.load_session(session_cookie.to_string())
.await
.map_err(|e| {
error!("failed to load session: {e}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
let user = session
.get::<Self>("user")
.ok_or_else(|| StatusCode::UNAUTHORIZED.into_response())?;
Ok(user)
}
}
impl<S> OptionalFromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<Self>, Self::Rejection> {
(<Self as FromRequestParts<S>>::from_request_parts(parts, state).await)
.map_or(Ok(None), |user| Ok(Some(user)))
}
}

240
src/routes/groups.rs Normal file
View File

@@ -0,0 +1,240 @@
use std::collections::HashMap;
use axum::{
Json, Router, extract,
http::StatusCode,
response::{IntoResponse, Redirect},
routing,
};
use log::error;
use non_empty_string::NonEmptyString;
use serde::{Deserialize, Serialize};
use crate::{models::authelia, routes::auth, state::State};
#[derive(Debug, Serialize)]
struct GroupResponse {
groupname: NonEmptyString,
users: Vec<NonEmptyString>,
}
impl From<(NonEmptyString, authelia::Group)> for GroupResponse {
fn from((groupname, group): (NonEmptyString, authelia::Group)) -> Self {
Self {
groupname,
users: group.users,
}
}
}
type GroupsResponse = HashMap<NonEmptyString, GroupResponse>;
impl From<authelia::Groups> for GroupsResponse {
fn from(groups: authelia::Groups) -> Self {
groups
.groups
.into_iter()
.map(|(key, group)| (key.clone(), GroupResponse::from((key, group))))
.collect()
}
}
pub async fn get_all(
_user: auth::User,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let groups = state.load_groups().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(GroupsResponse::from(groups)))
}
pub async fn get(
_user: auth::User,
extract::Path(groupname): extract::Path<NonEmptyString>,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let groups = state.load_groups().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
groups.get(&groupname).cloned().map_or_else(
|| Err(StatusCode::NOT_FOUND),
|group| Ok(Json(GroupResponse::from((groupname, group))).into_response()),
)
}
#[derive(Debug, Deserialize)]
pub struct GroupCreate {
groupname: NonEmptyString,
users: Vec<NonEmptyString>,
}
impl From<GroupCreate> for authelia::Group {
fn from(update: GroupCreate) -> Self {
Self {
users: update.users,
}
}
}
pub async fn create(
_user: auth::User,
extract::State(state): extract::State<State>,
extract::Json(group_create): extract::Json<GroupCreate>,
) -> Result<impl IntoResponse, StatusCode> {
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let groupname = group_create.groupname.clone();
if groups.contains_key(&groupname) {
return Err(StatusCode::CONFLICT);
}
let group_created = authelia::Group::from(group_create);
for username in &group_created.users {
if !users.contains_key(username) {
return Err(StatusCode::NOT_FOUND);
}
users
.get_mut(username)
.unwrap()
.groups
.push(groupname.clone());
}
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(GroupResponse::from((groupname, group_created))).into_response())
}
#[derive(Debug, Deserialize)]
pub struct GroupUpdate {
groupname: Option<NonEmptyString>,
users: Vec<NonEmptyString>,
}
impl From<GroupUpdate> for authelia::Group {
fn from(update: GroupUpdate) -> Self {
Self {
users: update.users,
}
}
}
pub async fn update(
user: auth::User,
extract::Path(groupname): extract::Path<NonEmptyString>,
extract::State(state): extract::State<State>,
extract::Json(group_update): extract::Json<GroupUpdate>,
) -> Result<impl IntoResponse, StatusCode> {
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let new_groupname = group_update
.groupname
.clone()
.unwrap_or_else(|| groupname.clone());
let group_existing = groups.get(&groupname).ok_or(StatusCode::NOT_FOUND)?;
let group_updated = authelia::Group::from(group_update);
if groupname != new_groupname
&& (groupname == state.config.oauth.admin_group
|| new_groupname == state.config.oauth.admin_group)
{
return Err(StatusCode::FORBIDDEN);
}
if groupname != new_groupname && groups.contains_key(&new_groupname) {
return Err(StatusCode::CONFLICT);
}
for user in &group_existing.users {
let user = users.get_mut(user).unwrap();
let pos = user.groups.iter().position(|g| g == &groupname).unwrap();
user.groups.remove(pos);
}
for username in &group_updated.users {
if !users.contains_key(username) {
return Err(StatusCode::NOT_FOUND);
}
let user = users.get_mut(username).unwrap();
if !user.groups.contains(&new_groupname) {
user.groups.push(new_groupname.clone());
}
}
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if new_groupname == state.config.oauth.admin_group
&& !group_updated
.users
.iter()
.any(|group_user| *group_user == *user.username.to_string())
{
return Ok(Redirect::to("/api/auth/logout").into_response());
}
Ok(Json(GroupResponse::from((new_groupname, group_updated))).into_response())
}
pub async fn delete(
_user: auth::User,
extract::Path(groupname): extract::Path<String>,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let (mut users, groups) = state.load_users_and_groups().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if groupname == state.config.oauth.admin_group {
return Err(StatusCode::FORBIDDEN);
}
if let Some(old_group) = groups.get(&groupname) {
for user in &old_group.users {
let user = users.get_mut(user).unwrap();
let pos = user.groups.iter().position(|g| g == &groupname).unwrap();
user.groups.remove(pos);
}
} else {
return Err(StatusCode::NOT_FOUND);
}
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(StatusCode::NO_CONTENT.into_response())
}
pub fn routes(state: State) -> Router {
Router::new()
.route("/groups", routing::get(get_all))
.route("/groups/{username}", routing::get(get))
.route("/groups", routing::post(create))
.route("/groups/{username}", routing::put(update))
.route("/groups/{username}", routing::delete(delete))
.with_state(state)
}

13
src/routes/health.rs Normal file
View File

@@ -0,0 +1,13 @@
use axum::{Router, http::StatusCode, response::IntoResponse, routing};
use crate::state::State;
pub async fn get() -> Result<impl IntoResponse, StatusCode> {
Ok(StatusCode::OK)
}
pub fn routes(state: State) -> Router {
Router::new()
.route("/health", routing::get(get))
.with_state(state)
}

21
src/routes/mod.rs Normal file
View File

@@ -0,0 +1,21 @@
mod auth;
mod groups;
mod health;
mod users;
use axum::Router;
use crate::state::State;
pub fn routes(state: State) -> Router {
let auth = auth::routes(state.clone());
let health = health::routes(state.clone());
let users = users::routes(state.clone());
let groups = groups::routes(state);
Router::new()
.merge(auth)
.merge(health)
.merge(users)
.merge(groups)
}

218
src/routes/users.rs Normal file
View File

@@ -0,0 +1,218 @@
use std::collections::HashMap;
use axum::{
Json, Router, extract,
http::StatusCode,
response::{IntoResponse, Redirect},
routing,
};
use log::error;
use non_empty_string::NonEmptyString;
use serde::{Deserialize, Serialize};
use crate::{
models::authelia, routes::auth, state::State, utils::crypto::generate_random_password_hash,
};
#[derive(Debug, Serialize)]
struct UserResponse {
username: NonEmptyString,
displayname: NonEmptyString,
email: Option<String>,
groups: Option<Vec<NonEmptyString>>,
}
impl From<(NonEmptyString, authelia::User)> for UserResponse {
fn from((username, user): (NonEmptyString, authelia::User)) -> Self {
Self {
username,
displayname: user.displayname,
email: user.email,
groups: Some(user.groups),
}
}
}
type UsersResponse = HashMap<NonEmptyString, UserResponse>;
impl From<authelia::Users> for UsersResponse {
fn from(users: authelia::Users) -> Self {
users
.users
.into_iter()
.map(|(key, user)| (key.clone(), UserResponse::from((key, user))))
.collect()
}
}
pub async fn get_all(
_user: auth::User,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let users = state.load_users().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(UsersResponse::from(users)))
}
pub async fn get(
_user: auth::User,
extract::Path(username): extract::Path<NonEmptyString>,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let users = state.load_users().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
users.get(&username).cloned().map_or_else(
|| Err(StatusCode::NOT_FOUND),
|user| Ok(Json(UserResponse::from((username, user))).into_response()),
)
}
#[derive(Debug, Deserialize)]
pub struct UserCreate {
username: NonEmptyString,
displayname: NonEmptyString,
email: NonEmptyString,
disabled: Option<bool>,
groups: Option<Vec<NonEmptyString>>,
}
#[allow(clippy::fallible_impl_from)]
impl From<UserCreate> for authelia::User {
fn from(user_create: UserCreate) -> Self {
Self {
displayname: user_create.displayname,
email: Some(user_create.email.to_string()),
password: NonEmptyString::new(generate_random_password_hash()).unwrap(),
disabled: user_create.disabled.unwrap_or(false),
groups: user_create.groups.unwrap_or_default(),
extra: None,
}
}
}
pub async fn create(
_user: auth::User,
extract::State(state): extract::State<State>,
extract::Json(user_create): extract::Json<UserCreate>,
) -> Result<impl IntoResponse, StatusCode> {
let mut users = state.load_users().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let username = user_create.username.clone();
if users.contains_key(&username) {
return Err(StatusCode::CONFLICT);
}
let user_created = authelia::User::from(user_create);
users.users.insert(username.clone(), user_created.clone());
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(UserResponse::from((username, user_created))).into_response())
}
#[derive(Debug, Deserialize)]
pub struct UserUpdate {
username: Option<NonEmptyString>,
displayname: NonEmptyString,
email: NonEmptyString,
disabled: Option<bool>,
groups: Option<Vec<NonEmptyString>>,
}
impl From<(Self, UserUpdate)> for authelia::User {
fn from((user_existing, user_update): (Self, UserUpdate)) -> Self {
Self {
displayname: user_update.displayname,
email: Some(user_update.email.to_string()),
password: user_existing.password,
disabled: user_update.disabled.unwrap_or(user_existing.disabled),
groups: user_update.groups.unwrap_or(user_existing.groups),
extra: user_existing.extra,
}
}
}
pub async fn update(
user: auth::User,
extract::Path(username): extract::Path<NonEmptyString>,
extract::State(state): extract::State<State>,
extract::Json(user_update): extract::Json<UserUpdate>,
) -> Result<impl IntoResponse, StatusCode> {
let mut users = state.load_users().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let new_username = user_update
.username
.clone()
.unwrap_or_else(|| username.clone());
let user_existing = users.remove(&username).ok_or(StatusCode::NOT_FOUND)?;
let user_updated = authelia::User::from((user_existing, user_update));
users
.users
.insert(new_username.clone(), user_updated.clone());
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if user.username.to_string() == username && (username != new_username || user_updated.disabled)
{
return Ok(Redirect::to("/api/auth/logout").into_response());
}
Ok(Json(UserResponse::from((new_username, user_updated))).into_response())
}
pub async fn delete(
user: auth::User,
extract::Path(username): extract::Path<String>,
extract::State(state): extract::State<State>,
) -> Result<impl IntoResponse, StatusCode> {
let mut users = state.load_users().map_err(|e| {
error!("Failed to read users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if users.remove(&username).is_none() {
return Err(StatusCode::NOT_FOUND);
}
state.save_users(users).map_err(|e| {
error!("Failed to save users file: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if user.username.to_string() == username {
return Ok(Redirect::to("/api/auth/logout").into_response());
}
Ok(StatusCode::NO_CONTENT.into_response())
}
pub fn routes(state: State) -> Router {
Router::new()
.route("/users", routing::get(get_all))
.route("/users/{username}", routing::get(get))
.route("/users", routing::post(create))
.route("/users/{username}", routing::put(update))
.route("/users/{username}", routing::delete(delete))
.with_state(state)
}

169
src/state.rs Normal file
View File

@@ -0,0 +1,169 @@
use std::error::Error;
use async_session::MemoryStore;
use axum::extract::FromRef;
use log::error;
use openidconnect::{
ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
IssuerUrl, RedirectUrl, StandardErrorResponse,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken,
CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse,
},
reqwest,
};
use tokio::{
spawn,
time::{Duration, sleep},
};
use crate::{config::Config, models::authelia};
pub type OAuthClient<
HasAuthUrl = EndpointSet,
HasDeviceAuthUrl = EndpointNotSet,
HasIntrospectionUrl = EndpointNotSet,
HasRevocationUrl = EndpointNotSet,
HasTokenUrl = EndpointMaybeSet,
HasUserInfoUrl = EndpointMaybeSet,
> = openidconnect::Client<
EmptyAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
CoreTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
HasAuthUrl,
HasDeviceAuthUrl,
HasIntrospectionUrl,
HasRevocationUrl,
HasTokenUrl,
HasUserInfoUrl,
>;
#[derive(Clone)]
pub struct State {
pub config: Config,
pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient,
pub session_store: MemoryStore,
}
impl State {
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> {
let (oauth_http_client, oauth_client) = oauth(&config).await?;
let session_store = session_store();
Ok(Self {
config,
oauth_http_client,
oauth_client,
session_store,
})
}
pub fn load_users(&self) -> Result<authelia::Users, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
let users = authelia::Users::from(users_file);
Ok(users)
}
pub fn load_groups(&self) -> Result<authelia::Groups, Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
let groups = authelia::Groups::from(users_file);
Ok(groups)
}
pub fn load_users_and_groups(
&self,
) -> Result<(authelia::Users, authelia::Groups), Box<dyn Error + Send + Sync>> {
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
let users = authelia::Users::from(users_file.clone());
let groups = authelia::Groups::from(users_file);
Ok((users, groups))
}
pub fn save_users(&self, users: authelia::Users) -> Result<(), Box<dyn Error + Send + Sync>> {
let users_file = authelia::UsersFile::from(users);
let file_contents = serde_yaml::to_string(&users_file)?;
std::fs::write(&self.config.authelia.user_database, file_contents)?;
Ok(())
}
}
impl FromRef<State> for Config {
fn from_ref(state: &State) -> Self {
state.config.clone()
}
}
impl FromRef<State> for reqwest::Client {
fn from_ref(state: &State) -> Self {
state.oauth_http_client.clone()
}
}
impl FromRef<State> for OAuthClient {
fn from_ref(state: &State) -> Self {
state.oauth_client.clone()
}
}
impl FromRef<State> for MemoryStore {
fn from_ref(state: &State) -> Self {
state.session_store.clone()
}
}
async fn oauth(
config: &Config,
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
let oauth_http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(config.oauth.insecure)
.build()?;
let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(config.oauth.issuer_url.clone())?,
&oauth_http_client,
)
.await?;
let oauth_client = OAuthClient::from_provider_metadata(
provider_metadata,
ClientId::new(config.oauth.client_id.clone()),
Some(ClientSecret::new(config.oauth.client_secret.clone())),
)
.set_redirect_uri(RedirectUrl::new(format!(
"{}{}/api/auth/callback",
config.server.host, config.server.subpath
))?);
Ok((oauth_http_client, oauth_client))
}
fn session_store() -> MemoryStore {
let session_store = MemoryStore::new();
let session_store_clone = session_store.clone();
spawn(async move {
loop {
match session_store_clone.cleanup().await {
Ok(()) => {}
Err(e) => error!("Failed to clean up session store: {e}"),
}
sleep(Duration::from_secs(60)).await;
}
});
session_store
}

35
src/utils/crypto.rs Normal file
View File

@@ -0,0 +1,35 @@
use argon2::{
Argon2,
password_hash::{PasswordHasher, SaltString, rand_core::OsRng},
};
use passwords::PasswordGenerator;
pub fn generate_random_password_hash() -> String {
let password = (PasswordGenerator {
length: 64,
numbers: true,
lowercase_letters: true,
uppercase_letters: true,
symbols: true,
spaces: false,
exclude_similar_characters: false,
strict: true,
})
.generate_one()
.unwrap();
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
argon2::Version::V0x13,
argon2::Params::new(65536, 3, 4, Some(32)).unwrap(),
);
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string();
password_hash
}

1
src/utils/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod crypto;