use std::error::Error; use async_session::MemoryStore; use axum::extract::FromRef; use log::error; use openidconnect::{ ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, RedirectUrl, StandardErrorResponse, core::{ CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse, }, reqwest, }; use tokio::{ spawn, time::{Duration, sleep}, }; use crate::{config::Config, models::authelia}; pub type OAuthClient< HasAuthUrl = EndpointSet, HasDeviceAuthUrl = EndpointNotSet, HasIntrospectionUrl = EndpointNotSet, HasRevocationUrl = EndpointNotSet, HasTokenUrl = EndpointMaybeSet, HasUserInfoUrl = EndpointMaybeSet, > = openidconnect::Client< EmptyAdditionalClaims, CoreAuthDisplay, CoreGenderClaim, CoreJweContentEncryptionAlgorithm, CoreJsonWebKey, CoreAuthPrompt, StandardErrorResponse, CoreTokenResponse, CoreTokenIntrospectionResponse, CoreRevocableToken, CoreRevocationErrorResponse, HasAuthUrl, HasDeviceAuthUrl, HasIntrospectionUrl, HasRevocationUrl, HasTokenUrl, HasUserInfoUrl, >; #[derive(Clone)] pub struct State { pub config: Config, pub oauth_http_client: reqwest::Client, pub oauth_client: OAuthClient, pub session_store: MemoryStore, } impl State { pub async fn from_config(config: Config) -> Result> { let (oauth_http_client, oauth_client) = oauth(&config).await?; let session_store = session_store(); Ok(Self { config, oauth_http_client, oauth_client, session_store, }) } pub fn load_users(&self) -> Result> { let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?; let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?; let users = authelia::Users::from(users_file); Ok(users) } pub fn load_groups(&self) -> Result> { let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?; let users_file = serde_yaml::from_str::(&file_contents)?; let groups = authelia::Groups::from(users_file); Ok(groups) } pub fn load_users_and_groups( &self, ) -> Result<(authelia::Users, authelia::Groups), Box> { let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?; let users_file = serde_yaml::from_str::(&file_contents)?; let users = authelia::Users::from(users_file.clone()); let groups = authelia::Groups::from(users_file); Ok((users, groups)) } pub fn save_users(&self, users: authelia::Users) -> Result<(), Box> { let users_file = authelia::UsersFile::from(users); let file_contents = serde_yaml::to_string(&users_file)?; std::fs::write(&self.config.authelia.user_database, file_contents)?; Ok(()) } } impl FromRef for Config { fn from_ref(state: &State) -> Self { state.config.clone() } } impl FromRef for reqwest::Client { fn from_ref(state: &State) -> Self { state.oauth_http_client.clone() } } impl FromRef for OAuthClient { fn from_ref(state: &State) -> Self { state.oauth_client.clone() } } impl FromRef for MemoryStore { fn from_ref(state: &State) -> Self { state.session_store.clone() } } async fn oauth( config: &Config, ) -> Result<(reqwest::Client, OAuthClient), Box> { let oauth_http_client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .danger_accept_invalid_certs(config.oauth.insecure) .build()?; let provider_metadata = CoreProviderMetadata::discover_async( IssuerUrl::new(config.oauth.issuer_url.clone())?, &oauth_http_client, ) .await?; let oauth_client = OAuthClient::from_provider_metadata( provider_metadata, ClientId::new(config.oauth.client_id.clone()), Some(ClientSecret::new(config.oauth.client_secret.clone())), ) .set_redirect_uri(RedirectUrl::new(format!( "{}{}/api/auth/callback", config.server.host, config.server.subpath ))?); Ok((oauth_http_client, oauth_client)) } fn session_store() -> MemoryStore { let session_store = MemoryStore::new(); let session_store_clone = session_store.clone(); spawn(async move { loop { match session_store_clone.cleanup().await { Ok(()) => {} Err(e) => error!("Failed to clean up session store: {e}"), } sleep(Duration::from_secs(60)).await; } }); session_store }