Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions crates/devolutions-pedm/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use axum::extract::{ConnectInfo, Request};
use axum::middleware::{self, Next};
use axum::response::Response;
use axum::{Json, Router};
use camino::Utf8PathBuf;
use futures_util::future::BoxFuture;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server;
Expand Down Expand Up @@ -40,6 +39,7 @@ use revoke::post_revoke;
use state::{AppState, AppStateError};
use status::get_status;

use crate::config::Config;
use crate::db::DbError;
use crate::error::{Error, ErrorResponse};
use crate::utils::AccountExt;
Expand Down Expand Up @@ -94,7 +94,7 @@ async fn named_pipe_middleware(
Ok(next.run(request).await)
}

fn create_pipe(pipe_name: &'static str) -> anyhow::Result<NamedPipeServer> {
fn create_pipe(pipe_name: &str) -> anyhow::Result<NamedPipeServer> {
let pipe = ServerOptions::new().write_dac(true).create(pipe_name)?;

let dacl = Acl::new()?.set_entries(&[
Expand Down Expand Up @@ -165,8 +165,8 @@ async fn health_check() -> &'static str {
"OK"
}

pub async fn serve(pipe_name: &'static str, config_path: Option<Utf8PathBuf>) -> Result<(), ServeError> {
let state = AppState::load(config_path).await?;
pub async fn serve(config: Config) -> Result<(), ServeError> {
let state = AppState::load(&config).await?;

// a plain Axum router
let hello_router = Router::new().route("/health", axum::routing::get(health_check));
Expand All @@ -178,6 +178,7 @@ pub async fn serve(pipe_name: &'static str, config_path: Option<Utf8PathBuf>) ->

let mut make_service = app.into_make_service_with_connect_info::<RawNamedPipeConnectInfo>();

let pipe_name = &config.pipe_name;
let mut server = create_pipe(pipe_name)?;

// Log the server startup.
Expand Down
17 changes: 5 additions & 12 deletions crates/devolutions-pedm/src/api/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::sync::Arc;

use axum::extract::{FromRef, FromRequestParts};
use axum::http::request::Parts;
use camino::Utf8PathBuf;
use hyper::StatusCode;
use parking_lot::RwLock;
use tracing::info;
Expand Down Expand Up @@ -41,18 +40,12 @@ pub(crate) struct AppState {
}

impl AppState {
pub(crate) async fn load(config_path: Option<Utf8PathBuf>) -> Result<Self, AppStateError> {
let config = if let Some(path) = config_path {
Config::load_from_path(&path)
} else {
Config::load_from_default_path()
}?;

pub(crate) async fn load(config: &Config) -> Result<Self, AppStateError> {
let db: Arc<dyn Database + Send + Sync> = match config.db {
#[cfg(feature = "libsql")]
DbBackend::Libsql => {
#[expect(clippy::unwrap_used)]
let c = config.libsql.unwrap(); // already checked by `Config::validate` at the end of the load function
let c = config.libsql.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function
let db_obj = libsql::Builder::new_local(&c.path)
.build()
.await
Expand All @@ -64,15 +57,15 @@ impl AppState {
#[cfg(feature = "postgres")]
DbBackend::Postgres => {
#[expect(clippy::unwrap_used)]
let c = config.postgres.unwrap(); // already checked by `Config::validate` at the end of the load function
let c = config.postgres.as_ref().unwrap(); // already checked by `Config::validate` at the end of the load function
let mut pg_config = tokio_postgres::Config::new();
pg_config.host(&c.host);
pg_config.dbname(&c.dbname);
if let Some(n) = c.port {
pg_config.port(n);
}
pg_config.user(c.user);
if let Some(s) = c.password {
pg_config.user(&c.user);
if let Some(s) = &c.password {
pg_config.password(s);
}
pg_config.ssl_mode(SslMode::Disable);
Expand Down
44 changes: 29 additions & 15 deletions crates/devolutions-pedm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,46 @@ use tracing::info;

use crate::data_dir;

/// Specifies the default pipe name.
///
/// This is a workaround for `serde(default)` not taking a raw string literal or escaped backslashes.
fn default_pipe_name() -> String {
"\\\\.\\pipe\\DevolutionsPEDM".to_owned()
}

/// The application config.
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) struct Config {
pub struct Config {
/// The selected database backend.
///
/// Only one can be active at a given time.
pub(crate) db: DbBackend,
pub(crate) postgres: Option<PgConfig>,
pub(crate) libsql: Option<LibsqlConfig>,
pub db: DbBackend,
pub postgres: Option<PgConfig>,
pub libsql: Option<LibsqlConfig>,
/// Specify the pipe name, if desired.
///
/// Backslashes must be escaped, like "\\\\.\\pipe\\foo".
/// This field is intentionally omitted from the example configuration.
#[serde(default = "default_pipe_name")]
pub pipe_name: String,
}

impl Config {
/// Creates a new config with the default values for a new setup.
fn standard() -> Self {
pub fn standard() -> Self {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: Sounds like it should be Default::default()

Self {
db: DbBackend::default(),
postgres: None,
libsql: Some(LibsqlConfig {
path: data_dir().join("pedm.sqlite"),
}),
pipe_name: default_pipe_name(),
}
}

/// Loads the config file from the specified path.
pub(crate) fn load_from_path(path: &Utf8Path) -> Result<Self, ConfigError> {
pub fn load_from_path(path: &Utf8Path) -> Result<Self, ConfigError> {
match fs::read_to_string(path) {
Ok(s) => {
info!("Loading config from {path}");
Expand All @@ -51,7 +65,7 @@ impl Config {
}

/// Loads the config file from the default path.
pub(crate) fn load_from_default_path() -> Result<Self, ConfigError> {
pub fn load_from_default_path() -> Result<Self, ConfigError> {
let path = data_dir().join("config.toml");
Self::load_from_path(&path)
}
Expand Down Expand Up @@ -91,20 +105,20 @@ impl fmt::Display for DbBackend {

#[derive(Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) struct LibsqlConfig {
pub struct LibsqlConfig {
/// The path to the SQLite database file.
pub(crate) path: Utf8PathBuf,
pub path: Utf8PathBuf,
}

// TODO: SSL support
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) struct PgConfig {
pub(crate) host: String,
pub(crate) dbname: String,
pub(crate) port: Option<u16>, // 5432 if omitted
pub(crate) user: String,
pub(crate) password: Option<String>,
pub struct PgConfig {
pub host: String,
pub dbname: String,
pub port: Option<u16>, // 5432 if omitted
pub user: String,
pub password: Option<String>,
}

#[derive(Debug)]
Expand Down
11 changes: 8 additions & 3 deletions crates/devolutions-pedm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@ use camino::Utf8PathBuf;

use devolutions_gateway_task::{ShutdownSignal, Task};

mod config;
pub use config::Config;

cfg_if::cfg_if! {
if #[cfg(target_os = "windows")] {
pub mod api;
mod db;
mod config;

mod elevations;
mod elevator;
mod error;
mod log;
mod policy;
mod utils;
use tokio::select;

pub use api::serve;

use tokio::select;
use tracing::error;
}
}
Expand All @@ -39,7 +44,7 @@ impl Task for PedmTask {
cfg_if::cfg_if! {
if #[cfg(target_os = "windows")] {
select! {
res = api::serve(r"\\.\pipe\DevolutionsPEDM", None) => {
res = serve(Config::standard()) => {
if let Err(error) = &res {
error!(%error, "Named pipe server got error");
}
Expand Down