170 lines
5.1 KiB
Rust
170 lines
5.1 KiB
Rust
use std::error::Error;
|
|
|
|
use async_session::MemoryStore;
|
|
use axum::extract::FromRef;
|
|
use log::error;
|
|
use openidconnect::{
|
|
ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
|
|
IssuerUrl, RedirectUrl, StandardErrorResponse,
|
|
core::{
|
|
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
|
|
CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreRevocableToken,
|
|
CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse,
|
|
},
|
|
reqwest,
|
|
};
|
|
use tokio::{
|
|
spawn,
|
|
time::{Duration, sleep},
|
|
};
|
|
|
|
use crate::{config::Config, models::authelia};
|
|
|
|
pub type OAuthClient<
|
|
HasAuthUrl = EndpointSet,
|
|
HasDeviceAuthUrl = EndpointNotSet,
|
|
HasIntrospectionUrl = EndpointNotSet,
|
|
HasRevocationUrl = EndpointNotSet,
|
|
HasTokenUrl = EndpointMaybeSet,
|
|
HasUserInfoUrl = EndpointMaybeSet,
|
|
> = openidconnect::Client<
|
|
EmptyAdditionalClaims,
|
|
CoreAuthDisplay,
|
|
CoreGenderClaim,
|
|
CoreJweContentEncryptionAlgorithm,
|
|
CoreJsonWebKey,
|
|
CoreAuthPrompt,
|
|
StandardErrorResponse<CoreErrorResponseType>,
|
|
CoreTokenResponse,
|
|
CoreTokenIntrospectionResponse,
|
|
CoreRevocableToken,
|
|
CoreRevocationErrorResponse,
|
|
HasAuthUrl,
|
|
HasDeviceAuthUrl,
|
|
HasIntrospectionUrl,
|
|
HasRevocationUrl,
|
|
HasTokenUrl,
|
|
HasUserInfoUrl,
|
|
>;
|
|
|
|
#[derive(Clone)]
|
|
pub struct State {
|
|
pub config: Config,
|
|
pub oauth_http_client: reqwest::Client,
|
|
pub oauth_client: OAuthClient,
|
|
pub session_store: MemoryStore,
|
|
}
|
|
|
|
impl State {
|
|
pub async fn from_config(config: Config) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
|
let (oauth_http_client, oauth_client) = oauth(&config).await?;
|
|
let session_store = session_store();
|
|
|
|
Ok(Self {
|
|
config,
|
|
oauth_http_client,
|
|
oauth_client,
|
|
session_store,
|
|
})
|
|
}
|
|
|
|
pub fn load_users(&self) -> Result<authelia::Users, Box<dyn Error + Send + Sync>> {
|
|
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
|
let users_file: authelia::UsersFile = serde_yaml::from_str(&file_contents)?;
|
|
let users = authelia::Users::from(users_file);
|
|
Ok(users)
|
|
}
|
|
|
|
pub fn load_groups(&self) -> Result<authelia::Groups, Box<dyn Error + Send + Sync>> {
|
|
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
|
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
|
let groups = authelia::Groups::from(users_file);
|
|
Ok(groups)
|
|
}
|
|
|
|
pub fn load_users_and_groups(
|
|
&self,
|
|
) -> Result<(authelia::Users, authelia::Groups), Box<dyn Error + Send + Sync>> {
|
|
let file_contents = std::fs::read_to_string(&self.config.authelia.user_database)?;
|
|
let users_file = serde_yaml::from_str::<authelia::UsersFile>(&file_contents)?;
|
|
let users = authelia::Users::from(users_file.clone());
|
|
let groups = authelia::Groups::from(users_file);
|
|
Ok((users, groups))
|
|
}
|
|
|
|
pub fn save_users(&self, users: authelia::Users) -> Result<(), Box<dyn Error + Send + Sync>> {
|
|
let users_file = authelia::UsersFile::from(users);
|
|
let file_contents = serde_yaml::to_string(&users_file)?;
|
|
std::fs::write(&self.config.authelia.user_database, file_contents)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl FromRef<State> for Config {
|
|
fn from_ref(state: &State) -> Self {
|
|
state.config.clone()
|
|
}
|
|
}
|
|
|
|
impl FromRef<State> for reqwest::Client {
|
|
fn from_ref(state: &State) -> Self {
|
|
state.oauth_http_client.clone()
|
|
}
|
|
}
|
|
|
|
impl FromRef<State> for OAuthClient {
|
|
fn from_ref(state: &State) -> Self {
|
|
state.oauth_client.clone()
|
|
}
|
|
}
|
|
|
|
impl FromRef<State> for MemoryStore {
|
|
fn from_ref(state: &State) -> Self {
|
|
state.session_store.clone()
|
|
}
|
|
}
|
|
|
|
async fn oauth(
|
|
config: &Config,
|
|
) -> Result<(reqwest::Client, OAuthClient), Box<dyn Error + Send + Sync>> {
|
|
let oauth_http_client = reqwest::ClientBuilder::new()
|
|
.redirect(reqwest::redirect::Policy::none())
|
|
.danger_accept_invalid_certs(config.oauth.insecure)
|
|
.build()?;
|
|
|
|
let provider_metadata = CoreProviderMetadata::discover_async(
|
|
IssuerUrl::new(config.oauth.issuer_url.clone())?,
|
|
&oauth_http_client,
|
|
)
|
|
.await?;
|
|
|
|
let oauth_client = OAuthClient::from_provider_metadata(
|
|
provider_metadata,
|
|
ClientId::new(config.oauth.client_id.clone()),
|
|
Some(ClientSecret::new(config.oauth.client_secret.clone())),
|
|
)
|
|
.set_redirect_uri(RedirectUrl::new(format!(
|
|
"{}{}/api/auth/callback",
|
|
config.server.host, config.server.subpath
|
|
))?);
|
|
|
|
Ok((oauth_http_client, oauth_client))
|
|
}
|
|
|
|
fn session_store() -> MemoryStore {
|
|
let session_store = MemoryStore::new();
|
|
|
|
let session_store_clone = session_store.clone();
|
|
spawn(async move {
|
|
loop {
|
|
match session_store_clone.cleanup().await {
|
|
Ok(()) => {}
|
|
Err(e) => error!("Failed to clean up session store: {e}"),
|
|
}
|
|
sleep(Duration::from_secs(60)).await;
|
|
}
|
|
});
|
|
|
|
session_store
|
|
}
|