Add fuse callbacks

Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
2025-06-07 11:00:33 +01:00
parent ab3cb8bd4e
commit ab9f2cbc09
47 changed files with 884 additions and 451 deletions

View File

@@ -1,5 +1,6 @@
use clap::Parser;
use serde::Deserialize;
use sqlx::query;
use std::{
error::Error,
fs,
@@ -7,6 +8,8 @@ use std::{
path::PathBuf,
};
use crate::utils::crypto::hash_password;
#[derive(Clone, Deserialize)]
pub struct ServerConfig {
pub host: String,
@@ -59,6 +62,40 @@ pub struct RedisConfig {
pub database: u8,
}
#[derive(Clone, Deserialize)]
pub struct AdminConfig {
pub name: String,
pub display_name: String,
pub password: String,
pub email: String,
}
impl AdminConfig {
pub async fn upsert(&self, pool: &sqlx::PgPool) -> Result<(), Box<dyn Error + Send + Sync>> {
let password = hash_password(&self.password);
query!(
r#"
INSERT INTO glyph_users (name, display_name, password, email, disabled)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (name) DO UPDATE
SET display_name = EXCLUDED.display_name,
password = EXCLUDED.password,
email = EXCLUDED.email,
disabled = EXCLUDED.disabled
"#,
self.name,
self.display_name,
password,
self.email,
false
)
.execute(pool)
.await?;
Ok(())
}
}
#[derive(Clone, Deserialize)]
pub struct Config {
pub server: ServerConfig,
@@ -66,6 +103,7 @@ pub struct Config {
pub fuse: FuseConfig,
pub postgresql: PostgresqlConfig,
pub redis: RedisConfig,
pub admin: AdminConfig,
}
impl TryFrom<&PathBuf> for Config {
@@ -87,4 +125,7 @@ pub struct Args {
/// Path to the log4rs config file
#[arg(short, long, value_name = "FILE", default_value = "log4rs.yaml")]
pub log_config: PathBuf,
/// Additional arguments to pass to Authelia
#[arg(last = true, num_args = 0.., allow_hyphen_values = true)]
pub passthrough: Vec<String>,
}

View File

@@ -4,6 +4,7 @@
use std::{
cmp,
collections::HashMap,
error::Error,
ffi::CString,
mem::MaybeUninit,
ops::Deref,
@@ -11,17 +12,20 @@ use std::{
time::{Duration, SystemTime},
};
use fuser::{FileType, Filesystem};
use fuser::{FileType, Filesystem, Notifier, Session};
use libc::{
EACCES, EINVAL, EISDIR, ENOENT, ENOSYS, ENOTDIR, EPERM, O_ACCMODE, O_APPEND, O_RDONLY, O_TRUNC,
O_WRONLY, R_OK, W_OK, X_OK, c_int, gid_t, uid_t,
};
use parking_lot::{RwLock, RwLockWriteGuard};
use sqlx::PgPool;
use tokio::{fs, task::spawn_blocking};
use crate::config::FuseConfig;
type WriteCallback = Box<dyn Fn(&str) + Send + Sync>;
type WriteCallback = Box<dyn Fn(&PgPool, &str) + Send + Sync>;
#[derive(Clone, Copy)]
struct StaticState {
creation_time: SystemTime,
user: u32,
@@ -29,12 +33,14 @@ struct StaticState {
block_size: u32,
}
#[derive(Clone)]
struct VariableState {
contents: String,
access_time: SystemTime,
modification_time: SystemTime,
}
#[derive(Clone, Copy)]
struct Handle {
inode: u64,
uid: u32,
@@ -42,6 +48,7 @@ struct Handle {
cursor: i64,
}
#[derive(Clone)]
struct Handles {
handles: HashMap<u64, Handle>,
next_handle: u64,
@@ -55,12 +62,15 @@ impl Handles {
}
}
#[derive(Clone)]
pub struct AutheliaFS {
config: FuseConfig,
write_callback: Option<WriteCallback>,
static_state: Arc<StaticState>,
static_state: StaticState,
variable_state: Arc<RwLock<VariableState>>,
handles: Arc<RwLock<Handles>>,
write_callback: Arc<RwLock<WriteCallback>>,
notifier: Arc<RwLock<Option<Notifier>>>,
pg_pool: PgPool,
}
const TTL: Duration = Duration::from_secs(1);
@@ -222,13 +232,19 @@ enum HandleCheckResult {
}
impl AutheliaFS {
pub fn new(config: FuseConfig, write_callback: Option<WriteCallback>) -> Self {
pub async fn new(
config: FuseConfig,
write_callback: Option<WriteCallback>,
pg_pool: PgPool,
) -> Self {
let contents = String::new();
let time = SystemTime::now();
let uid = getuid();
let gid = getgid();
let _ = fs::create_dir_all(&config.mount_directory).await;
let block_size = u32::try_from(
stat(config.mount_directory.to_str().unwrap())
.unwrap()
@@ -236,14 +252,14 @@ impl AutheliaFS {
)
.unwrap_or(4096);
let static_file_state = Arc::new(StaticState {
let static_state = StaticState {
creation_time: time,
user: uid,
group: gid,
block_size,
});
};
let variable_file_state = Arc::new(RwLock::new(VariableState {
let variable_state = Arc::new(RwLock::new(VariableState {
contents,
access_time: time,
modification_time: time,
@@ -254,18 +270,48 @@ impl AutheliaFS {
next_handle: 1,
}));
let write_callback = Arc::new(RwLock::new(
write_callback.unwrap_or_else(|| Box::new(|_, _| {})),
));
let notifier = Arc::new(RwLock::new(None));
Self {
config,
write_callback,
variable_state: variable_file_state,
static_state: static_file_state,
static_state,
variable_state,
handles,
write_callback,
notifier,
pg_pool,
}
}
pub fn mount(self) -> std::io::Result<()> {
let mountpoint = self.config.mount_directory.clone();
fuser::mount2(self, mountpoint, &vec![])
pub async fn run(self) -> Result<(), Box<dyn Error + Send + Sync>> {
let _ = fs::create_dir_all(&self.config.mount_directory).await;
let mut session = Session::new(self.clone(), self.config.mount_directory.clone(), &[])?;
self.notifier.write().replace(session.notifier());
Ok(spawn_blocking(move || session.run().unwrap()).await?)
}
pub async fn store(&self, contents: String) -> Result<(), Box<dyn Error + Send + Sync>> {
let variable_state = self.variable_state.clone();
let notifier = self.notifier.clone();
Ok(spawn_blocking(move || {
let mut variable_state = variable_state.write();
variable_state.contents = contents;
variable_state.modification_time = SystemTime::now();
variable_state.access_time = SystemTime::now();
if let Some(notifier) = notifier.write().as_ref() {
notifier
.store(USERS_FILE_INODE, 0, variable_state.contents.as_bytes())
.unwrap();
}
})
.await?)
}
#[allow(clippy::fn_params_excessive_bools)]
@@ -483,30 +529,36 @@ impl Filesystem for AutheliaFS {
return;
}
if let Some(size) = size {
if size == 0 {
let mut variable_file_state = self.variable_state.write();
variable_file_state.contents.clear();
} else {
reply.error(ENOSYS);
return;
if size.is_some() && (atime.is_some() || mtime.is_some()) {
let mut variable_state = self.variable_state.write();
if let Some(size) = size {
if size == 0 {
variable_state.contents.clear();
} else {
reply.error(ENOSYS);
return;
}
}
}
if mtime.is_some() || atime.is_some() {
let mut variable_file_state = self.variable_state.write();
variable_file_state.modification_time = match mtime {
variable_state.modification_time = match mtime {
Some(fuser::TimeOrNow::Now) => SystemTime::now(),
Some(fuser::TimeOrNow::SpecificTime(time)) => time,
None => variable_file_state.modification_time,
None => variable_state.modification_time,
};
variable_file_state.access_time = match atime {
variable_state.access_time = match atime {
Some(fuser::TimeOrNow::Now) => SystemTime::now(),
Some(fuser::TimeOrNow::SpecificTime(time)) => time,
None => variable_file_state.access_time,
None => variable_state.access_time,
};
self.notifier
.write()
.as_ref()
.unwrap()
.store(ino, 0, variable_state.contents.as_bytes())
.unwrap();
}
let attr = file.to_file_attr(self);
@@ -544,8 +596,8 @@ impl Filesystem for AutheliaFS {
drop(handles);
if flags & O_TRUNC != 0 && flags & O_ACCMODE != O_RDONLY {
let mut variable_file_state = self.variable_state.write();
variable_file_state.contents.clear();
let mut variable_state = self.variable_state.write();
variable_state.contents.clear();
}
reply.opened(handle, 0);
@@ -578,11 +630,11 @@ impl Filesystem for AutheliaFS {
AccessCheckResult::Ok(_) => {}
}
let mut variable_file_state = self.variable_state.write();
variable_file_state.access_time = SystemTime::now();
let mut variable_state = self.variable_state.write();
variable_state.access_time = SystemTime::now();
let variable_file_state = RwLockWriteGuard::downgrade(variable_file_state);
let contents = variable_file_state.contents.as_bytes();
let variable_state = RwLockWriteGuard::downgrade(variable_state);
let contents = variable_state.contents.as_bytes();
let contents_len = i64::try_from(contents.len()).unwrap();
if offset < 0 || offset >= contents_len {
@@ -626,9 +678,9 @@ impl Filesystem for AutheliaFS {
let mut handles = self.handles.write();
let handle = handles.handles.get_mut(&fh).unwrap();
let mut variable_file_state = self.variable_state.write();
let mut variable_state = self.variable_state.write();
let old_end = variable_file_state.contents.len();
let old_end = variable_state.contents.len();
let offset = if handle.flags & O_APPEND != 0 {
handle.cursor = i64::try_from(old_end).unwrap();
@@ -641,8 +693,8 @@ impl Filesystem for AutheliaFS {
usize::try_from(offset).unwrap()
};
variable_file_state.access_time = SystemTime::now();
variable_file_state.modification_time = SystemTime::now();
variable_state.access_time = SystemTime::now();
variable_state.modification_time = SystemTime::now();
let Ok(new_data) = std::str::from_utf8(data) else {
reply.error(EINVAL);
@@ -653,22 +705,27 @@ impl Filesystem for AutheliaFS {
let new_real_end = cmp::max(new_end, old_end);
let mut new_contents = String::with_capacity(new_real_end);
new_contents.push_str(&variable_file_state.contents[..offset]);
new_contents.push_str(&variable_state.contents[..offset]);
new_contents.push_str(new_data);
if new_end < old_end {
new_contents.push_str(&variable_file_state.contents[new_end..]);
new_contents.push_str(&variable_state.contents[new_end..]);
}
variable_file_state.contents = new_contents;
variable_state.contents = new_contents;
handle.cursor = i64::try_from(offset + new_data.len()).unwrap();
drop(handles);
if let Some(callback) = &self.write_callback {
callback(&variable_file_state.contents);
}
self.write_callback.read().deref()(&self.pg_pool, &variable_state.contents);
drop(variable_file_state);
self.notifier
.write()
.as_ref()
.unwrap()
.store(ino, 0, variable_state.contents.as_bytes())
.unwrap();
drop(variable_state);
reply.written(u32::try_from(data.len()).unwrap());
}
@@ -954,10 +1011,10 @@ impl Filesystem for AutheliaFS {
AccessCheckResult::Ok(_) => {}
}
let variable_file_state = self.variable_state.read();
let blocks = (variable_file_state.contents.len() as u64)
let variable_state = self.variable_state.read();
let blocks = (variable_state.contents.len() as u64)
.div_ceil(u64::from(self.static_state.block_size));
drop(variable_file_state);
drop(variable_state);
reply.statfs(
blocks,

View File

@@ -2,7 +2,7 @@
#![allow(clippy::missing_docs_in_private_items)]
mod config;
mod fuser;
mod fuse;
mod models;
mod routes;
mod state;
@@ -15,28 +15,26 @@ use log4rs::config::Deserializers;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use config::{Args, Config};
use config::Args;
use state::State;
#[tokio::main]
async fn main() {
let args = Args::parse();
log4rs::init_file(args.log_config, Deserializers::default()).unwrap();
let args: Args = Args::parse();
log4rs::init_file(args.log_config.clone(), Deserializers::default()).unwrap();
let config = Config::try_from(&args.config).unwrap();
let state = State::from_config(config.clone()).await;
let state = State::from_args(args).await;
sqlx::migrate!("./migrations")
.run(&state.pg_pool)
.await
.unwrap();
let routes = routes::routes(state.clone());
let app = axum::Router::new().nest(&format!("{}/api", state.config.server.subpath), routes);
let routes = routes::routes(state);
let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes);
let addr = SocketAddr::from((config.server.address, config.server.port));
let addr = SocketAddr::from((state.config.server.address, state.config.server.port));
let listener = TcpListener::bind(addr).await.unwrap();
info!("Listening on {}", listener.local_addr().unwrap());
serve(listener, app).await.unwrap();
serve(listener, app)
.with_graceful_shutdown(utils::shutdown_signal())
.await
.unwrap();
}

View File

@@ -1,24 +1,102 @@
use log::warn;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_yaml::Value;
use sqlx::PgPool;
use std::collections::HashMap;
use std::{collections::HashMap, error::Error};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsersFile {
pub users: HashMap<String, UserFile>,
pub struct Users {
pub users: HashMap<String, User>,
#[serde(flatten)]
pub extra: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserFile {
pub struct User {
pub displayname: String,
pub password: String,
pub email: Option<String>,
pub disabled: Option<bool>,
pub picture: Option<String>,
pub groups: Option<Vec<String>>,
#[serde(flatten)]
pub extra: Option<HashMap<String, Value>>,
}
impl TryInto<Vec<super::users::UserWithGroups>> for Users {
type Error = Box<dyn Error + Send + Sync>;
fn try_into(self) -> Result<Vec<super::users::UserWithGroups>, Self::Error> {
self.users
.into_iter()
.map(|(name, user)| {
let groups = user.groups.unwrap_or_default();
Ok(super::users::UserWithGroups {
name: name.clone(),
display_name: user.displayname,
password: user.password,
email: user
.email
.ok_or_else(|| format!("User {} is missing an email", &name))?,
disabled: user.disabled.unwrap_or(false),
picture: user.picture,
groups,
})
})
.collect()
}
}
impl Users {
pub fn from_fuse(pool: &PgPool, contents: &str) {
let Ok(users) = serde_yaml::from_str::<Self>(contents) else {
warn!("Failed to parse users from JSON.");
return;
};
let users_with_groups: Vec<super::users::UserWithGroups> = match users.try_into() {
Ok(users) => users,
Err(e) => {
warn!("Failed to convert Users to UserWithGroups: {e}");
return;
}
};
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
super::users::UserWithGroups::upsert_many_delete_remaining(pool, &users_with_groups)
.await
.unwrap_or_else(|e| warn!("Failed to upsert users: {e}"));
});
}
pub async fn to_fuse(pool: &PgPool) -> Result<String, Box<dyn Error + Send + Sync>> {
let users_with_groups = super::users::UserWithGroups::select_all(pool).await?;
let users = Self {
users: users_with_groups
.into_iter()
.map(|user| {
(
user.name.clone(),
User {
displayname: user.display_name,
password: user.password,
email: Some(user.email),
disabled: Some(user.disabled),
picture: user.picture,
groups: Some(user.groups),
extra: None,
},
)
})
.collect(),
extra: None,
};
Ok(serde_yaml::to_string(&users)?)
}
}

View File

@@ -9,7 +9,7 @@ pub struct Group {
}
impl Group {
pub async fn select_by_name(
pub async fn select(
pool: &PgPool,
name: &str,
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
@@ -17,7 +17,7 @@ impl Group {
Group,
r#"
SELECT name
FROM groups
FROM glyph_groups
WHERE name = $1
"#,
name
@@ -28,13 +28,10 @@ impl Group {
Ok(group)
}
pub async fn delete_by_name(
pool: &PgPool,
name: &str,
) -> Result<(), Box<dyn Error + Send + Sync>> {
pub async fn delete(pool: &PgPool, name: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
query!(
r#"
DELETE FROM groups
DELETE FROM glyph_groups
WHERE name = $1
"#,
name
@@ -45,14 +42,14 @@ impl Group {
Ok(())
}
pub async fn all_exist_by_names(
pub async fn all_exist(
pool: &PgPool,
names: &[String],
) -> Result<bool, Box<dyn Error + Send + Sync>> {
let row = query!(
r#"
SELECT COUNT(*) AS "count!"
FROM groups
FROM glyph_groups
WHERE name = ANY($1)
"#,
names
@@ -67,20 +64,19 @@ impl Group {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupWithUsers {
pub name: String,
#[serde(default)]
pub users: Vec<String>,
}
impl GroupWithUsers {
pub async fn select(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
pub async fn select_all(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
let groups = query_as!(
GroupWithUsers,
r#"
SELECT
g.name,
COALESCE(array_agg(ug.user_name ORDER BY ug.user_name), ARRAY[]::TEXT[]) AS "users!"
FROM groups g
LEFT JOIN users_groups ug ON g.name = ug.group_name
GROUP BY g.name
ARRAY(SELECT ug.user_name FROM glyph_users_groups ug WHERE ug.group_name = g.name) AS "users!"
FROM glyph_groups g
"#
)
.fetch_all(pool)
@@ -89,7 +85,7 @@ impl GroupWithUsers {
Ok(groups)
}
pub async fn select_by_name(
pub async fn select(
pool: &PgPool,
name: &str,
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
@@ -98,11 +94,9 @@ impl GroupWithUsers {
r#"
SELECT
g.name,
COALESCE(array_agg(ug.user_name ORDER BY ug.user_name), ARRAY[]::TEXT[]) AS "users!"
FROM groups g
LEFT JOIN users_groups ug ON g.name = ug.group_name
ARRAY(SELECT ug.user_name FROM glyph_users_groups ug WHERE ug.group_name = g.name) AS "users!"
FROM glyph_groups g
WHERE g.name = $1
GROUP BY g.name
"#,
name
)
@@ -119,7 +113,7 @@ impl GroupWithUsers {
let mut tx = pool.begin().await?;
query!(
r#"INSERT INTO groups (name) VALUES ($1)"#,
r#"INSERT INTO glyph_groups (name) VALUES ($1)"#,
group_with_users.name
)
.execute(&mut *tx)
@@ -127,8 +121,8 @@ impl GroupWithUsers {
query!(
r#"
INSERT INTO users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
INSERT INTO glyph_users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
"#,
&group_with_users.users,
&vec![group_with_users.name.clone(); group_with_users.users.len()]

View File

@@ -19,7 +19,7 @@ impl UsersGroups {
query!(
r#"
DELETE FROM users_groups
DELETE FROM glyph_users_groups
WHERE group_name = $1
"#,
group_name
@@ -29,7 +29,7 @@ impl UsersGroups {
query!(
r#"
INSERT INTO users_groups (user_name, group_name)
INSERT INTO glyph_users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
"#,
users,
@@ -50,7 +50,7 @@ impl UsersGroups {
query!(
r#"
DELETE FROM users_groups
DELETE FROM glyph_users_groups
WHERE user_name = $1
"#,
user_name
@@ -60,7 +60,7 @@ impl UsersGroups {
query!(
r#"
INSERT INTO users_groups (user_name, group_name)
INSERT INTO glyph_users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
"#,
&vec![user_name.to_string(); groups.len()],

View File

@@ -1,4 +1,4 @@
use std::error::Error;
use std::{collections::HashSet, error::Error};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool, query, query_as};
@@ -12,20 +12,20 @@ pub struct User {
#[serde(default)]
pub disabled: bool,
#[serde(default)]
pub image: Option<String>,
pub picture: Option<String>,
}
impl User {
pub async fn select_by_name(
pub async fn select(
pool: &PgPool,
name: &str,
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
let user = query_as!(
User,
r#"
SELECT name, display_name, password, email, disabled, image
FROM users
WHERE name = $1
SELECT name, display_name, password, email, disabled, picture
FROM glyph_users
WHERE name = $1
"#,
name
)
@@ -38,21 +38,21 @@ impl User {
pub async fn upsert(pool: &PgPool, user: &Self) -> Result<(), Box<dyn Error + Send + Sync>> {
query!(
r#"
INSERT INTO users (name, display_name, password, email, disabled, image)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (name) DO UPDATE
SET display_name = EXCLUDED.display_name,
password = EXCLUDED.password,
email = EXCLUDED.email,
disabled = EXCLUDED.disabled,
image = EXCLUDED.image
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (name) DO UPDATE
SET display_name = EXCLUDED.display_name,
password = EXCLUDED.password,
email = EXCLUDED.email,
disabled = EXCLUDED.disabled,
picture = EXCLUDED.picture
"#,
user.name,
user.display_name,
user.password,
user.email,
user.disabled,
user.image
user.picture
)
.execute(pool)
.await?;
@@ -60,14 +60,11 @@ impl User {
Ok(())
}
pub async fn delete_by_name(
pool: &PgPool,
name: &str,
) -> Result<(), Box<dyn Error + Send + Sync>> {
pub async fn delete(pool: &PgPool, name: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
query!(
r#"
DELETE FROM users
WHERE name = $1
DELETE FROM glyph_users
WHERE name = $1
"#,
name
)
@@ -77,15 +74,15 @@ impl User {
Ok(())
}
pub async fn all_exist_by_names(
pub async fn all_exist(
pool: &PgPool,
names: &[String],
) -> Result<bool, Box<dyn Error + Send + Sync>> {
let row = query!(
r#"
SELECT COUNT(*) AS "count!"
FROM users
WHERE name = ANY($1)
SELECT COUNT(*) AS "count!"
FROM glyph_users
WHERE name = ANY($1)
"#,
names
)
@@ -105,26 +102,25 @@ pub struct UserWithGroups {
#[serde(default)]
pub disabled: bool,
#[serde(default)]
pub image: Option<String>,
pub picture: Option<String>,
#[serde(default)]
pub groups: Vec<String>,
}
impl UserWithGroups {
pub async fn select(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
pub async fn select_all(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
let users = query_as!(
UserWithGroups,
r#"
SELECT
u.name,
u.display_name,
u.password,
u.email,
u.disabled,
u.image,
COALESCE(array_agg(ug.group_name ORDER BY ug.group_name), ARRAY[]::TEXT[]) AS "groups!"
FROM users u
LEFT JOIN users_groups ug ON u.name = ug.user_name
GROUP BY u.name, u.email, u.disabled, u.image
SELECT
u.name,
u.display_name,
u.password,
u.email,
u.disabled,
u.picture,
ARRAY(SELECT ug.group_name FROM glyph_users_groups ug WHERE ug.user_name = u.name) AS "groups!"
FROM glyph_users u
"#
)
.fetch_all(pool)
@@ -133,25 +129,23 @@ impl UserWithGroups {
Ok(users)
}
pub async fn select_by_name(
pub async fn select(
pool: &PgPool,
name: &str,
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
let user = query_as!(
UserWithGroups,
r#"
SELECT
u.name,
u.display_name,
u.password,
u.email,
u.disabled,
u.image,
COALESCE(array_agg(ug.group_name ORDER BY ug.group_name), ARRAY[]::TEXT[]) AS "groups!"
FROM users u
LEFT JOIN users_groups ug ON u.name = ug.user_name
WHERE u.name = $1
GROUP BY u.name, u.email, u.disabled, u.image
SELECT
u.name,
u.display_name,
u.password,
u.email,
u.disabled,
u.picture,
ARRAY(SELECT ug.group_name FROM glyph_users_groups ug WHERE ug.user_name = u.name) AS "groups!"
FROM glyph_users u
WHERE u.name = $1
"#,
name
)
@@ -168,23 +162,24 @@ impl UserWithGroups {
let mut tx = pool.begin().await?;
query!(
r#"INSERT INTO users (name, display_name, password, email, disabled, image)
VALUES ($1, $2, $3, $4, $5, $6)
r#"
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
user_with_groups.name,
user_with_groups.display_name,
user_with_groups.password,
user_with_groups.email,
user_with_groups.disabled,
user_with_groups.image
user_with_groups.picture
)
.execute(&mut *tx)
.await?;
query!(
r#"
INSERT INTO users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
INSERT INTO glyph_users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
"#,
&user_with_groups.groups,
&vec![user_with_groups.name.clone(); user_with_groups.groups.len()]
@@ -196,4 +191,93 @@ impl UserWithGroups {
Ok(())
}
pub async fn upsert_many_delete_remaining(
pool: &PgPool,
users_with_groups: &[Self],
) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut tx = pool.begin().await?;
for user in users_with_groups {
query!(
r#"
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (name) DO UPDATE
SET display_name = EXCLUDED.display_name,
password = EXCLUDED.password,
email = EXCLUDED.email,
disabled = EXCLUDED.disabled,
picture = EXCLUDED.picture
"#,
user.name,
user.display_name,
user.password,
user.email,
user.disabled,
user.picture
)
.execute(&mut *tx)
.await?;
query!(
r#"
DELETE FROM glyph_users_groups
WHERE user_name = $1
"#,
user.name
)
.execute(&mut *tx)
.await?;
if !user.groups.is_empty() {
query!(
r#"
INSERT INTO glyph_users_groups (user_name, group_name)
SELECT * FROM UNNEST($1::text[], $2::text[])
"#,
&user.groups,
&vec![user.name.clone(); user.groups.len()]
)
.execute(&mut *tx)
.await?;
}
}
let users = users_with_groups
.iter()
.map(|user| user.name.clone())
.collect::<Vec<_>>();
query!(
r#"
DELETE FROM glyph_users
WHERE name <> ALL($1)
"#,
&users
)
.execute(&mut *tx)
.await?;
let groups = users_with_groups
.iter()
.flat_map(|user| user.groups.iter().cloned())
.collect::<HashSet<_>>()
.into_iter()
.collect::<Vec<_>>();
query!(
r#"
DELETE FROM glyph_groups
WHERE name <> ALL($1)
"#,
&groups
)
.execute(pool)
.await?;
tx.commit().await?;
Ok(())
}
}

View File

@@ -35,7 +35,7 @@ pub async fn get_all(
_: auth::User,
extract::State(pg_pool): extract::State<PgPool>,
) -> Result<impl IntoResponse, StatusCode> {
let groups_with_users = models::groups::GroupWithUsers::select(&pg_pool)
let groups_with_users = models::groups::GroupWithUsers::select_all(&pg_pool)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
@@ -52,7 +52,7 @@ pub async fn get(
extract::Path(name): extract::Path<NonEmptyString>,
extract::State(pg_pool): extract::State<PgPool>,
) -> Result<impl IntoResponse, StatusCode> {
let group_with_users = models::groups::GroupWithUsers::select_by_name(&pg_pool, name.as_str())
let group_with_users = models::groups::GroupWithUsers::select(&pg_pool, name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
@@ -71,7 +71,7 @@ pub async fn create(
extract::State(pg_pool): extract::State<PgPool>,
extract::Json(group_create): extract::Json<GroupCreate>,
) -> Result<impl IntoResponse, StatusCode> {
if models::groups::Group::select_by_name(&pg_pool, group_create.name.as_str())
if models::groups::Group::select(&pg_pool, group_create.name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.is_some()
@@ -85,7 +85,7 @@ pub async fn create(
.map(|u| u.to_string())
.collect::<Vec<_>>();
if !models::users::User::all_exist_by_names(&pg_pool, &users)
if !models::users::User::all_exist(&pg_pool, &users)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
{
@@ -116,7 +116,7 @@ pub async fn update(
extract::State(config): extract::State<Config>,
extract::Json(group_update): extract::Json<GroupUpdate>,
) -> Result<impl IntoResponse, StatusCode> {
let group = models::groups::Group::select_by_name(&pg_pool, name.as_str())
let group = models::groups::Group::select(&pg_pool, name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
@@ -126,7 +126,7 @@ pub async fn update(
if let Some(users) = &group_update.users {
let users = users.iter().map(ToString::to_string).collect::<Vec<_>>();
if !models::users::User::all_exist_by_names(&pg_pool, &users)
if !models::users::User::all_exist(&pg_pool, &users)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
{
@@ -163,12 +163,12 @@ pub async fn delete(
return Err(StatusCode::FORBIDDEN);
}
let group = models::groups::Group::select_by_name(&pg_pool, &name)
let group = models::groups::Group::select(&pg_pool, &name)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
Group::delete_by_name(&pg_pool, &group.name)
Group::delete(&pg_pool, &group.name)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;

View File

@@ -21,7 +21,7 @@ struct UserResponse {
display_name: String,
email: String,
disabled: bool,
image: Option<String>,
picture: Option<String>,
groups: Vec<String>,
}
@@ -31,7 +31,7 @@ impl From<models::users::UserWithGroups> for UserResponse {
display_name: user.display_name,
email: user.email,
disabled: user.disabled,
image: user.image,
picture: user.picture,
groups: user.groups,
}
}
@@ -43,7 +43,7 @@ pub async fn get_all(
_: auth::User,
extract::State(pg_pool): extract::State<PgPool>,
) -> Result<impl IntoResponse, StatusCode> {
let users_with_groups = models::users::UserWithGroups::select(&pg_pool)
let users_with_groups = models::users::UserWithGroups::select_all(&pg_pool)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
@@ -60,7 +60,7 @@ pub async fn get(
extract::Path(name): extract::Path<NonEmptyString>,
extract::State(pg_pool): extract::State<PgPool>,
) -> Result<impl IntoResponse, StatusCode> {
let user_with_groups = models::users::UserWithGroups::select_by_name(&pg_pool, name.as_str())
let user_with_groups = models::users::UserWithGroups::select(&pg_pool, name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
@@ -74,7 +74,7 @@ pub struct UserCreate {
displayname: NonEmptyString,
email: NonEmptyString,
disabled: bool,
image: Option<NonEmptyString>,
picture: Option<NonEmptyString>,
groups: Vec<NonEmptyString>,
}
@@ -83,7 +83,7 @@ pub async fn create(
extract::State(pg_pool): extract::State<PgPool>,
extract::Json(user_create): extract::Json<UserCreate>,
) -> Result<impl IntoResponse, StatusCode> {
if models::users::User::select_by_name(&pg_pool, user_create.name.as_str())
if models::users::User::select(&pg_pool, user_create.name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.is_some()
@@ -97,7 +97,7 @@ pub async fn create(
.map(|g| g.to_string())
.collect::<Vec<_>>();
if !models::groups::Group::all_exist_by_names(&pg_pool, &groups)
if !models::groups::Group::all_exist(&pg_pool, &groups)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
{
@@ -110,7 +110,7 @@ pub async fn create(
password: generate_random_password_hash(),
email: user_create.email.to_string(),
disabled: user_create.disabled,
image: user_create.image.map(|i| i.to_string()),
picture: user_create.picture.map(|i| i.to_string()),
groups,
};
@@ -126,7 +126,7 @@ pub struct UserUpdate {
display_name: Option<NonEmptyString>,
email: Option<NonEmptyString>,
disabled: Option<bool>,
image: Option<NonEmptyString>,
picture: Option<NonEmptyString>,
groups: Option<Vec<NonEmptyString>>,
}
@@ -137,7 +137,7 @@ pub async fn update(
extract::State(config): extract::State<Config>,
extract::Json(user_update): extract::Json<UserUpdate>,
) -> Result<impl IntoResponse, StatusCode> {
let user = models::users::User::select_by_name(&pg_pool, name.as_str())
let user = models::users::User::select(&pg_pool, name.as_str())
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
@@ -150,7 +150,7 @@ pub async fn update(
.map(|g| g.to_string())
.collect::<Vec<_>>();
if !models::groups::Group::all_exist_by_names(&pg_pool, &groups)
if !models::groups::Group::all_exist(&pg_pool, &groups)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
{
@@ -183,7 +183,7 @@ pub async fn update(
.map(|e| e.to_string())
.unwrap_or(user.email),
disabled: user_update.disabled.unwrap_or(user.disabled),
image: user_update.image.map(|i| i.to_string()).or(user.image),
picture: user_update.picture.map(|i| i.to_string()).or(user.picture),
};
models::users::User::upsert(&pg_pool, &user)
@@ -206,12 +206,12 @@ pub async fn delete(
return Err(StatusCode::FORBIDDEN);
}
let user = models::users::User::select_by_name(&pg_pool, &name)
let user = models::users::User::select(&pg_pool, &name)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(StatusCode::NOT_FOUND)?;
models::users::User::delete_by_name(&pg_pool, &user.name)
models::users::User::delete(&pg_pool, &user.name)
.await
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;

View File

@@ -1,3 +1,5 @@
use std::{sync::Arc, time::Duration};
use async_redis_session::RedisSessionStore;
use axum::extract::FromRef;
use openidconnect::{
@@ -10,11 +12,14 @@ use openidconnect::{
},
reqwest,
};
use redis::{self, AsyncCommands};
use sqlx::{PgPool, postgres::PgPoolOptions};
use tokio::spawn;
use tokio::{process::Command, spawn, task::JoinHandle, time::sleep};
use crate::config::Config;
use crate::{
config::{Args, Config},
fuse::AutheliaFS,
models,
};
pub type OAuthClient<
HasAuthUrl = EndpointSet,
@@ -46,26 +51,44 @@ pub type OAuthClient<
#[derive(Clone)]
pub struct State {
pub config: Config,
pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient,
pub pg_pool: PgPool,
pub redis_client: redis::aio::MultiplexedConnection,
pub filesystem: AutheliaFS,
pub mount: Arc<JoinHandle<()>>,
pub authelia: Arc<JoinHandle<()>>,
pub oauth_http_client: reqwest::Client,
pub oauth_client: OAuthClient,
pub session_store: RedisSessionStore,
}
impl State {
pub async fn from_config(config: Config) -> Self {
let (oauth_http_client, oauth_client) = oauth_client(&config).await;
pub async fn from_args(args: Args) -> Self {
let config = Config::try_from(&args.config).unwrap();
let pg_pool = pg_pool(&config).await;
sqlx::migrate!("./migrations").run(&pg_pool).await.unwrap();
config.admin.upsert(&pg_pool).await.unwrap();
let redis_client = redis_client(&config).await;
let (filesystem, mount) = fuse(&config, &pg_pool).await;
let contents = models::authelia::Users::to_fuse(&pg_pool).await.unwrap();
filesystem.store(contents).await.unwrap();
let authelia = authelia(args.passthrough);
let (oauth_http_client, oauth_client) = oauth_client(&config).await;
let session_store = session_store(&config);
Self {
config,
oauth_http_client,
oauth_client,
pg_pool,
redis_client,
filesystem,
mount,
authelia,
oauth_http_client,
oauth_client,
session_store,
}
}
@@ -77,18 +100,6 @@ impl FromRef<State> for Config {
}
}
impl FromRef<State> for reqwest::Client {
fn from_ref(state: &State) -> Self {
state.oauth_http_client.clone()
}
}
impl FromRef<State> for OAuthClient {
fn from_ref(state: &State) -> Self {
state.oauth_client.clone()
}
}
impl FromRef<State> for PgPool {
fn from_ref(state: &State) -> Self {
state.pg_pool.clone()
@@ -101,42 +112,30 @@ impl FromRef<State> for redis::aio::MultiplexedConnection {
}
}
impl FromRef<State> for AutheliaFS {
fn from_ref(state: &State) -> Self {
state.filesystem.clone()
}
}
impl FromRef<State> for reqwest::Client {
fn from_ref(state: &State) -> Self {
state.oauth_http_client.clone()
}
}
impl FromRef<State> for OAuthClient {
fn from_ref(state: &State) -> Self {
state.oauth_client.clone()
}
}
impl FromRef<State> for RedisSessionStore {
fn from_ref(state: &State) -> Self {
state.session_store.clone()
}
}
async fn oauth_client(config: &Config) -> (reqwest::Client, OAuthClient) {
let oauth_http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(config.oauth.insecure)
.build()
.unwrap();
let provider_metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(config.oauth.issuer_url.clone()).unwrap(),
&oauth_http_client,
)
.await
.unwrap();
let oauth_client = OAuthClient::from_provider_metadata(
provider_metadata,
ClientId::new(config.oauth.client_id.clone()),
Some(ClientSecret::new(config.oauth.client_secret.clone())),
)
.set_redirect_uri(
RedirectUrl::new(format!(
"{}{}/api/auth/callback",
config.server.host, config.server.subpath
))
.unwrap(),
);
(oauth_http_client, oauth_client)
}
async fn pg_pool(config: &Config) -> PgPool {
PgPoolOptions::new()
.max_connections(5)
@@ -159,43 +158,86 @@ async fn redis_client(config: &Config) -> redis::aio::MultiplexedConnection {
);
let client = redis::Client::open(url).unwrap();
let mut connection = client.get_multiplexed_async_connection().await.unwrap();
client.get_multiplexed_async_connection().await.unwrap()
}
let _: () = redis::cmd("CONFIG")
.arg("SET")
.arg("notify-keyspace-events")
.arg("Ex")
.query_async(&mut connection)
.await
async fn fuse(config: &Config, pg_pool: &PgPool) -> (AutheliaFS, Arc<JoinHandle<()>>) {
let fs = AutheliaFS::new(
config.fuse.clone(),
Some(Box::new(models::authelia::Users::from_fuse)),
pg_pool.clone(),
)
.await;
let fs_clone = fs.clone();
let mount = Arc::new(spawn(async move {
loop {
let _ = fs_clone.clone().run().await;
}
}));
(fs, mount)
}
fn authelia(args: Vec<String>) -> Arc<JoinHandle<()>> {
Arc::new(spawn(async move {
loop {
let _ = Command::new("authelia")
.args(args.clone())
.spawn()
.unwrap()
.wait()
.await;
}
}))
}
async fn oauth_client(config: &Config) -> (reqwest::Client, OAuthClient) {
let oauth_http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(config.oauth.insecure)
.build()
.unwrap();
let database = config.redis.database.to_string();
spawn(async move {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let rconfig = redis::AsyncConnectionConfig::new().set_push_sender(tx);
let mut connection = client
.get_multiplexed_async_connection_with_config(&rconfig)
.await
.unwrap();
let mut provider_metadata = None;
let channel = format!("__keyevent@{database}__:expired");
connection.subscribe(&[channel]).await.unwrap();
let retries = 10;
let mut backoff = Duration::from_secs(1);
while let Some(msg) = rx.recv().await {
if let Some(msg) = redis::Msg::from_push_info(msg) {
if let Ok(key) = msg.get_payload::<String>() {
if !key.starts_with("invite:") {
continue;
}
let id = key.trim_start_matches("invite:").to_string();
let _: i64 = connection.srem("invite:all", id).await.unwrap();
}
}
for i in 0..retries {
if let Ok(metadata) = CoreProviderMetadata::discover_async(
IssuerUrl::new(config.oauth.issuer_url.clone()).unwrap(),
&oauth_http_client,
)
.await
{
provider_metadata = Some(metadata);
break;
}
if i == retries - 1 {
break;
}
});
connection
sleep(backoff).await;
backoff *= 2;
}
let provider_metadata = provider_metadata.unwrap();
let oauth_client = OAuthClient::from_provider_metadata(
provider_metadata,
ClientId::new(config.oauth.client_id.clone()),
Some(ClientSecret::new(config.oauth.client_secret.clone())),
)
.set_redirect_uri(
RedirectUrl::new(format!(
"{}{}/api/auth/callback",
config.server.host, config.server.subpath
))
.unwrap(),
);
(oauth_http_client, oauth_client)
}
fn session_store(config: &Config) -> RedisSessionStore {

View File

@@ -33,3 +33,20 @@ pub fn generate_random_password_hash() -> String {
password_hash
}
pub fn hash_password(password: &str) -> String {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
argon2::Version::V0x13,
argon2::Params::new(65536, 3, 4, Some(32)).unwrap(),
);
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string();
password_hash
}

View File

@@ -1 +1,21 @@
use tokio::{select, signal};
pub mod crypto;
pub async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c().await.unwrap();
};
let terminate = async {
signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.unwrap()
.recv()
.await;
};
select! {
() = ctrl_c => {},
() = terminate => {},
}
}