diff --git a/Cargo.lock b/Cargo.lock index 13dae28a6..7f4799e6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -866,6 +866,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-link", ] @@ -974,9 +975,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" dependencies = [ "crossbeam-utils", ] @@ -1225,7 +1226,7 @@ dependencies = [ "tap", "thiserror 1.0.69", "tokio 1.44.2", - "tokio-rustls 0.26.2", + "tokio-rustls", "tracing", "uuid", "win-api-wrappers", @@ -1305,7 +1306,7 @@ dependencies = [ "thiserror 1.0.69", "time", "tokio 1.44.2", - "tokio-rustls 0.26.2", + "tokio-rustls", "tokio-test", "tokio-tungstenite", "tower 0.5.2", @@ -1368,6 +1369,7 @@ dependencies = [ "bb8-postgres", "camino", "cfg-if", + "chrono", "devolutions-agent-shared", "devolutions-gateway-task", "devolutions-pedm-shared", @@ -2338,24 +2340,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper-rustls" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "399c78f9338483cb7e630c8474b07268983c6bd5acee012e4211f9f7bb21b070" -dependencies = [ - "futures-util", - "http 0.2.12", - "hyper 0.14.32", - "log", - "rustls 0.22.4", - "rustls-native-certs 0.7.3", - "rustls-pki-types", - "tokio 1.44.2", - "tokio-rustls 0.25.0", - "webpki-roots", -] - [[package]] name = "hyper-rustls" version = "0.27.5" @@ -2367,10 +2351,10 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "rustls 0.23.25", - "rustls-native-certs 0.8.1", + "rustls-native-certs", "rustls-pki-types", "tokio 1.44.2", - "tokio-rustls 0.26.2", + "tokio-rustls", "tower-service", ] @@ -2895,7 +2879,7 @@ dependencies = [ "ironrdp-svc", "ironrdp-tokio", "tokio 1.44.2", - "tokio-rustls 0.26.2", + "tokio-rustls", "tracing", ] @@ -2966,7 +2950,7 @@ dependencies = [ "proxy-types", "proxy_cfg", "rustls 0.23.25", - "rustls-native-certs 0.8.1", + "rustls-native-certs", "rustls-pemfile 2.2.0", "seahorse", "sysinfo", @@ -3156,7 +3140,6 @@ dependencies = [ "futures", "http 0.2.12", "hyper 0.14.32", - "hyper-rustls 0.25.0", "libsql-hrana", "libsql-sqlite3-parser", "libsql-sys", @@ -3284,9 +3267,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" @@ -3424,9 +3407,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -4456,6 +4439,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" dependencies = [ "bytes 1.10.1", + "chrono", "fallible-iterator 0.2.0", "postgres-protocol", ] @@ -4939,7 +4923,7 @@ dependencies = [ "http-body 1.0.1", "http-body-util", "hyper 1.6.0", - "hyper-rustls 0.27.5", + "hyper-rustls", "hyper-util", "ipnet", "js-sys", @@ -4950,7 +4934,7 @@ dependencies = [ "pin-project-lite 0.2.16", "quinn", "rustls 0.23.25", - "rustls-native-certs 0.8.1", + "rustls-native-certs", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -4958,7 +4942,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.2", "tokio 1.44.2", - "tokio-rustls 0.26.2", + "tokio-rustls", "tokio-util", "tower 0.5.2", "tower-service", @@ -5151,7 +5135,7 @@ dependencies = [ "bitflags 2.9.0", "errno", "libc", - "linux-raw-sys 0.9.3", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -5167,20 +5151,6 @@ dependencies = [ "webpki", ] -[[package]] -name = "rustls" -version = "0.22.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" -dependencies = [ - "log", - "ring 0.17.14", - "rustls-pki-types", - "rustls-webpki 0.102.8", - "subtle", - "zeroize", -] - [[package]] name = "rustls" version = "0.23.25" @@ -5192,7 +5162,7 @@ dependencies = [ "once_cell", "ring 0.17.14", "rustls-pki-types", - "rustls-webpki 0.103.1", + "rustls-webpki", "subtle", "zeroize", ] @@ -5208,19 +5178,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "rustls-native-certs" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.2.0", - "rustls-pki-types", - "schannel", - "security-framework 2.11.1", -] - [[package]] name = "rustls-native-certs" version = "0.8.1" @@ -5260,17 +5217,6 @@ dependencies = [ "web-time", ] -[[package]] -name = "rustls-webpki" -version = "0.102.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" -dependencies = [ - "ring 0.17.14", - "rustls-pki-types", - "untrusted 0.9.0", -] - [[package]] name = "rustls-webpki" version = "0.103.1" @@ -5337,6 +5283,7 @@ version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ + "chrono", "dyn-clone", "indexmap 2.9.0", "schemars_derive", @@ -6143,17 +6090,6 @@ dependencies = [ "tokio 1.44.2", ] -[[package]] -name = "tokio-rustls" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" -dependencies = [ - "rustls 0.22.4", - "rustls-pki-types", - "tokio 1.44.2", -] - [[package]] name = "tokio-rustls" version = "0.26.2" @@ -6198,11 +6134,11 @@ dependencies = [ "log", "native-tls", "rustls 0.23.25", - "rustls-native-certs 0.8.1", + "rustls-native-certs", "rustls-pki-types", "tokio 1.44.2", "tokio-native-tls", - "tokio-rustls 0.26.2", + "tokio-rustls", "tungstenite", ] @@ -7055,15 +6991,6 @@ dependencies = [ "untrusted 0.9.0", ] -[[package]] -name = "webpki-roots" -version = "0.26.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "which" version = "4.4.2" diff --git a/crates/devolutions-pedm/Cargo.toml b/crates/devolutions-pedm/Cargo.toml index 972212eb4..6f3a2a4a7 100644 --- a/crates/devolutions-pedm/Cargo.toml +++ b/crates/devolutions-pedm/Cargo.toml @@ -14,10 +14,11 @@ anyhow = "1.0" axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "tracing", "tower-log", "form", "original-uri", "matched-path"] } base16ct = { version = "0.2", features = ["std", "alloc"] } base64 = "0.22" +chrono = { version = "0.4", features = ["serde"] } digest = "0.10" hyper = { version = "1.3", features = ["server"] } hyper-util = { version = "0.1", features = ["tokio"] } -schemars = "0.8" +schemars = { version = "0.8", features = ["chrono"] } serde = "1.0" serde_json = "1.0" sha1 = "0.10" @@ -39,10 +40,10 @@ uuid = "1" dunce = "1.0" tower = "0.5" futures-util = "0.3" -libsql = { version = "0.9", optional = true, features = [ "core", "stream"] } -tokio-postgres = { version = "0.7", optional = true } -bb8 = { version = "0.9.0", optional = true } -bb8-postgres = { version = "0.9.0", optional = true } +libsql = { version = "0.9", optional = true, default-features = false, features = [ "core", "sync"] } +tokio-postgres = { version = "0.7", optional = true, features = ["with-chrono-0_4"] } +bb8 = { version = "0.9", optional = true } +bb8-postgres = { version = "0.9", optional = true } [features] default = ["libsql"] diff --git a/crates/devolutions-pedm/schema/libsql.sql b/crates/devolutions-pedm/schema/libsql.sql index 0c999b0ba..1f280aa72 100644 --- a/crates/devolutions-pedm/schema/libsql.sql +++ b/crates/devolutions-pedm/schema/libsql.sql @@ -1,20 +1,39 @@ -/* In SQLite, we store time as integer with microsecond precision. This is the same precision used by TIMESTAMPTZ in Postgres. */ +/* In SQLite, we store time as an 8-byte integer (i64) with microsecond precision. This matches TIMESTAMPTZ in Postgres. + Use `chrono::DateTime::timestamp_micros` when inserting or fetching timestamps in Rust. +*/ -CREATE TABLE pedm_run ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - start_time INTEGER NOT NULL DEFAULT (CAST(strftime('%f', 'now') * 1000000 AS INTEGER)), - pipe_name TEXT NOT NULL +CREATE TABLE IF NOT EXISTS version +( + version integer PRIMARY KEY, + updated_at integer NOT NULL DEFAULT ( + CAST(strftime('%s', 'now') AS integer) * 1000000 + CAST(strftime('%f', 'now') * 1000000 AS integer) % 1000000 + ) ); -CREATE TABLE http_request ( - id INTEGER PRIMARY KEY, - at INTEGER NOT NULL DEFAULT (CAST(strftime('%f', 'now') * 1000000 AS INTEGER)), - method TEXT NOT NULL, - path TEXT NOT NULL, - status_code INTEGER NOT NULL +CREATE TABLE IF NOT EXISTS run +( + id integer PRIMARY KEY AUTOINCREMENT, + start_time integer NOT NULL DEFAULT ( + CAST(strftime('%s', 'now') AS integer) * 1000000 + CAST(strftime('%f', 'now') * 1000000 AS integer) % 1000000 + ), + pipe_name text NOT NULL ); -CREATE TABLE elevate_tmp_request ( - req_id INTEGER PRIMARY KEY, - seconds INTEGER NOT NULL +CREATE TABLE IF NOT EXISTS http_request +( + id integer PRIMARY KEY, + at integer NOT NULL DEFAULT ( + CAST(strftime('%s', 'now') AS integer) * 1000000 + CAST(strftime('%f', 'now') * 1000000 AS integer) % 1000000 + ), + method text NOT NULL, + path text NOT NULL, + status_code integer NOT NULL ); + +CREATE TABLE IF NOT EXISTS elevate_tmp_request +( + req_id integer PRIMARY KEY, + seconds integer NOT NULL +); + +INSERT INTO version (version) VALUES (0) ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/crates/devolutions-pedm/schema/pg.sql b/crates/devolutions-pedm/schema/pg.sql index 6c74949b6..393df9ca4 100644 --- a/crates/devolutions-pedm/schema/pg.sql +++ b/crates/devolutions-pedm/schema/pg.sql @@ -1,22 +1,31 @@ +CREATE TABLE IF NOT EXISTS version +( + version smallint PRIMARY KEY, + add_time timestamptz NOT NULL DEFAULT NOW() +); + /* The startup of the server */ -CREATE TABLE pedm_run +CREATE TABLE IF NOT EXISTS run ( id int PRIMARY KEY GENERATED ALWAYS AS IDENTITY, start_time timestamptz NOT NULL DEFAULT NOW(), pipe_name text NOT NULL ); -CREATE TABLE http_request +CREATE TABLE IF NOT EXISTS http_request ( id integer PRIMARY KEY, at timestamptz NOT NULL DEFAULT NOW(), method text NOT NULL, - path text NOT NULL, + path text NOT NULL, status_code smallint NOT NULL ); -CREATE TABLE elevate_tmp_request +/* The request ID is `http_request(id)` but the http_request INSERT only executes in middleware after the response, so we don't use a FK. */ +CREATE TABLE IF NOT EXISTS elevate_tmp_request ( - req_id integer PRIMARY KEY, /* this is http_request but the http_request INSERT only executes in middleware after the response, so we don't use a FK */ + req_id integer PRIMARY KEY, seconds int NOT NULL -); \ No newline at end of file +); + +INSERT INTO version (version) VALUES (0) ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/crates/devolutions-pedm/src/api/about.rs b/crates/devolutions-pedm/src/api/about.rs new file mode 100644 index 000000000..1ca388a5e --- /dev/null +++ b/crates/devolutions-pedm/src/api/about.rs @@ -0,0 +1,25 @@ +use std::sync::atomic::Ordering; + +use aide::NoApi; +use axum::extract::State; +use axum::Json; + +use crate::db::Db; +use crate::model::AboutData; + +use super::err::HandlerError; +use super::state::AppState; + +/// Gets info about the current state of the application. +pub(crate) async fn about( + NoApi(State(state)): NoApi>, + NoApi(Db(db)): NoApi, +) -> Result, HandlerError> { + Ok(Json(AboutData { + run_id: state.startup_info.run_id, + start_time: state.startup_info.start_time, + startup_request_count: state.startup_info.request_count, + current_request_count: state.req_counter.load(Ordering::Relaxed), + last_request_time: db.get_last_request_time().await?, + })) +} diff --git a/crates/devolutions-pedm/src/api/elevate_temporary.rs b/crates/devolutions-pedm/src/api/elevate_temporary.rs index b193acdf9..df3258329 100644 --- a/crates/devolutions-pedm/src/api/elevate_temporary.rs +++ b/crates/devolutions-pedm/src/api/elevate_temporary.rs @@ -9,11 +9,11 @@ use parking_lot::RwLock; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use crate::db::Db; use crate::elevations; use crate::policy::Policy; use super::err::HandlerError; -use super::state::Db; use super::NamedPipeConnectInfo; #[derive(Deserialize, Serialize, JsonSchema, Debug)] diff --git a/crates/devolutions-pedm/src/api/err.rs b/crates/devolutions-pedm/src/api/err.rs index b4429905c..77e46e4ec 100644 --- a/crates/devolutions-pedm/src/api/err.rs +++ b/crates/devolutions-pedm/src/api/err.rs @@ -6,6 +6,8 @@ use hyper::StatusCode; use crate::db::DbError; /// An error type for route handlers. +/// +/// The error contains a status code and an optional error message. #[derive(Debug)] pub(crate) struct HandlerError(StatusCode, Option); diff --git a/crates/devolutions-pedm/src/api/mod.rs b/crates/devolutions-pedm/src/api/mod.rs index c146c80cf..78caa9f58 100644 --- a/crates/devolutions-pedm/src/api/mod.rs +++ b/crates/devolutions-pedm/src/api/mod.rs @@ -32,18 +32,12 @@ use win_api_wrappers::token::Token; use win_api_wrappers::undoc::PIPE_ACCESS_FULL_CONTROL; use win_api_wrappers::utils::Pipe; -use elevate_session::elevate_session; -use elevate_temporary::elevate_temporary; -use launch::post_launch; -use revoke::post_revoke; -use state::{AppState, AppStateError}; -use status::get_status; - use crate::config::Config; -use crate::db::DbError; +use crate::db::{Db, DbError, InitSchemaError}; use crate::error::{Error, ErrorResponse}; use crate::utils::AccountExt; +mod about; mod elevate_session; mod elevate_temporary; mod err; @@ -53,6 +47,14 @@ mod revoke; pub(crate) mod state; mod status; +use self::about::about; +use self::elevate_session::elevate_session; +use self::elevate_temporary::elevate_temporary; +use self::launch::post_launch; +use self::revoke::post_revoke; +use self::state::{AppState, AppStateError}; +use self::status::get_status; + #[derive(Debug, Clone)] struct NamedPipeConnectInfo { pub(crate) user: User, @@ -132,6 +134,7 @@ fn create_pipe(pipe_name: &str) -> anyhow::Result { pub(crate) fn api_router() -> ApiRouter { ApiRouter::new() + .api_route("/about", aide::axum::routing::get(about)) .api_route("/elevate/temporary", aide::axum::routing::post(elevate_temporary)) .api_route("/elevate/session", aide::axum::routing::post(elevate_session)) .api_route("/launch", aide::axum::routing::post(post_launch)) @@ -165,8 +168,12 @@ async fn health_check() -> &'static str { "OK" } +/// Initializes the appliation and starts the named pipe server. pub async fn serve(config: Config) -> Result<(), ServeError> { - let state = AppState::load(&config).await?; + let db = Db::new(&config).await?; + db.setup().await?; + + let state = AppState::new(db, &config.pipe_name).await?; // a plain Axum router let hello_router = Router::new().route("/health", axum::routing::get(health_check)); @@ -182,11 +189,11 @@ pub async fn serve(config: Config) -> Result<(), ServeError> { let mut server = create_pipe(pipe_name)?; // Log the server startup. - let run_id = state.db.log_server_startup(pipe_name).await?; - info!("Started server at {pipe_name} with run ID {run_id}"); + info!("Started named pipe server with name `{pipe_name}`"); info!( - "Starting request ID counter at {}", - state.req_counter.load(Ordering::Relaxed) + "Run ID is {run_id}, request ID counter is {req_count}", + run_id = state.startup_info.run_id, + req_count = state.req_counter.load(Ordering::Relaxed) ); loop { @@ -219,6 +226,7 @@ pub enum ServeError { TokioIo(tokio::io::Error), AppState(AppStateError), Db(DbError), + InitSchema(InitSchemaError), Other(anyhow::Error), } @@ -228,6 +236,7 @@ impl core::error::Error for ServeError { Self::TokioIo(e) => Some(e), Self::AppState(e) => Some(e), Self::Db(e) => Some(e), + Self::InitSchema(e) => Some(e), Self::Other(e) => Some(e.as_ref()), } } @@ -239,6 +248,7 @@ impl fmt::Display for ServeError { Self::TokioIo(e) => e.fmt(f), Self::AppState(e) => e.fmt(f), Self::Db(e) => e.fmt(f), + Self::InitSchema(e) => e.fmt(f), Self::Other(e) => e.fmt(f), } } @@ -259,6 +269,11 @@ impl From for ServeError { Self::Db(e) } } +impl From for ServeError { + fn from(e: InitSchemaError) -> Self { + Self::InitSchema(e) + } +} impl From for ServeError { fn from(e: anyhow::Error) -> Self { Self::Other(e) @@ -325,8 +340,12 @@ where info!("request ID: {req_id}, status code: {status_code}"); tokio::spawn(async move { #[expect(clippy::cast_possible_wrap)] - db.log_http_request(req_id, method.as_str(), &path, status_code.as_u16() as i16) - .await?; + if let Err(error) = db + .log_http_request(req_id, method.as_str(), &path, status_code.as_u16() as i16) + .await + { + error!(%error, "Failed to log HTTP request"); + } Ok::<_, DbError>(()) }); Ok(resp) diff --git a/crates/devolutions-pedm/src/api/state.rs b/crates/devolutions-pedm/src/api/state.rs index 2168c715b..73b84b531 100644 --- a/crates/devolutions-pedm/src/api/state.rs +++ b/crates/devolutions-pedm/src/api/state.rs @@ -4,31 +4,18 @@ use std::sync::Arc; use axum::extract::{FromRef, FromRequestParts}; use axum::http::request::Parts; +use chrono::Utc; use hyper::StatusCode; use parking_lot::RwLock; -use tracing::info; -use crate::config::{Config, ConfigError, DbBackend}; -use crate::db::{Database, DbError}; +use crate::db::{Database, Db, DbError}; +use crate::model::StartupInfo; use crate::policy::{LoadPolicyError, Policy}; -#[cfg(feature = "libsql")] -use crate::db::LibsqlConn; - -#[cfg(feature = "postgres")] -use bb8::Pool; -#[cfg(feature = "postgres")] -use bb8_postgres::PostgresConnectionManager; -#[cfg(feature = "postgres")] -use tokio_postgres::config::SslMode; -#[cfg(feature = "postgres")] -use tokio_postgres::NoTls; - -#[cfg(feature = "postgres")] -use crate::db::PgPool; - +/// Axum application state. #[derive(Clone)] pub(crate) struct AppState { + pub(crate) startup_info: StartupInfo, /// Request counter. /// /// The current count is the last used request ID. @@ -40,59 +27,29 @@ pub(crate) struct AppState { } impl AppState { - pub(crate) async fn load(config: &Config) -> Result { - let db: Arc = match config.db { - #[cfg(feature = "libsql")] - DbBackend::Libsql => { - #[expect(clippy::unwrap_used)] - let c = config.libsql.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function - let db_obj = libsql::Builder::new_local(&c.path) - .build() - .await - .map_err(DbError::from)?; - let conn = db_obj.connect().map_err(DbError::from)?; - info!("Connecting to LibSQL database at {}", c.path); - Arc::new(LibsqlConn::new(conn)) - } - #[cfg(feature = "postgres")] - DbBackend::Postgres => { - #[expect(clippy::unwrap_used)] - let c = config.postgres.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function - let mut pg_config = tokio_postgres::Config::new(); - pg_config.host(&c.host); - pg_config.dbname(&c.dbname); - if let Some(n) = c.port { - pg_config.port(n); - } - pg_config.user(&c.user); - if let Some(s) = &c.password { - pg_config.password(s); - } - pg_config.ssl_mode(SslMode::Disable); + pub(crate) async fn new(db: Db, pipe_name: &str) -> Result { + let policy = Policy::load()?; - let mgr = PostgresConnectionManager::new(pg_config, NoTls); - let pool = Pool::builder().build(mgr).await.map_err(DbError::from)?; + let last_req_id = db.get_last_request_id().await?; + let startup_time = Utc::now(); + let run_id = db.log_server_startup(startup_time, pipe_name).await?; - info!("Connecting to Postgres database {} on host {}", c.dbname, c.host); - Arc::new(PgPool::new(pool)) - } + let startup_info = StartupInfo { + run_id, + request_count: last_req_id, + start_time: startup_time, }; - let policy = Policy::load()?; - - let last_req_id = db.get_latest_request_id().await?; - Ok(Self { + startup_info, req_counter: Arc::new(AtomicI32::new(last_req_id)), - db, + db: db.0, policy: Arc::new(RwLock::new(policy)), }) } } /// Axum extractor for an object that is `Database`. -pub(crate) struct Db(pub Arc); - impl FromRequestParts for Db where AppState: FromRef, @@ -114,7 +71,6 @@ impl FromRef for Arc> { #[derive(Debug)] pub enum AppStateError { - Config(ConfigError), LoadPolicy(LoadPolicyError), Db(DbError), } @@ -122,7 +78,6 @@ pub enum AppStateError { impl error::Error for AppStateError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { - Self::Config(e) => Some(e), Self::LoadPolicy(e) => Some(e), Self::Db(e) => Some(e), } @@ -132,17 +87,12 @@ impl error::Error for AppStateError { impl fmt::Display for AppStateError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Config(e) => e.fmt(f), Self::LoadPolicy(e) => e.fmt(f), Self::Db(e) => e.fmt(f), } } } -impl From for AppStateError { - fn from(e: ConfigError) -> Self { - Self::Config(e) - } -} + impl From for AppStateError { fn from(e: LoadPolicyError) -> Self { Self::LoadPolicy(e) diff --git a/crates/devolutions-pedm/src/db/err.rs b/crates/devolutions-pedm/src/db/err.rs index e2dc9db0d..8dfe50e0f 100644 --- a/crates/devolutions-pedm/src/db/err.rs +++ b/crates/devolutions-pedm/src/db/err.rs @@ -1,20 +1,43 @@ +use core::error::Error; use core::fmt; +#[cfg(feature = "libsql")] +use std::num::TryFromIntError; + +#[cfg(feature = "libsql")] +use chrono::{DateTime, Utc}; + +#[cfg(feature = "postgres")] +use tokio_postgres::error::SqlState; + +/// Error type for DB operations. #[derive(Debug)] pub enum DbError { #[cfg(feature = "libsql")] Libsql(libsql::Error), + /// This is to handle some type conversions. + /// + /// For example, we may have a value that is `i16` in the data model but it is stored as `i64` in libSQL. + #[cfg(feature = "libsql")] + TryFromInt(TryFromIntError), + /// An error that occurs when parsing microseconds since epoch into `chrono::DateTime`. + #[cfg(feature = "libsql")] + Timestamp(ParseTimestampError), #[cfg(feature = "postgres")] Bb8(bb8::RunError), #[cfg(feature = "postgres")] Postgres(tokio_postgres::Error), } -impl core::error::Error for DbError { - fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { +impl Error for DbError { + fn source(&self) -> Option<&(dyn Error + 'static)> { match self { #[cfg(feature = "libsql")] Self::Libsql(e) => Some(e), + #[cfg(feature = "libsql")] + Self::TryFromInt(e) => Some(e), + #[cfg(feature = "libsql")] + Self::Timestamp(e) => Some(e), #[cfg(feature = "postgres")] Self::Bb8(e) => Some(e), #[cfg(feature = "postgres")] @@ -28,8 +51,12 @@ impl fmt::Display for DbError { match self { #[cfg(feature = "libsql")] Self::Libsql(e) => e.fmt(f), + #[cfg(feature = "libsql")] + Self::TryFromInt(e) => e.fmt(f), + #[cfg(feature = "libsql")] + Self::Timestamp(e) => e.fmt(f), #[cfg(feature = "postgres")] - Self::Bb8(e) => e.fmt(f), + Self::Bb8(e) => write!(f, "could not connect to the database: {e}"), #[cfg(feature = "postgres")] Self::Postgres(e) => e.fmt(f), } @@ -42,6 +69,18 @@ impl From for DbError { Self::Libsql(e) } } +#[cfg(feature = "libsql")] +impl From for DbError { + fn from(e: TryFromIntError) -> Self { + Self::TryFromInt(e) + } +} +#[cfg(feature = "libsql")] +impl From for DbError { + fn from(e: ParseTimestampError) -> Self { + Self::Timestamp(e) + } +} #[cfg(feature = "postgres")] impl From> for DbError { @@ -55,3 +94,41 @@ impl From for DbError { Self::Postgres(e) } } + +impl DbError { + pub fn is_table_does_not_exist(&self) -> bool { + match self { + #[cfg(feature = "libsql")] + Self::Libsql(libsql::Error::SqliteFailure(1, msg)) => msg.starts_with("no such table"), + #[cfg(feature = "postgres")] + Self::Postgres(e) => e.code() == Some(&SqlState::UNDEFINED_TABLE), + _ => false, + } + } +} + +/// A custom error type equivalent for `chrono::LocalResult`. +#[cfg(feature = "libsql")] +#[derive(Debug)] +pub enum ParseTimestampError { + None, + /// This should be unreachable when using UTC. + Ambiguous(DateTime, DateTime), +} + +#[cfg(feature = "libsql")] +impl Error for ParseTimestampError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +#[cfg(feature = "libsql")] +impl fmt::Display for ParseTimestampError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None => write!(f, "no timestamp found"), + Self::Ambiguous(dt1, dt2) => write!(f, "ambiguous timestamp: {dt1} or {dt2}"), + } + } +} diff --git a/crates/devolutions-pedm/src/db/libsql.rs b/crates/devolutions-pedm/src/db/libsql.rs index 713b31f7c..889da0a4f 100644 --- a/crates/devolutions-pedm/src/db/libsql.rs +++ b/crates/devolutions-pedm/src/db/libsql.rs @@ -1,10 +1,13 @@ use std::ops::Deref; use async_trait::async_trait; +use chrono::{DateTime, Utc}; use libsql::params::IntoParams; use libsql::{params, Row}; +use tracing::info; -use super::{Database, DbError}; +use super::err::ParseTimestampError; +use super::{Database, DbError, InitSchemaError}; pub(crate) struct LibsqlConn(libsql::Connection); @@ -35,7 +38,23 @@ impl Deref for LibsqlConn { #[async_trait] impl Database for LibsqlConn { - async fn get_latest_request_id(&self) -> Result { + async fn get_schema_version(&self) -> Result { + let version = self.query_one("SELECT version FROM version", ()).await?.get::(0)?; + Ok(i16::try_from(version)?) + } + + async fn init_schema(&self) -> Result<(), DbError> { + let sql = include_str!("../../schema/libsql.sql"); + self.execute_batch(sql).await?; + Ok(()) + } + + async fn apply_pragmas(&self) -> Result<(), DbError> { + // TODO: run pragmas + Ok(()) + } + + async fn get_last_request_id(&self) -> Result { match self .query_one("SELECT id FROM http_request ORDER BY id DESC LIMIT 1", ()) .await @@ -46,9 +65,26 @@ impl Database for LibsqlConn { } } - async fn log_server_startup(&self, pipe_name: &str) -> Result { + async fn get_last_request_time(&self) -> Result>, DbError> { + match self + .query_one("SELECT at FROM http_request ORDER BY id DESC LIMIT 1", ()) + .await + { + Ok(row) => { + let micros: i64 = row.get(0)?; + Ok(Some(parse_micros(micros)?)) + } + Err(libsql::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(DbError::Libsql(e)), + } + } + + async fn log_server_startup(&self, start_time: DateTime, pipe_name: &str) -> Result { Ok(self - .query_one("INSERT INTO pedm_run (pipe_name) VALUES (?) RETURNING id", [pipe_name]) + .query_one( + "INSERT INTO run (start_time, pipe_name) VALUES (?1, ?2) RETURNING id", + params![start_time.timestamp_micros(), pipe_name], + ) .await? .get(0)?) } @@ -71,3 +107,38 @@ impl Database for LibsqlConn { Ok(()) } } + +/// Converts a timestamp in microseconds to a `DateTime`. +fn parse_micros(micros: i64) -> Result, ParseTimestampError> { + use chrono::offset::LocalResult; + use chrono::TimeZone; + + match Utc.timestamp_micros(micros) { + LocalResult::Single(dt) => Ok(dt), + LocalResult::Ambiguous(earliest, latest) => Err(ParseTimestampError::Ambiguous(earliest, latest)), + LocalResult::None => Err(ParseTimestampError::None), + } +} + +#[cfg(test)] +mod tests { + use chrono::{TimeZone, Utc}; + + use super::parse_micros; + use crate::db::err::ParseTimestampError; + + #[test] + fn test_valid_micros() { + let dt = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let micros = dt.timestamp_micros(); + let parsed = parse_micros(micros).unwrap(); + assert_eq!(parsed.timestamp(), dt.timestamp()); + } + + #[test] + fn test_invalid_micros_none() { + // i32::MIN is too far in the past and produces LocalResult::None + let e = parse_micros(i64::MIN).unwrap_err(); + matches!(e, ParseTimestampError::None); + } +} diff --git a/crates/devolutions-pedm/src/db/mod.rs b/crates/devolutions-pedm/src/db/mod.rs index 5c1fc419e..b009c7f6f 100644 --- a/crates/devolutions-pedm/src/db/mod.rs +++ b/crates/devolutions-pedm/src/db/mod.rs @@ -1,7 +1,15 @@ +use core::fmt; +use std::ops::Deref; +use std::sync::Arc; + use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use tracing::info; mod err; +use crate::config::DbBackend; +use crate::Config; pub(crate) use err::DbError; #[cfg(feature = "libsql")] @@ -14,22 +22,174 @@ mod pg; #[cfg(feature = "postgres")] pub(crate) use pg::PgPool; +#[cfg(feature = "postgres")] +use bb8::Pool; +#[cfg(feature = "postgres")] +use bb8_postgres::PostgresConnectionManager; +#[cfg(feature = "postgres")] +use tokio_postgres::config::SslMode; +#[cfg(feature = "postgres")] +use tokio_postgres::NoTls; + +pub(crate) const CURRENT_SCHEMA_VERSION: i16 = 0; + +/// A wrapper around the database connection. +#[derive(Clone)] +pub(crate) struct Db(pub Arc); + +impl Deref for Db { + type Target = dyn Database + Send + Sync; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl Db { + /// Creates a new `Db` instance. + pub(crate) async fn new(config: &Config) -> Result { + let db: Arc = match config.db { + #[cfg(feature = "libsql")] + DbBackend::Libsql => { + #[expect(clippy::unwrap_used)] + let c = config.libsql.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function + let db_obj = ::libsql::Builder::new_local(&c.path).build().await?; + let conn = db_obj.connect()?; + info!("Connecting to libSQL database at {}", c.path); + Arc::new(LibsqlConn::new(conn)) + } + #[cfg(feature = "postgres")] + DbBackend::Postgres => { + #[expect(clippy::unwrap_used)] + let c = config.postgres.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function + let mut pg_config = tokio_postgres::Config::new(); + pg_config.host(&c.host); + pg_config.dbname(&c.dbname); + if let Some(n) = c.port { + pg_config.port(n); + } + pg_config.user(&c.user); + if let Some(s) = &c.password { + pg_config.password(s); + } + pg_config.ssl_mode(SslMode::Disable); + + let mgr = PostgresConnectionManager::new(pg_config, NoTls); + let pool = Pool::builder().build(mgr).await?; + + info!( + "Connecting to postgres://{user}@{host}:{port}/{dbname}", + user = c.user, + host = c.host, + port = c.port.unwrap_or(5432), + dbname = c.dbname + ); + // Check if the connection works. + let conn = pool.get().await?; + conn.query_one("SELECT 1", &[]).await?; + drop(conn); + Arc::new(PgPool::new(pool)) + } + }; + info!("Successfully connected to the database"); + Ok(Self(db)) + } + + /// Sets up the database. + /// + /// The schema version is checked. Tables are created if needed, such as for first run. + pub(crate) async fn setup(&self) -> Result<(), InitSchemaError> { + match self.0.get_schema_version().await { + Ok(version) => { + info!("Schema version: {version}"); + if version != CURRENT_SCHEMA_VERSION { + return Err(InitSchemaError::VersionMismatch { + expected: CURRENT_SCHEMA_VERSION, + actual: version, + }); + } + } + Err(error) => { + if error.is_table_does_not_exist() { + info!("Initializing schema"); + self.0.init_schema().await?; + } else { + return Err(error.into()); + } + } + } + self.0.apply_pragmas().await?; + Ok(()) + } +} + +#[derive(Debug)] +pub enum InitSchemaError { + VersionMismatch { expected: i16, actual: i16 }, + Db(DbError), +} + +impl core::error::Error for InitSchemaError { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + match self { + Self::VersionMismatch { .. } => None, + Self::Db(e) => Some(e), + } + } +} + +impl fmt::Display for InitSchemaError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::VersionMismatch { expected, actual } => { + write!( + f, + "schema version mismatch: expected version {expected}, got version {actual}" + ) + } + Self::Db(e) => e.fmt(f), + } + } +} + +impl From for InitSchemaError { + fn from(e: DbError) -> Self { + Self::Db(e) + } +} + /// Abstracts database operations for backends such as Postgres or libSQL. /// /// All queries required by the application are defined here. They must be implemented by each backend. #[async_trait] pub(crate) trait Database: Send + Sync { + /// Returns the schema version from the `version` table. + async fn get_schema_version(&self) -> Result; + + /// Initializes the database schema. + /// + /// This creates tables. + async fn init_schema(&self) -> Result<(), DbError>; + + /// Applies pragmas, if applicable. + async fn apply_pragmas(&self) -> Result<(), DbError>; + /// Gets the latest request ID from the HTTP request table. /// /// This is used to set the atomic request counter. /// /// It returns an error if there is a database error, except for "no rows found". In that case, it returns 0. - async fn get_latest_request_id(&self) -> Result; + async fn get_last_request_id(&self) -> Result; + + /// Gets the time of the latest request. + /// + /// This is used in endpoints like `/about`. + async fn get_last_request_time(&self) -> Result>, DbError>; /// Logs the server startup. /// /// Returns the run ID. - async fn log_server_startup(&self, pipe_name: &str) -> Result; + async fn log_server_startup(&self, start_time: DateTime, pipe_name: &str) -> Result; /// Logs an HTTP request. /// diff --git a/crates/devolutions-pedm/src/db/pg.rs b/crates/devolutions-pedm/src/db/pg.rs index afa80157b..95d49b670 100644 --- a/crates/devolutions-pedm/src/db/pg.rs +++ b/crates/devolutions-pedm/src/db/pg.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use async_trait::async_trait; use bb8::Pool; use bb8_postgres::PostgresConnectionManager; +use chrono::{DateTime, Utc}; use tokio_postgres::NoTls; use super::{Database, DbError}; @@ -25,7 +26,27 @@ impl Deref for PgPool { #[async_trait] impl Database for PgPool { - async fn get_latest_request_id(&self) -> Result { + async fn get_schema_version(&self) -> Result { + Ok(self + .get() + .await? + .query_one("SELECT version FROM version", &[]) + .await? + .get(0)) + } + + async fn init_schema(&self) -> Result<(), DbError> { + let sql = include_str!("../../schema/pg.sql"); + self.get().await?.batch_execute(sql).await?; + Ok(()) + } + + async fn apply_pragmas(&self) -> Result<(), DbError> { + // nothing to do + Ok(()) + } + + async fn get_last_request_id(&self) -> Result { Ok(self .get() .await? @@ -35,13 +56,22 @@ impl Database for PgPool { .unwrap_or_default()) } - async fn log_server_startup(&self, pipe_name: &str) -> Result { + async fn get_last_request_time(&self) -> Result>, DbError> { + Ok(self + .get() + .await? + .query_opt("SELECT at FROM http_request ORDER BY id DESC LIMIT 1", &[]) + .await? + .map(|r| r.get(0))) + } + + async fn log_server_startup(&self, start_time: DateTime, pipe_name: &str) -> Result { Ok(self .get() .await? .query_one( - "INSERT INTO pedm_run (pipe_name) VALUES ($1) RETURNING id", - &[&pipe_name], + "INSERT INTO run (start_time, pipe_name) VALUES ($1, $2) RETURNING id", + &[&start_time, &pipe_name], ) .await? .get(0)) @@ -50,7 +80,7 @@ impl Database for PgPool { async fn log_http_request(&self, req_id: i32, method: &str, path: &str, status_code: i16) -> Result<(), DbError> { self.get() .await? - .query_one( + .execute( "INSERT INTO http_request (id, method, path, status_code) VALUES ($1, $2, $3, $4)", &[&req_id, &method, &path, &status_code], ) diff --git a/crates/devolutions-pedm/src/lib.rs b/crates/devolutions-pedm/src/lib.rs index 30b5ab05d..929582552 100644 --- a/crates/devolutions-pedm/src/lib.rs +++ b/crates/devolutions-pedm/src/lib.rs @@ -4,13 +4,14 @@ use camino::Utf8PathBuf; use devolutions_gateway_task::{ShutdownSignal, Task}; mod config; +mod db; +pub mod model; + pub use config::Config; cfg_if::cfg_if! { if #[cfg(target_os = "windows")] { pub mod api; - mod db; - mod elevations; mod elevator; mod error; diff --git a/crates/devolutions-pedm/src/model.rs b/crates/devolutions-pedm/src/model.rs new file mode 100644 index 000000000..4c98af6a5 --- /dev/null +++ b/crates/devolutions-pedm/src/model.rs @@ -0,0 +1,27 @@ +use chrono::{DateTime, Utc}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct AboutData { + pub run_id: i32, + pub start_time: DateTime, + pub startup_request_count: i32, + pub current_request_count: i32, + /// The time of the most recent request. + /// + /// This can be `None` if `/about` is the first request made. + pub last_request_time: Option>, +} + +/// Immutable startup info. +/// +/// It is used in the `/about` endpoint. +#[derive(Clone)] +pub(crate) struct StartupInfo { + pub(crate) run_id: i32, + pub(crate) start_time: DateTime, + /// The request count at the time of the server startup. + pub(crate) request_count: i32, +} diff --git a/crates/devolutions-pedm/src/utils.rs b/crates/devolutions-pedm/src/utils.rs index 273d7c293..60d88044c 100644 --- a/crates/devolutions-pedm/src/utils.rs +++ b/crates/devolutions-pedm/src/utils.rs @@ -1,11 +1,13 @@ +use std::collections::HashMap; +use std::fs; +use std::path::Path; + use devolutions_pedm_shared::policy::{Hash, User}; use digest::Update; use sha1::Sha1; use sha2::{Digest, Sha256}; -use std::collections::HashMap; -use std::fs; -use std::path::Path; use tracing::info; + use win_api_wrappers::fs::create_directory; use win_api_wrappers::identity::account::Account; use win_api_wrappers::identity::sid::Sid;