diff --git a/.sqlx/query-ac624e63c0bb3b353838703deeaf1f1724edfa5462000f9d6148602d6f3d2431.json b/.sqlx/query-ac624e63c0bb3b353838703deeaf1f1724edfa5462000f9d6148602d6f3d2431.json new file mode 100644 index 0000000..ffa2878 --- /dev/null +++ b/.sqlx/query-ac624e63c0bb3b353838703deeaf1f1724edfa5462000f9d6148602d6f3d2431.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT name, address, port, private_key, default_network_netmask\n FROM interfaces\n WHERE name = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "address", + "type_info": "Inet" + }, + { + "ordinal": 2, + "name": "port", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "private_key", + "type_info": "Bytea" + }, + { + "ordinal": 4, + "name": "default_network_netmask", + "type_info": "Int2" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "ac624e63c0bb3b353838703deeaf1f1724edfa5462000f9d6148602d6f3d2431" +} diff --git a/.sqlx/query-f4bf4e45fbd5a0590653d6d89ecab8019f08e7f22c3cf7c45c164afe57017cb2.json b/.sqlx/query-f4bf4e45fbd5a0590653d6d89ecab8019f08e7f22c3cf7c45c164afe57017cb2.json new file mode 100644 index 0000000..3294952 --- /dev/null +++ b/.sqlx/query-f4bf4e45fbd5a0590653d6d89ecab8019f08e7f22c3cf7c45c164afe57017cb2.json @@ -0,0 +1,18 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO interfaces (name, address, port, private_key, default_network_netmask)\n VALUES ($1, $2, $3, $4, $5)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Inet", + "Int4", + "Bytea", + "Int2" + ] + }, + "nullable": [] + }, + "hash": "f4bf4e45fbd5a0590653d6d89ecab8019f08e7f22c3cf7c45c164afe57017cb2" +} diff --git a/Cargo.lock b/Cargo.lock index b69d536..2cc2410 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -566,6 +566,7 @@ dependencies = [ "digest 0.10.7", "fiat-crypto", "rustc_version", + "serde", "subtle", "zeroize", ] @@ -1396,6 +1397,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1571,6 +1581,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mktemp" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69fed8fbcd01affec44ac226784c6476a6006d98d13e33bc0ca7977aaf046bd8" +dependencies = [ + "uuid", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2342,6 +2361,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2430,6 +2458,7 @@ dependencies = [ "hashbrown 0.15.2", "hashlink", "indexmap 2.8.0", + "ipnetwork", "log", "memchr", "once_cell", @@ -2548,6 +2577,7 @@ dependencies = [ "hkdf", "hmac 0.12.1", "home", + "ipnetwork", "itoa", "log", "md-5", @@ -2791,6 +2821,7 @@ dependencies = [ "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -2987,6 +3018,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "getrandom 0.2.15", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -3000,15 +3040,19 @@ dependencies = [ "async-session", "axum", "axum-extra", + "base64 0.22.1", "clap", "log", "log4rs", + "mktemp", "openidconnect", "serde", "serde_json", "serde_yaml", "sqlx", "tokio", + "x25519-dalek", + "zeroize", ] [[package]] @@ -3456,6 +3500,18 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "yoke" version = "0.7.5" @@ -3526,6 +3582,20 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] [[package]] name = "zerovec" diff --git a/Cargo.toml b/Cargo.toml index 68ea091..b455cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,12 +17,16 @@ codegen-units = 1 async-session = "3.0.0" axum = "0.8.1" axum-extra = { version = "0.10.0", features = ["typed-header"] } +base64 = "0.22.1" clap = { version = "4.5.32", features = ["derive"] } log = "0.4.27" log4rs = "1.3.0" +mktemp = "0.5.1" openidconnect = { version = "4.0.0", features = ["reqwest"] } serde = "1.0.219" serde_json = "1.0.140" serde_yaml = "0.9.34" -sqlx = { version = "0.8.3", features = ["postgres", "runtime-tokio", "time"] } -tokio = { version = "1.44.1", features = ["rt-multi-thread"] } +sqlx = { version = "0.8.3", features = ["ipnetwork", "postgres", "runtime-tokio", "time"] } +tokio = { version = "1.44.1", features = ["process", "rt-multi-thread"] } +x25519-dalek = { version = "2.0.1", features = ["getrandom", "serde", "static_secrets"] } +zeroize = "1.8.1" diff --git a/Containerfile b/Containerfile index 462d317..b0d6b6b 100644 --- a/Containerfile +++ b/Containerfile @@ -18,7 +18,7 @@ RUN cargo build --target=x86_64-unknown-linux-musl --release FROM docker.io/library/alpine -RUN apk add --no-cache wireguard-tools iptables +RUN apk add --no-cache wireguard-tools iptables iproute2 COPY --from=builder /app/target/x86_64-unknown-linux-musl/release/veil /usr/local/bin/veil diff --git a/manifest.yaml b/manifest.yaml index d767ff5..c83f80a 100644 --- a/manifest.yaml +++ b/manifest.yaml @@ -4,6 +4,25 @@ metadata: name: veil spec: containers: + - name: veil + image: registry.karaolidis.com/karaolidis/veil:latest + volumeMounts: + - name: veil-config + mountPath: /etc/veil + command: + [ + "veil", + "--config", + "/etc/veil/default.yml", + --log-config, + "/etc/veil/log4rs.yml", + ] + securityContext: + capabilities: + add: + - NET_ADMIN + - NET_RAW + - name: postgresql image: docker.io/library/postgres:latest env: @@ -23,20 +42,6 @@ spec: - name: authelia-config mountPath: /config - - name: veil - image: registry.karaolidis.com/karaolidis/veil:latest - volumeMounts: - - name: veil-config - mountPath: /etc/veil - command: - [ - "veil", - "--config", - "/etc/veil/default.yml", - --log-config, - "/etc/veil/log4rs.yml", - ] - - name: traefik image: docker.io/library/traefik:latest args: diff --git a/migrations/20250327112746_init.sql b/migrations/20250327112746_init.sql index e69de29..2e4360d 100644 --- a/migrations/20250327112746_init.sql +++ b/migrations/20250327112746_init.sql @@ -0,0 +1,7 @@ +CREATE TABLE interfaces ( + name TEXT PRIMARY KEY, + address INET NOT NULL, + port INTEGER NOT NULL CHECK (port > 0 AND port <= 65535), + private_key BYTEA NOT NULL, + default_network_netmask SMALLINT NOT NULL CHECK (default_network_netmask >= 0 AND default_network_netmask <= 32) +); diff --git a/src/config.rs b/src/config.rs index f95905b..371b3bb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,27 +1,30 @@ +use base64::{Engine, prelude::BASE64_STANDARD}; use clap::Parser; use serde::Deserialize; +use sqlx::types::ipnetwork::IpNetwork; use std::{ fs, net::{IpAddr, Ipv4Addr}, path::PathBuf, }; +use x25519_dalek::StaticSecret; #[derive(Clone, Deserialize)] pub struct ServerConfig { pub host: String, - #[serde(default = "default_address")] + #[serde(default = "default_server_address")] pub address: std::net::IpAddr, - #[serde(default = "default_port")] + #[serde(default = "default_server_port")] pub port: u16, #[serde(default)] pub subpath: String, } -const fn default_address() -> IpAddr { +const fn default_server_address() -> IpAddr { IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)) } -const fn default_port() -> u16 { +const fn default_server_port() -> u16 { 51821 } @@ -34,7 +37,7 @@ pub struct DatabaseConfig { pub database: String, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Clone, Deserialize)] pub struct OAuthConfig { pub issuer_url: String, pub client_id: String, @@ -45,11 +48,60 @@ pub struct OAuthConfig { pub admin_group: Option, } +#[derive(Clone, Deserialize)] +pub struct WireguardConfig { + #[serde(default = "default_wireguard_address")] + pub address: IpNetwork, + #[serde(default = "default_wireguard_port")] + pub port: u16, + #[serde(default = "default_wireguard_interface")] + pub interface: String, + #[serde(default = "default_wireguard_private_key")] + pub private_key: String, + #[serde(default = "default_wireguard_default_network_netmask")] + pub default_network_netmask: u8, +} + +fn default_wireguard_address() -> IpNetwork { + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 8).unwrap() +} + +const fn default_wireguard_port() -> u16 { + 51820 +} + +fn default_wireguard_interface() -> String { + "wg0".to_string() +} + +fn default_wireguard_private_key() -> String { + let private_key = StaticSecret::random(); + BASE64_STANDARD.encode(private_key.as_bytes()) +} + +const fn default_wireguard_default_network_netmask() -> u8 { + 24 +} + +impl Default for WireguardConfig { + fn default() -> Self { + Self { + address: default_wireguard_address(), + port: default_wireguard_port(), + interface: default_wireguard_interface(), + private_key: default_wireguard_private_key(), + default_network_netmask: default_wireguard_default_network_netmask(), + } + } +} + #[derive(Clone, Deserialize)] pub struct Config { pub server: ServerConfig, pub database: DatabaseConfig, pub oauth: OAuthConfig, + #[serde(default)] + pub wireguard: WireguardConfig, } impl Config { diff --git a/src/main.rs b/src/main.rs index 0e26843..8214638 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,11 +7,14 @@ mod routes; mod state; use axum::serve; +use base64::{Engine, prelude::BASE64_STANDARD}; use clap::Parser; use log::info; use log4rs::config::Deserializers; -use std::net::SocketAddr; -use tokio::net::TcpListener; +use mktemp::Temp; +use models::interface::Interface; +use std::{error::Error, fs::File, io::Write, net::SocketAddr}; +use tokio::{net::TcpListener, process::Command}; use config::{Args, Config}; use state::State; @@ -23,10 +26,7 @@ async fn main() { let config = Config::from_yaml(&args.config).unwrap(); let state = State::from_config(config.clone()).await.unwrap(); - sqlx::migrate!("./migrations") - .run(&state.pg_pool) - .await - .expect("Failed to run migrations"); + init(&state).await.unwrap(); let routes = routes::routes(state); let app = axum::Router::new().nest(&format!("{}/api", config.server.subpath), routes); @@ -37,3 +37,113 @@ async fn main() { info!("Listening on {}", listener.local_addr().unwrap()); serve(listener, app).await.unwrap(); } + +async fn init(state: &State) -> Result<(), Box> { + sqlx::migrate!("./migrations") + .run(&state.pg_pool) + .await + .expect("Failed to run migrations"); + + let interface_name = &state.config.wireguard.interface; + let interface = { + let maybe_interface = Interface::select_by_name(&state.pg_pool, interface_name).await?; + + if let Some(interface) = maybe_interface { + interface + } else { + let interface = Interface::try_from(state.config.wireguard.clone())?; + Interface::insert(&state.pg_pool, interface.clone()).await?; + interface + } + }; + + let private_key_file_path = Temp::new_file()?; + File::options() + .write(true) + .open(&private_key_file_path)? + .write_all( + BASE64_STANDARD + .encode(interface.private_key.to_bytes()) + .as_bytes(), + )?; + + if !Command::new("ip") + .args(["link", "add", "dev", interface_name, "type", "wireguard"]) + .status() + .await? + .success() + { + return Err("Failed to create WireGuard interface".into()); + } + + if !Command::new("ip") + .args([ + "address", + "add", + &interface.address.to_string(), + "dev", + interface_name, + ]) + .status() + .await? + .success() + { + return Err("Failed to assign IP address".into()); + } + + if !Command::new("wg") + .args([ + "set", + interface_name, + "listen-port", + &interface.port.to_string(), + "private-key", + private_key_file_path + .to_str() + .ok_or("Invalid private key file path")?, + ]) + .status() + .await? + .success() + { + return Err("Failed to set WireGuard interface options".into()); + } + + if !Command::new("ip") + .args(["link", "set", "up", "dev", interface_name]) + .status() + .await? + .success() + { + return Err("Failed to set WireGuard interface up".into()); + } + + if !Command::new("iptables") + .args([ + "-t", + "nat", + "-A", + "POSTROUTING", + "-o", + "eth0", + "-j", + "MASQUERADE", + ]) + .status() + .await? + .success() + { + return Err("Failed to set iptables NAT rule".into()); + } + + if !Command::new("iptables") + .args(["-P", "FORWARD", "DROP"]) + .status() + .await? + .success() + { + return Err("Failed to set FORWARD policy to DROP".into()); + } + + Ok(()) +} diff --git a/src/models/interface.rs b/src/models/interface.rs new file mode 100644 index 0000000..f5a1f5e --- /dev/null +++ b/src/models/interface.rs @@ -0,0 +1,116 @@ +use std::error::Error; + +use base64::{Engine, prelude::BASE64_STANDARD}; +use sqlx::{PgPool, query, query_as, types::ipnetwork::IpNetwork}; +use x25519_dalek::StaticSecret; + +use crate::config::WireguardConfig; + +#[derive(Clone)] +pub struct Interface { + pub name: String, + pub address: IpNetwork, + pub port: u16, + pub private_key: StaticSecret, + pub default_network_netmask: u8, +} + +struct InterfacePostgres { + name: String, + address: IpNetwork, + port: i32, + private_key: Vec, + default_network_netmask: i16, +} + +impl TryFrom for Interface { + type Error = Box; + + fn try_from(config: WireguardConfig) -> Result { + Ok(Self { + name: config.interface, + address: config.address, + port: config.port, + private_key: { + let decoded_key = BASE64_STANDARD.decode(config.private_key)?; + let key_array: [u8; 32] = + decoded_key.try_into().map_err(|_| "Invalid key length")?; + StaticSecret::from(key_array) + }, + default_network_netmask: config.default_network_netmask, + }) + } +} + +// We allow .unwrap() here because we set the lengths of the variables ourselves +#[allow(clippy::fallible_impl_from)] +impl From for Interface { + fn from(row: InterfacePostgres) -> Self { + Self { + name: row.name, + address: row.address, + port: row.port.try_into().unwrap(), + private_key: { + let key_array: [u8; 32] = row.private_key.try_into().unwrap(); + StaticSecret::from(key_array) + }, + default_network_netmask: row.default_network_netmask.try_into().unwrap(), + } + } +} + +impl From for InterfacePostgres { + fn from(interface: Interface) -> Self { + Self { + name: interface.name, + address: interface.address, + port: i32::from(interface.port), + private_key: interface.private_key.to_bytes().to_vec(), + default_network_netmask: i16::from(interface.default_network_netmask), + } + } +} + +impl Interface { + pub async fn insert( + pool: &PgPool, + interface: Self, + ) -> Result<(), Box> { + let interface = InterfacePostgres::from(interface); + + query!( + r#" + INSERT INTO interfaces (name, address, port, private_key, default_network_netmask) + VALUES ($1, $2, $3, $4, $5) + "#, + interface.name, + interface.address, + interface.port, + interface.private_key, + interface.default_network_netmask, + ) + .execute(pool) + .await?; + + Ok(()) + } + + pub async fn select_by_name( + pool: &PgPool, + name: &str, + ) -> Result, Box> { + let row = query_as!( + InterfacePostgres, + r#" + SELECT name, address, port, private_key, default_network_netmask + FROM interfaces + WHERE name = $1 + "#, + name + ) + .fetch_optional(pool) + .await?; + + row.map_or_else(|| Ok(None), |row| Ok(Some(Self::from(row)))) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index e69de29..8d3d626 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -0,0 +1 @@ +pub mod interface; diff --git a/src/state.rs b/src/state.rs index 74e4df5..9f13edd 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use async_session::MemoryStore; use axum::extract::FromRef; use log::error; @@ -11,6 +13,7 @@ use openidconnect::{ }, reqwest, }; +use sqlx::{PgPool, postgres::PgPoolOptions}; use tokio::{ spawn, time::{Duration, sleep}, @@ -48,59 +51,17 @@ pub type OAuthClient< #[derive(Clone)] pub struct State { pub config: Config, - pub pg_pool: sqlx::PgPool, + pub pg_pool: PgPool, pub oauth_http_client: reqwest::Client, pub oauth_client: OAuthClient, - pub session_store: async_session::MemoryStore, + pub session_store: MemoryStore, } impl State { - pub async fn from_config(config: Config) -> Result> { - let pg_pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(5) - .connect(&format!( - "postgres://{}:{}@{}:{}/{}", - config.database.user, - config.database.password, - config.database.host, - config.database.port, - config.database.database - )) - .await?; - - let oauth_http_client = reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .danger_accept_invalid_certs(config.oauth.insecure) - .build()?; - - let provider_metadata = CoreProviderMetadata::discover_async( - IssuerUrl::new(config.oauth.issuer_url.clone())?, - &oauth_http_client, - ) - .await?; - - let oauth_client = OAuthClient::from_provider_metadata( - provider_metadata, - ClientId::new(config.oauth.client_id.clone()), - Some(ClientSecret::new(config.oauth.client_secret.clone())), - ) - .set_redirect_uri(RedirectUrl::new(format!( - "{}{}/api/auth/callback", - config.server.host, config.server.subpath - ))?); - - let session_store = MemoryStore::new(); - - let session_store_clone = session_store.clone(); - spawn(async move { - loop { - match session_store_clone.cleanup().await { - Ok(()) => {} - Err(e) => error!("Failed to clean up session store: {e}"), - } - sleep(Duration::from_secs(60)).await; - } - }); + pub async fn from_config(config: Config) -> Result> { + let pg_pool = pg_pool(&config).await?; + let (oauth_http_client, oauth_client) = oauth(&config).await?; + let session_store = session_store(); Ok(Self { config, @@ -118,7 +79,7 @@ impl FromRef for Config { } } -impl FromRef for sqlx::PgPool { +impl FromRef for PgPool { fn from_ref(state: &State) -> Self { state.pg_pool.clone() } @@ -136,8 +97,68 @@ impl FromRef for OAuthClient { } } -impl FromRef for async_session::MemoryStore { +impl FromRef for MemoryStore { fn from_ref(state: &State) -> Self { state.session_store.clone() } } + +async fn pg_pool(config: &Config) -> Result> { + let pg_pool = PgPoolOptions::new() + .max_connections(5) + .connect(&format!( + "postgres://{}:{}@{}:{}/{}", + config.database.user, + config.database.password, + config.database.host, + config.database.port, + config.database.database + )) + .await?; + + Ok(pg_pool) +} + +async fn oauth( + config: &Config, +) -> Result<(reqwest::Client, OAuthClient), Box> { + let oauth_http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .danger_accept_invalid_certs(config.oauth.insecure) + .build()?; + + let provider_metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(config.oauth.issuer_url.clone())?, + &oauth_http_client, + ) + .await?; + + let oauth_client = OAuthClient::from_provider_metadata( + provider_metadata, + ClientId::new(config.oauth.client_id.clone()), + Some(ClientSecret::new(config.oauth.client_secret.clone())), + ) + .set_redirect_uri(RedirectUrl::new(format!( + "{}{}/api/auth/callback", + config.server.host, config.server.subpath + ))?); + + Ok((oauth_http_client, oauth_client)) +} + +fn session_store() -> MemoryStore { + let session_store = MemoryStore::new(); + + let session_store_clone = session_store.clone(); + spawn(async move { + loop { + match session_store_clone.cleanup().await { + Ok(()) => {} + Err(e) => error!("Failed to clean up session store: {e}"), + } + sleep(Duration::from_secs(60)).await; + } + }); + + session_store +}