Add fuse callbacks
Signed-off-by: Nikolaos Karaolidis <nick@karaolidis.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use sqlx::query;
|
||||
use std::{
|
||||
error::Error,
|
||||
fs,
|
||||
@@ -7,6 +8,8 @@ use std::{
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use crate::utils::crypto::hash_password;
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
@@ -59,6 +62,40 @@ pub struct RedisConfig {
|
||||
pub database: u8,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct AdminConfig {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub password: String,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
impl AdminConfig {
|
||||
pub async fn upsert(&self, pool: &sqlx::PgPool) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let password = hash_password(&self.password);
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO glyph_users (name, display_name, password, email, disabled)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET display_name = EXCLUDED.display_name,
|
||||
password = EXCLUDED.password,
|
||||
email = EXCLUDED.email,
|
||||
disabled = EXCLUDED.disabled
|
||||
"#,
|
||||
self.name,
|
||||
self.display_name,
|
||||
password,
|
||||
self.email,
|
||||
false
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub server: ServerConfig,
|
||||
@@ -66,6 +103,7 @@ pub struct Config {
|
||||
pub fuse: FuseConfig,
|
||||
pub postgresql: PostgresqlConfig,
|
||||
pub redis: RedisConfig,
|
||||
pub admin: AdminConfig,
|
||||
}
|
||||
|
||||
impl TryFrom<&PathBuf> for Config {
|
||||
@@ -87,4 +125,7 @@ pub struct Args {
|
||||
/// Path to the log4rs config file
|
||||
#[arg(short, long, value_name = "FILE", default_value = "log4rs.yaml")]
|
||||
pub log_config: PathBuf,
|
||||
/// Additional arguments to pass to Authelia
|
||||
#[arg(last = true, num_args = 0.., allow_hyphen_values = true)]
|
||||
pub passthrough: Vec<String>,
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
use std::{
|
||||
cmp,
|
||||
collections::HashMap,
|
||||
error::Error,
|
||||
ffi::CString,
|
||||
mem::MaybeUninit,
|
||||
ops::Deref,
|
||||
@@ -11,17 +12,20 @@ use std::{
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use fuser::{FileType, Filesystem};
|
||||
use fuser::{FileType, Filesystem, Notifier, Session};
|
||||
use libc::{
|
||||
EACCES, EINVAL, EISDIR, ENOENT, ENOSYS, ENOTDIR, EPERM, O_ACCMODE, O_APPEND, O_RDONLY, O_TRUNC,
|
||||
O_WRONLY, R_OK, W_OK, X_OK, c_int, gid_t, uid_t,
|
||||
};
|
||||
use parking_lot::{RwLock, RwLockWriteGuard};
|
||||
use sqlx::PgPool;
|
||||
use tokio::{fs, task::spawn_blocking};
|
||||
|
||||
use crate::config::FuseConfig;
|
||||
|
||||
type WriteCallback = Box<dyn Fn(&str) + Send + Sync>;
|
||||
type WriteCallback = Box<dyn Fn(&PgPool, &str) + Send + Sync>;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct StaticState {
|
||||
creation_time: SystemTime,
|
||||
user: u32,
|
||||
@@ -29,12 +33,14 @@ struct StaticState {
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct VariableState {
|
||||
contents: String,
|
||||
access_time: SystemTime,
|
||||
modification_time: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Handle {
|
||||
inode: u64,
|
||||
uid: u32,
|
||||
@@ -42,6 +48,7 @@ struct Handle {
|
||||
cursor: i64,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Handles {
|
||||
handles: HashMap<u64, Handle>,
|
||||
next_handle: u64,
|
||||
@@ -55,12 +62,15 @@ impl Handles {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AutheliaFS {
|
||||
config: FuseConfig,
|
||||
write_callback: Option<WriteCallback>,
|
||||
static_state: Arc<StaticState>,
|
||||
static_state: StaticState,
|
||||
variable_state: Arc<RwLock<VariableState>>,
|
||||
handles: Arc<RwLock<Handles>>,
|
||||
write_callback: Arc<RwLock<WriteCallback>>,
|
||||
notifier: Arc<RwLock<Option<Notifier>>>,
|
||||
pg_pool: PgPool,
|
||||
}
|
||||
|
||||
const TTL: Duration = Duration::from_secs(1);
|
||||
@@ -222,13 +232,19 @@ enum HandleCheckResult {
|
||||
}
|
||||
|
||||
impl AutheliaFS {
|
||||
pub fn new(config: FuseConfig, write_callback: Option<WriteCallback>) -> Self {
|
||||
pub async fn new(
|
||||
config: FuseConfig,
|
||||
write_callback: Option<WriteCallback>,
|
||||
pg_pool: PgPool,
|
||||
) -> Self {
|
||||
let contents = String::new();
|
||||
let time = SystemTime::now();
|
||||
|
||||
let uid = getuid();
|
||||
let gid = getgid();
|
||||
|
||||
let _ = fs::create_dir_all(&config.mount_directory).await;
|
||||
|
||||
let block_size = u32::try_from(
|
||||
stat(config.mount_directory.to_str().unwrap())
|
||||
.unwrap()
|
||||
@@ -236,14 +252,14 @@ impl AutheliaFS {
|
||||
)
|
||||
.unwrap_or(4096);
|
||||
|
||||
let static_file_state = Arc::new(StaticState {
|
||||
let static_state = StaticState {
|
||||
creation_time: time,
|
||||
user: uid,
|
||||
group: gid,
|
||||
block_size,
|
||||
});
|
||||
};
|
||||
|
||||
let variable_file_state = Arc::new(RwLock::new(VariableState {
|
||||
let variable_state = Arc::new(RwLock::new(VariableState {
|
||||
contents,
|
||||
access_time: time,
|
||||
modification_time: time,
|
||||
@@ -254,18 +270,48 @@ impl AutheliaFS {
|
||||
next_handle: 1,
|
||||
}));
|
||||
|
||||
let write_callback = Arc::new(RwLock::new(
|
||||
write_callback.unwrap_or_else(|| Box::new(|_, _| {})),
|
||||
));
|
||||
|
||||
let notifier = Arc::new(RwLock::new(None));
|
||||
|
||||
Self {
|
||||
config,
|
||||
write_callback,
|
||||
variable_state: variable_file_state,
|
||||
static_state: static_file_state,
|
||||
static_state,
|
||||
variable_state,
|
||||
handles,
|
||||
write_callback,
|
||||
notifier,
|
||||
pg_pool,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mount(self) -> std::io::Result<()> {
|
||||
let mountpoint = self.config.mount_directory.clone();
|
||||
fuser::mount2(self, mountpoint, &vec![])
|
||||
pub async fn run(self) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let _ = fs::create_dir_all(&self.config.mount_directory).await;
|
||||
let mut session = Session::new(self.clone(), self.config.mount_directory.clone(), &[])?;
|
||||
self.notifier.write().replace(session.notifier());
|
||||
Ok(spawn_blocking(move || session.run().unwrap()).await?)
|
||||
}
|
||||
|
||||
pub async fn store(&self, contents: String) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let variable_state = self.variable_state.clone();
|
||||
let notifier = self.notifier.clone();
|
||||
|
||||
Ok(spawn_blocking(move || {
|
||||
let mut variable_state = variable_state.write();
|
||||
|
||||
variable_state.contents = contents;
|
||||
variable_state.modification_time = SystemTime::now();
|
||||
variable_state.access_time = SystemTime::now();
|
||||
|
||||
if let Some(notifier) = notifier.write().as_ref() {
|
||||
notifier
|
||||
.store(USERS_FILE_INODE, 0, variable_state.contents.as_bytes())
|
||||
.unwrap();
|
||||
}
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[allow(clippy::fn_params_excessive_bools)]
|
||||
@@ -483,30 +529,36 @@ impl Filesystem for AutheliaFS {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(size) = size {
|
||||
if size == 0 {
|
||||
let mut variable_file_state = self.variable_state.write();
|
||||
variable_file_state.contents.clear();
|
||||
} else {
|
||||
reply.error(ENOSYS);
|
||||
return;
|
||||
if size.is_some() && (atime.is_some() || mtime.is_some()) {
|
||||
let mut variable_state = self.variable_state.write();
|
||||
|
||||
if let Some(size) = size {
|
||||
if size == 0 {
|
||||
variable_state.contents.clear();
|
||||
} else {
|
||||
reply.error(ENOSYS);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mtime.is_some() || atime.is_some() {
|
||||
let mut variable_file_state = self.variable_state.write();
|
||||
|
||||
variable_file_state.modification_time = match mtime {
|
||||
variable_state.modification_time = match mtime {
|
||||
Some(fuser::TimeOrNow::Now) => SystemTime::now(),
|
||||
Some(fuser::TimeOrNow::SpecificTime(time)) => time,
|
||||
None => variable_file_state.modification_time,
|
||||
None => variable_state.modification_time,
|
||||
};
|
||||
|
||||
variable_file_state.access_time = match atime {
|
||||
variable_state.access_time = match atime {
|
||||
Some(fuser::TimeOrNow::Now) => SystemTime::now(),
|
||||
Some(fuser::TimeOrNow::SpecificTime(time)) => time,
|
||||
None => variable_file_state.access_time,
|
||||
None => variable_state.access_time,
|
||||
};
|
||||
|
||||
self.notifier
|
||||
.write()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.store(ino, 0, variable_state.contents.as_bytes())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let attr = file.to_file_attr(self);
|
||||
@@ -544,8 +596,8 @@ impl Filesystem for AutheliaFS {
|
||||
drop(handles);
|
||||
|
||||
if flags & O_TRUNC != 0 && flags & O_ACCMODE != O_RDONLY {
|
||||
let mut variable_file_state = self.variable_state.write();
|
||||
variable_file_state.contents.clear();
|
||||
let mut variable_state = self.variable_state.write();
|
||||
variable_state.contents.clear();
|
||||
}
|
||||
|
||||
reply.opened(handle, 0);
|
||||
@@ -578,11 +630,11 @@ impl Filesystem for AutheliaFS {
|
||||
AccessCheckResult::Ok(_) => {}
|
||||
}
|
||||
|
||||
let mut variable_file_state = self.variable_state.write();
|
||||
variable_file_state.access_time = SystemTime::now();
|
||||
let mut variable_state = self.variable_state.write();
|
||||
variable_state.access_time = SystemTime::now();
|
||||
|
||||
let variable_file_state = RwLockWriteGuard::downgrade(variable_file_state);
|
||||
let contents = variable_file_state.contents.as_bytes();
|
||||
let variable_state = RwLockWriteGuard::downgrade(variable_state);
|
||||
let contents = variable_state.contents.as_bytes();
|
||||
let contents_len = i64::try_from(contents.len()).unwrap();
|
||||
|
||||
if offset < 0 || offset >= contents_len {
|
||||
@@ -626,9 +678,9 @@ impl Filesystem for AutheliaFS {
|
||||
let mut handles = self.handles.write();
|
||||
let handle = handles.handles.get_mut(&fh).unwrap();
|
||||
|
||||
let mut variable_file_state = self.variable_state.write();
|
||||
let mut variable_state = self.variable_state.write();
|
||||
|
||||
let old_end = variable_file_state.contents.len();
|
||||
let old_end = variable_state.contents.len();
|
||||
|
||||
let offset = if handle.flags & O_APPEND != 0 {
|
||||
handle.cursor = i64::try_from(old_end).unwrap();
|
||||
@@ -641,8 +693,8 @@ impl Filesystem for AutheliaFS {
|
||||
usize::try_from(offset).unwrap()
|
||||
};
|
||||
|
||||
variable_file_state.access_time = SystemTime::now();
|
||||
variable_file_state.modification_time = SystemTime::now();
|
||||
variable_state.access_time = SystemTime::now();
|
||||
variable_state.modification_time = SystemTime::now();
|
||||
|
||||
let Ok(new_data) = std::str::from_utf8(data) else {
|
||||
reply.error(EINVAL);
|
||||
@@ -653,22 +705,27 @@ impl Filesystem for AutheliaFS {
|
||||
let new_real_end = cmp::max(new_end, old_end);
|
||||
|
||||
let mut new_contents = String::with_capacity(new_real_end);
|
||||
new_contents.push_str(&variable_file_state.contents[..offset]);
|
||||
new_contents.push_str(&variable_state.contents[..offset]);
|
||||
new_contents.push_str(new_data);
|
||||
if new_end < old_end {
|
||||
new_contents.push_str(&variable_file_state.contents[new_end..]);
|
||||
new_contents.push_str(&variable_state.contents[new_end..]);
|
||||
}
|
||||
variable_file_state.contents = new_contents;
|
||||
variable_state.contents = new_contents;
|
||||
|
||||
handle.cursor = i64::try_from(offset + new_data.len()).unwrap();
|
||||
|
||||
drop(handles);
|
||||
|
||||
if let Some(callback) = &self.write_callback {
|
||||
callback(&variable_file_state.contents);
|
||||
}
|
||||
self.write_callback.read().deref()(&self.pg_pool, &variable_state.contents);
|
||||
|
||||
drop(variable_file_state);
|
||||
self.notifier
|
||||
.write()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.store(ino, 0, variable_state.contents.as_bytes())
|
||||
.unwrap();
|
||||
|
||||
drop(variable_state);
|
||||
|
||||
reply.written(u32::try_from(data.len()).unwrap());
|
||||
}
|
||||
@@ -954,10 +1011,10 @@ impl Filesystem for AutheliaFS {
|
||||
AccessCheckResult::Ok(_) => {}
|
||||
}
|
||||
|
||||
let variable_file_state = self.variable_state.read();
|
||||
let blocks = (variable_file_state.contents.len() as u64)
|
||||
let variable_state = self.variable_state.read();
|
||||
let blocks = (variable_state.contents.len() as u64)
|
||||
.div_ceil(u64::from(self.static_state.block_size));
|
||||
drop(variable_file_state);
|
||||
drop(variable_state);
|
||||
|
||||
reply.statfs(
|
||||
blocks,
|
28
src/main.rs
28
src/main.rs
@@ -2,7 +2,7 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
mod config;
|
||||
mod fuser;
|
||||
mod fuse;
|
||||
mod models;
|
||||
mod routes;
|
||||
mod state;
|
||||
@@ -15,28 +15,26 @@ use log4rs::config::Deserializers;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use config::{Args, Config};
|
||||
use config::Args;
|
||||
use state::State;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = Args::parse();
|
||||
log4rs::init_file(args.log_config, Deserializers::default()).unwrap();
|
||||
let args: Args = Args::parse();
|
||||
log4rs::init_file(args.log_config.clone(), Deserializers::default()).unwrap();
|
||||
|
||||
let config = Config::try_from(&args.config).unwrap();
|
||||
let state = State::from_config(config.clone()).await;
|
||||
let state = State::from_args(args).await;
|
||||
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&state.pg_pool)
|
||||
.await
|
||||
.unwrap();
|
||||
let routes = routes::routes(state.clone());
|
||||
let app = axum::Router::new().nest(&format!("{}/api", state.config.server.subpath), routes);
|
||||
|
||||
let routes = routes::routes(state);
|
||||
let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes);
|
||||
|
||||
let addr = SocketAddr::from((config.server.address, config.server.port));
|
||||
let addr = SocketAddr::from((state.config.server.address, state.config.server.port));
|
||||
let listener = TcpListener::bind(addr).await.unwrap();
|
||||
|
||||
info!("Listening on {}", listener.local_addr().unwrap());
|
||||
serve(listener, app).await.unwrap();
|
||||
|
||||
serve(listener, app)
|
||||
.with_graceful_shutdown(utils::shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
@@ -1,24 +1,102 @@
|
||||
use log::warn;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_yaml::Value;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, error::Error};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsersFile {
|
||||
pub users: HashMap<String, UserFile>,
|
||||
pub struct Users {
|
||||
pub users: HashMap<String, User>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserFile {
|
||||
pub struct User {
|
||||
pub displayname: String,
|
||||
pub password: String,
|
||||
pub email: Option<String>,
|
||||
pub disabled: Option<bool>,
|
||||
pub picture: Option<String>,
|
||||
pub groups: Option<Vec<String>>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl TryInto<Vec<super::users::UserWithGroups>> for Users {
|
||||
type Error = Box<dyn Error + Send + Sync>;
|
||||
|
||||
fn try_into(self) -> Result<Vec<super::users::UserWithGroups>, Self::Error> {
|
||||
self.users
|
||||
.into_iter()
|
||||
.map(|(name, user)| {
|
||||
let groups = user.groups.unwrap_or_default();
|
||||
Ok(super::users::UserWithGroups {
|
||||
name: name.clone(),
|
||||
display_name: user.displayname,
|
||||
password: user.password,
|
||||
email: user
|
||||
.email
|
||||
.ok_or_else(|| format!("User {} is missing an email", &name))?,
|
||||
disabled: user.disabled.unwrap_or(false),
|
||||
picture: user.picture,
|
||||
groups,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Users {
|
||||
pub fn from_fuse(pool: &PgPool, contents: &str) {
|
||||
let Ok(users) = serde_yaml::from_str::<Self>(contents) else {
|
||||
warn!("Failed to parse users from JSON.");
|
||||
return;
|
||||
};
|
||||
|
||||
let users_with_groups: Vec<super::users::UserWithGroups> = match users.try_into() {
|
||||
Ok(users) => users,
|
||||
Err(e) => {
|
||||
warn!("Failed to convert Users to UserWithGroups: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
super::users::UserWithGroups::upsert_many_delete_remaining(pool, &users_with_groups)
|
||||
.await
|
||||
.unwrap_or_else(|e| warn!("Failed to upsert users: {e}"));
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn to_fuse(pool: &PgPool) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
let users_with_groups = super::users::UserWithGroups::select_all(pool).await?;
|
||||
|
||||
let users = Self {
|
||||
users: users_with_groups
|
||||
.into_iter()
|
||||
.map(|user| {
|
||||
(
|
||||
user.name.clone(),
|
||||
User {
|
||||
displayname: user.display_name,
|
||||
password: user.password,
|
||||
email: Some(user.email),
|
||||
disabled: Some(user.disabled),
|
||||
picture: user.picture,
|
||||
groups: Some(user.groups),
|
||||
extra: None,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
extra: None,
|
||||
};
|
||||
|
||||
Ok(serde_yaml::to_string(&users)?)
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ pub struct Group {
|
||||
}
|
||||
|
||||
impl Group {
|
||||
pub async fn select_by_name(
|
||||
pub async fn select(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
|
||||
@@ -17,7 +17,7 @@ impl Group {
|
||||
Group,
|
||||
r#"
|
||||
SELECT name
|
||||
FROM groups
|
||||
FROM glyph_groups
|
||||
WHERE name = $1
|
||||
"#,
|
||||
name
|
||||
@@ -28,13 +28,10 @@ impl Group {
|
||||
Ok(group)
|
||||
}
|
||||
|
||||
pub async fn delete_by_name(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
pub async fn delete(pool: &PgPool, name: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM groups
|
||||
DELETE FROM glyph_groups
|
||||
WHERE name = $1
|
||||
"#,
|
||||
name
|
||||
@@ -45,14 +42,14 @@ impl Group {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn all_exist_by_names(
|
||||
pub async fn all_exist(
|
||||
pool: &PgPool,
|
||||
names: &[String],
|
||||
) -> Result<bool, Box<dyn Error + Send + Sync>> {
|
||||
let row = query!(
|
||||
r#"
|
||||
SELECT COUNT(*) AS "count!"
|
||||
FROM groups
|
||||
FROM glyph_groups
|
||||
WHERE name = ANY($1)
|
||||
"#,
|
||||
names
|
||||
@@ -67,20 +64,19 @@ impl Group {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroupWithUsers {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub users: Vec<String>,
|
||||
}
|
||||
|
||||
impl GroupWithUsers {
|
||||
pub async fn select(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
|
||||
pub async fn select_all(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
|
||||
let groups = query_as!(
|
||||
GroupWithUsers,
|
||||
r#"
|
||||
SELECT
|
||||
g.name,
|
||||
COALESCE(array_agg(ug.user_name ORDER BY ug.user_name), ARRAY[]::TEXT[]) AS "users!"
|
||||
FROM groups g
|
||||
LEFT JOIN users_groups ug ON g.name = ug.group_name
|
||||
GROUP BY g.name
|
||||
ARRAY(SELECT ug.user_name FROM glyph_users_groups ug WHERE ug.group_name = g.name) AS "users!"
|
||||
FROM glyph_groups g
|
||||
"#
|
||||
)
|
||||
.fetch_all(pool)
|
||||
@@ -89,7 +85,7 @@ impl GroupWithUsers {
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
pub async fn select_by_name(
|
||||
pub async fn select(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
|
||||
@@ -98,11 +94,9 @@ impl GroupWithUsers {
|
||||
r#"
|
||||
SELECT
|
||||
g.name,
|
||||
COALESCE(array_agg(ug.user_name ORDER BY ug.user_name), ARRAY[]::TEXT[]) AS "users!"
|
||||
FROM groups g
|
||||
LEFT JOIN users_groups ug ON g.name = ug.group_name
|
||||
ARRAY(SELECT ug.user_name FROM glyph_users_groups ug WHERE ug.group_name = g.name) AS "users!"
|
||||
FROM glyph_groups g
|
||||
WHERE g.name = $1
|
||||
GROUP BY g.name
|
||||
"#,
|
||||
name
|
||||
)
|
||||
@@ -119,7 +113,7 @@ impl GroupWithUsers {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
query!(
|
||||
r#"INSERT INTO groups (name) VALUES ($1)"#,
|
||||
r#"INSERT INTO glyph_groups (name) VALUES ($1)"#,
|
||||
group_with_users.name
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
@@ -127,8 +121,8 @@ impl GroupWithUsers {
|
||||
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
INSERT INTO glyph_users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
"#,
|
||||
&group_with_users.users,
|
||||
&vec![group_with_users.name.clone(); group_with_users.users.len()]
|
||||
|
@@ -19,7 +19,7 @@ impl UsersGroups {
|
||||
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM users_groups
|
||||
DELETE FROM glyph_users_groups
|
||||
WHERE group_name = $1
|
||||
"#,
|
||||
group_name
|
||||
@@ -29,7 +29,7 @@ impl UsersGroups {
|
||||
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO users_groups (user_name, group_name)
|
||||
INSERT INTO glyph_users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
"#,
|
||||
users,
|
||||
@@ -50,7 +50,7 @@ impl UsersGroups {
|
||||
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM users_groups
|
||||
DELETE FROM glyph_users_groups
|
||||
WHERE user_name = $1
|
||||
"#,
|
||||
user_name
|
||||
@@ -60,7 +60,7 @@ impl UsersGroups {
|
||||
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO users_groups (user_name, group_name)
|
||||
INSERT INTO glyph_users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
"#,
|
||||
&vec![user_name.to_string(); groups.len()],
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use std::error::Error;
|
||||
use std::{collections::HashSet, error::Error};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{FromRow, PgPool, query, query_as};
|
||||
@@ -12,20 +12,20 @@ pub struct User {
|
||||
#[serde(default)]
|
||||
pub disabled: bool,
|
||||
#[serde(default)]
|
||||
pub image: Option<String>,
|
||||
pub picture: Option<String>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub async fn select_by_name(
|
||||
pub async fn select(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
|
||||
let user = query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT name, display_name, password, email, disabled, image
|
||||
FROM users
|
||||
WHERE name = $1
|
||||
SELECT name, display_name, password, email, disabled, picture
|
||||
FROM glyph_users
|
||||
WHERE name = $1
|
||||
"#,
|
||||
name
|
||||
)
|
||||
@@ -38,21 +38,21 @@ impl User {
|
||||
pub async fn upsert(pool: &PgPool, user: &Self) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO users (name, display_name, password, email, disabled, image)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET display_name = EXCLUDED.display_name,
|
||||
password = EXCLUDED.password,
|
||||
email = EXCLUDED.email,
|
||||
disabled = EXCLUDED.disabled,
|
||||
image = EXCLUDED.image
|
||||
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET display_name = EXCLUDED.display_name,
|
||||
password = EXCLUDED.password,
|
||||
email = EXCLUDED.email,
|
||||
disabled = EXCLUDED.disabled,
|
||||
picture = EXCLUDED.picture
|
||||
"#,
|
||||
user.name,
|
||||
user.display_name,
|
||||
user.password,
|
||||
user.email,
|
||||
user.disabled,
|
||||
user.image
|
||||
user.picture
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
@@ -60,14 +60,11 @@ impl User {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_by_name(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
pub async fn delete(pool: &PgPool, name: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM users
|
||||
WHERE name = $1
|
||||
DELETE FROM glyph_users
|
||||
WHERE name = $1
|
||||
"#,
|
||||
name
|
||||
)
|
||||
@@ -77,15 +74,15 @@ impl User {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn all_exist_by_names(
|
||||
pub async fn all_exist(
|
||||
pool: &PgPool,
|
||||
names: &[String],
|
||||
) -> Result<bool, Box<dyn Error + Send + Sync>> {
|
||||
let row = query!(
|
||||
r#"
|
||||
SELECT COUNT(*) AS "count!"
|
||||
FROM users
|
||||
WHERE name = ANY($1)
|
||||
SELECT COUNT(*) AS "count!"
|
||||
FROM glyph_users
|
||||
WHERE name = ANY($1)
|
||||
"#,
|
||||
names
|
||||
)
|
||||
@@ -105,26 +102,25 @@ pub struct UserWithGroups {
|
||||
#[serde(default)]
|
||||
pub disabled: bool,
|
||||
#[serde(default)]
|
||||
pub image: Option<String>,
|
||||
pub picture: Option<String>,
|
||||
#[serde(default)]
|
||||
pub groups: Vec<String>,
|
||||
}
|
||||
|
||||
impl UserWithGroups {
|
||||
pub async fn select(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
|
||||
pub async fn select_all(pool: &PgPool) -> Result<Vec<Self>, Box<dyn Error + Send + Sync>> {
|
||||
let users = query_as!(
|
||||
UserWithGroups,
|
||||
r#"
|
||||
SELECT
|
||||
u.name,
|
||||
u.display_name,
|
||||
u.password,
|
||||
u.email,
|
||||
u.disabled,
|
||||
u.image,
|
||||
COALESCE(array_agg(ug.group_name ORDER BY ug.group_name), ARRAY[]::TEXT[]) AS "groups!"
|
||||
FROM users u
|
||||
LEFT JOIN users_groups ug ON u.name = ug.user_name
|
||||
GROUP BY u.name, u.email, u.disabled, u.image
|
||||
SELECT
|
||||
u.name,
|
||||
u.display_name,
|
||||
u.password,
|
||||
u.email,
|
||||
u.disabled,
|
||||
u.picture,
|
||||
ARRAY(SELECT ug.group_name FROM glyph_users_groups ug WHERE ug.user_name = u.name) AS "groups!"
|
||||
FROM glyph_users u
|
||||
"#
|
||||
)
|
||||
.fetch_all(pool)
|
||||
@@ -133,25 +129,23 @@ impl UserWithGroups {
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub async fn select_by_name(
|
||||
pub async fn select(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
) -> Result<Option<Self>, Box<dyn Error + Send + Sync>> {
|
||||
let user = query_as!(
|
||||
UserWithGroups,
|
||||
r#"
|
||||
SELECT
|
||||
u.name,
|
||||
u.display_name,
|
||||
u.password,
|
||||
u.email,
|
||||
u.disabled,
|
||||
u.image,
|
||||
COALESCE(array_agg(ug.group_name ORDER BY ug.group_name), ARRAY[]::TEXT[]) AS "groups!"
|
||||
FROM users u
|
||||
LEFT JOIN users_groups ug ON u.name = ug.user_name
|
||||
WHERE u.name = $1
|
||||
GROUP BY u.name, u.email, u.disabled, u.image
|
||||
SELECT
|
||||
u.name,
|
||||
u.display_name,
|
||||
u.password,
|
||||
u.email,
|
||||
u.disabled,
|
||||
u.picture,
|
||||
ARRAY(SELECT ug.group_name FROM glyph_users_groups ug WHERE ug.user_name = u.name) AS "groups!"
|
||||
FROM glyph_users u
|
||||
WHERE u.name = $1
|
||||
"#,
|
||||
name
|
||||
)
|
||||
@@ -168,23 +162,24 @@ impl UserWithGroups {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
query!(
|
||||
r#"INSERT INTO users (name, display_name, password, email, disabled, image)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
r#"
|
||||
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
"#,
|
||||
user_with_groups.name,
|
||||
user_with_groups.display_name,
|
||||
user_with_groups.password,
|
||||
user_with_groups.email,
|
||||
user_with_groups.disabled,
|
||||
user_with_groups.image
|
||||
user_with_groups.picture
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
INSERT INTO glyph_users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
"#,
|
||||
&user_with_groups.groups,
|
||||
&vec![user_with_groups.name.clone(); user_with_groups.groups.len()]
|
||||
@@ -196,4 +191,93 @@ impl UserWithGroups {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn upsert_many_delete_remaining(
|
||||
pool: &PgPool,
|
||||
users_with_groups: &[Self],
|
||||
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
for user in users_with_groups {
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO glyph_users (name, display_name, password, email, disabled, picture)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET display_name = EXCLUDED.display_name,
|
||||
password = EXCLUDED.password,
|
||||
email = EXCLUDED.email,
|
||||
disabled = EXCLUDED.disabled,
|
||||
picture = EXCLUDED.picture
|
||||
"#,
|
||||
user.name,
|
||||
user.display_name,
|
||||
user.password,
|
||||
user.email,
|
||||
user.disabled,
|
||||
user.picture
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM glyph_users_groups
|
||||
WHERE user_name = $1
|
||||
"#,
|
||||
user.name
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
if !user.groups.is_empty() {
|
||||
query!(
|
||||
r#"
|
||||
INSERT INTO glyph_users_groups (user_name, group_name)
|
||||
SELECT * FROM UNNEST($1::text[], $2::text[])
|
||||
"#,
|
||||
&user.groups,
|
||||
&vec![user.name.clone(); user.groups.len()]
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let users = users_with_groups
|
||||
.iter()
|
||||
.map(|user| user.name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM glyph_users
|
||||
WHERE name <> ALL($1)
|
||||
"#,
|
||||
&users
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let groups = users_with_groups
|
||||
.iter()
|
||||
.flat_map(|user| user.groups.iter().cloned())
|
||||
.collect::<HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
query!(
|
||||
r#"
|
||||
DELETE FROM glyph_groups
|
||||
WHERE name <> ALL($1)
|
||||
"#,
|
||||
&groups
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -35,7 +35,7 @@ pub async fn get_all(
|
||||
_: auth::User,
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let groups_with_users = models::groups::GroupWithUsers::select(&pg_pool)
|
||||
let groups_with_users = models::groups::GroupWithUsers::select_all(&pg_pool)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
@@ -52,7 +52,7 @@ pub async fn get(
|
||||
extract::Path(name): extract::Path<NonEmptyString>,
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let group_with_users = models::groups::GroupWithUsers::select_by_name(&pg_pool, name.as_str())
|
||||
let group_with_users = models::groups::GroupWithUsers::select(&pg_pool, name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
@@ -71,7 +71,7 @@ pub async fn create(
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
extract::Json(group_create): extract::Json<GroupCreate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if models::groups::Group::select_by_name(&pg_pool, group_create.name.as_str())
|
||||
if models::groups::Group::select(&pg_pool, group_create.name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.is_some()
|
||||
@@ -85,7 +85,7 @@ pub async fn create(
|
||||
.map(|u| u.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !models::users::User::all_exist_by_names(&pg_pool, &users)
|
||||
if !models::users::User::all_exist(&pg_pool, &users)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
{
|
||||
@@ -116,7 +116,7 @@ pub async fn update(
|
||||
extract::State(config): extract::State<Config>,
|
||||
extract::Json(group_update): extract::Json<GroupUpdate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let group = models::groups::Group::select_by_name(&pg_pool, name.as_str())
|
||||
let group = models::groups::Group::select(&pg_pool, name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
@@ -126,7 +126,7 @@ pub async fn update(
|
||||
if let Some(users) = &group_update.users {
|
||||
let users = users.iter().map(ToString::to_string).collect::<Vec<_>>();
|
||||
|
||||
if !models::users::User::all_exist_by_names(&pg_pool, &users)
|
||||
if !models::users::User::all_exist(&pg_pool, &users)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
{
|
||||
@@ -163,12 +163,12 @@ pub async fn delete(
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let group = models::groups::Group::select_by_name(&pg_pool, &name)
|
||||
let group = models::groups::Group::select(&pg_pool, &name)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
Group::delete_by_name(&pg_pool, &group.name)
|
||||
Group::delete(&pg_pool, &group.name)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
|
@@ -21,7 +21,7 @@ struct UserResponse {
|
||||
display_name: String,
|
||||
email: String,
|
||||
disabled: bool,
|
||||
image: Option<String>,
|
||||
picture: Option<String>,
|
||||
groups: Vec<String>,
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ impl From<models::users::UserWithGroups> for UserResponse {
|
||||
display_name: user.display_name,
|
||||
email: user.email,
|
||||
disabled: user.disabled,
|
||||
image: user.image,
|
||||
picture: user.picture,
|
||||
groups: user.groups,
|
||||
}
|
||||
}
|
||||
@@ -43,7 +43,7 @@ pub async fn get_all(
|
||||
_: auth::User,
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let users_with_groups = models::users::UserWithGroups::select(&pg_pool)
|
||||
let users_with_groups = models::users::UserWithGroups::select_all(&pg_pool)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
@@ -60,7 +60,7 @@ pub async fn get(
|
||||
extract::Path(name): extract::Path<NonEmptyString>,
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let user_with_groups = models::users::UserWithGroups::select_by_name(&pg_pool, name.as_str())
|
||||
let user_with_groups = models::users::UserWithGroups::select(&pg_pool, name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
@@ -74,7 +74,7 @@ pub struct UserCreate {
|
||||
displayname: NonEmptyString,
|
||||
email: NonEmptyString,
|
||||
disabled: bool,
|
||||
image: Option<NonEmptyString>,
|
||||
picture: Option<NonEmptyString>,
|
||||
groups: Vec<NonEmptyString>,
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ pub async fn create(
|
||||
extract::State(pg_pool): extract::State<PgPool>,
|
||||
extract::Json(user_create): extract::Json<UserCreate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
if models::users::User::select_by_name(&pg_pool, user_create.name.as_str())
|
||||
if models::users::User::select(&pg_pool, user_create.name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.is_some()
|
||||
@@ -97,7 +97,7 @@ pub async fn create(
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !models::groups::Group::all_exist_by_names(&pg_pool, &groups)
|
||||
if !models::groups::Group::all_exist(&pg_pool, &groups)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
{
|
||||
@@ -110,7 +110,7 @@ pub async fn create(
|
||||
password: generate_random_password_hash(),
|
||||
email: user_create.email.to_string(),
|
||||
disabled: user_create.disabled,
|
||||
image: user_create.image.map(|i| i.to_string()),
|
||||
picture: user_create.picture.map(|i| i.to_string()),
|
||||
groups,
|
||||
};
|
||||
|
||||
@@ -126,7 +126,7 @@ pub struct UserUpdate {
|
||||
display_name: Option<NonEmptyString>,
|
||||
email: Option<NonEmptyString>,
|
||||
disabled: Option<bool>,
|
||||
image: Option<NonEmptyString>,
|
||||
picture: Option<NonEmptyString>,
|
||||
groups: Option<Vec<NonEmptyString>>,
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ pub async fn update(
|
||||
extract::State(config): extract::State<Config>,
|
||||
extract::Json(user_update): extract::Json<UserUpdate>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let user = models::users::User::select_by_name(&pg_pool, name.as_str())
|
||||
let user = models::users::User::select(&pg_pool, name.as_str())
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
@@ -150,7 +150,7 @@ pub async fn update(
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !models::groups::Group::all_exist_by_names(&pg_pool, &groups)
|
||||
if !models::groups::Group::all_exist(&pg_pool, &groups)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
{
|
||||
@@ -183,7 +183,7 @@ pub async fn update(
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or(user.email),
|
||||
disabled: user_update.disabled.unwrap_or(user.disabled),
|
||||
image: user_update.image.map(|i| i.to_string()).or(user.image),
|
||||
picture: user_update.picture.map(|i| i.to_string()).or(user.picture),
|
||||
};
|
||||
|
||||
models::users::User::upsert(&pg_pool, &user)
|
||||
@@ -206,12 +206,12 @@ pub async fn delete(
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let user = models::users::User::select_by_name(&pg_pool, &name)
|
||||
let user = models::users::User::select(&pg_pool, &name)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
models::users::User::delete_by_name(&pg_pool, &user.name)
|
||||
models::users::User::delete(&pg_pool, &user.name)
|
||||
.await
|
||||
.or(Err(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
|
204
src/state.rs
204
src/state.rs
@@ -1,3 +1,5 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use axum::extract::FromRef;
|
||||
use openidconnect::{
|
||||
@@ -10,11 +12,14 @@ use openidconnect::{
|
||||
},
|
||||
reqwest,
|
||||
};
|
||||
use redis::{self, AsyncCommands};
|
||||
use sqlx::{PgPool, postgres::PgPoolOptions};
|
||||
use tokio::spawn;
|
||||
use tokio::{process::Command, spawn, task::JoinHandle, time::sleep};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{
|
||||
config::{Args, Config},
|
||||
fuse::AutheliaFS,
|
||||
models,
|
||||
};
|
||||
|
||||
pub type OAuthClient<
|
||||
HasAuthUrl = EndpointSet,
|
||||
@@ -46,26 +51,44 @@ pub type OAuthClient<
|
||||
#[derive(Clone)]
|
||||
pub struct State {
|
||||
pub config: Config,
|
||||
pub oauth_http_client: reqwest::Client,
|
||||
pub oauth_client: OAuthClient,
|
||||
pub pg_pool: PgPool,
|
||||
pub redis_client: redis::aio::MultiplexedConnection,
|
||||
pub filesystem: AutheliaFS,
|
||||
pub mount: Arc<JoinHandle<()>>,
|
||||
pub authelia: Arc<JoinHandle<()>>,
|
||||
pub oauth_http_client: reqwest::Client,
|
||||
pub oauth_client: OAuthClient,
|
||||
pub session_store: RedisSessionStore,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub async fn from_config(config: Config) -> Self {
|
||||
let (oauth_http_client, oauth_client) = oauth_client(&config).await;
|
||||
pub async fn from_args(args: Args) -> Self {
|
||||
let config = Config::try_from(&args.config).unwrap();
|
||||
|
||||
let pg_pool = pg_pool(&config).await;
|
||||
sqlx::migrate!("./migrations").run(&pg_pool).await.unwrap();
|
||||
config.admin.upsert(&pg_pool).await.unwrap();
|
||||
|
||||
let redis_client = redis_client(&config).await;
|
||||
|
||||
let (filesystem, mount) = fuse(&config, &pg_pool).await;
|
||||
let contents = models::authelia::Users::to_fuse(&pg_pool).await.unwrap();
|
||||
filesystem.store(contents).await.unwrap();
|
||||
|
||||
let authelia = authelia(args.passthrough);
|
||||
|
||||
let (oauth_http_client, oauth_client) = oauth_client(&config).await;
|
||||
let session_store = session_store(&config);
|
||||
|
||||
Self {
|
||||
config,
|
||||
oauth_http_client,
|
||||
oauth_client,
|
||||
pg_pool,
|
||||
redis_client,
|
||||
filesystem,
|
||||
mount,
|
||||
authelia,
|
||||
oauth_http_client,
|
||||
oauth_client,
|
||||
session_store,
|
||||
}
|
||||
}
|
||||
@@ -77,18 +100,6 @@ impl FromRef<State> for Config {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for reqwest::Client {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_http_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for OAuthClient {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for PgPool {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.pg_pool.clone()
|
||||
@@ -101,42 +112,30 @@ impl FromRef<State> for redis::aio::MultiplexedConnection {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for AutheliaFS {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.filesystem.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for reqwest::Client {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_http_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for OAuthClient {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.oauth_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for RedisSessionStore {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.session_store.clone()
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_client(config: &Config) -> (reqwest::Client, OAuthClient) {
|
||||
let oauth_http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.danger_accept_invalid_certs(config.oauth.insecure)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(config.oauth.issuer_url.clone()).unwrap(),
|
||||
&oauth_http_client,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let oauth_client = OAuthClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(config.oauth.client_id.clone()),
|
||||
Some(ClientSecret::new(config.oauth.client_secret.clone())),
|
||||
)
|
||||
.set_redirect_uri(
|
||||
RedirectUrl::new(format!(
|
||||
"{}{}/api/auth/callback",
|
||||
config.server.host, config.server.subpath
|
||||
))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
(oauth_http_client, oauth_client)
|
||||
}
|
||||
|
||||
async fn pg_pool(config: &Config) -> PgPool {
|
||||
PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
@@ -159,43 +158,86 @@ async fn redis_client(config: &Config) -> redis::aio::MultiplexedConnection {
|
||||
);
|
||||
|
||||
let client = redis::Client::open(url).unwrap();
|
||||
let mut connection = client.get_multiplexed_async_connection().await.unwrap();
|
||||
client.get_multiplexed_async_connection().await.unwrap()
|
||||
}
|
||||
|
||||
let _: () = redis::cmd("CONFIG")
|
||||
.arg("SET")
|
||||
.arg("notify-keyspace-events")
|
||||
.arg("Ex")
|
||||
.query_async(&mut connection)
|
||||
.await
|
||||
async fn fuse(config: &Config, pg_pool: &PgPool) -> (AutheliaFS, Arc<JoinHandle<()>>) {
|
||||
let fs = AutheliaFS::new(
|
||||
config.fuse.clone(),
|
||||
Some(Box::new(models::authelia::Users::from_fuse)),
|
||||
pg_pool.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let fs_clone = fs.clone();
|
||||
let mount = Arc::new(spawn(async move {
|
||||
loop {
|
||||
let _ = fs_clone.clone().run().await;
|
||||
}
|
||||
}));
|
||||
|
||||
(fs, mount)
|
||||
}
|
||||
|
||||
fn authelia(args: Vec<String>) -> Arc<JoinHandle<()>> {
|
||||
Arc::new(spawn(async move {
|
||||
loop {
|
||||
let _ = Command::new("authelia")
|
||||
.args(args.clone())
|
||||
.spawn()
|
||||
.unwrap()
|
||||
.wait()
|
||||
.await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn oauth_client(config: &Config) -> (reqwest::Client, OAuthClient) {
|
||||
let oauth_http_client = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.danger_accept_invalid_certs(config.oauth.insecure)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let database = config.redis.database.to_string();
|
||||
spawn(async move {
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let rconfig = redis::AsyncConnectionConfig::new().set_push_sender(tx);
|
||||
let mut connection = client
|
||||
.get_multiplexed_async_connection_with_config(&rconfig)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut provider_metadata = None;
|
||||
|
||||
let channel = format!("__keyevent@{database}__:expired");
|
||||
connection.subscribe(&[channel]).await.unwrap();
|
||||
let retries = 10;
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
if let Some(msg) = redis::Msg::from_push_info(msg) {
|
||||
if let Ok(key) = msg.get_payload::<String>() {
|
||||
if !key.starts_with("invite:") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let id = key.trim_start_matches("invite:").to_string();
|
||||
let _: i64 = connection.srem("invite:all", id).await.unwrap();
|
||||
}
|
||||
}
|
||||
for i in 0..retries {
|
||||
if let Ok(metadata) = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(config.oauth.issuer_url.clone()).unwrap(),
|
||||
&oauth_http_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
provider_metadata = Some(metadata);
|
||||
break;
|
||||
}
|
||||
if i == retries - 1 {
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
connection
|
||||
sleep(backoff).await;
|
||||
backoff *= 2;
|
||||
}
|
||||
|
||||
let provider_metadata = provider_metadata.unwrap();
|
||||
|
||||
let oauth_client = OAuthClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(config.oauth.client_id.clone()),
|
||||
Some(ClientSecret::new(config.oauth.client_secret.clone())),
|
||||
)
|
||||
.set_redirect_uri(
|
||||
RedirectUrl::new(format!(
|
||||
"{}{}/api/auth/callback",
|
||||
config.server.host, config.server.subpath
|
||||
))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
(oauth_http_client, oauth_client)
|
||||
}
|
||||
|
||||
fn session_store(config: &Config) -> RedisSessionStore {
|
||||
|
@@ -33,3 +33,20 @@ pub fn generate_random_password_hash() -> String {
|
||||
|
||||
password_hash
|
||||
}
|
||||
|
||||
pub fn hash_password(password: &str) -> String {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
|
||||
let argon2 = Argon2::new(
|
||||
argon2::Algorithm::Argon2id,
|
||||
argon2::Version::V0x13,
|
||||
argon2::Params::new(65536, 3, 4, Some(32)).unwrap(),
|
||||
);
|
||||
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
password_hash
|
||||
}
|
||||
|
@@ -1 +1,21 @@
|
||||
use tokio::{select, signal};
|
||||
|
||||
pub mod crypto;
|
||||
|
||||
pub async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c().await.unwrap();
|
||||
};
|
||||
|
||||
let terminate = async {
|
||||
signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.unwrap()
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
select! {
|
||||
() = ctrl_c => {},
|
||||
() = terminate => {},
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user