Skip to content

Commit 1892cbd

Browse files
committed
Add the ability to load/save the config to the db
1 parent 33c3bf5 commit 1892cbd

3 files changed

Lines changed: 64 additions & 11 deletions

File tree

config.sample.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
# Apart from oidc_clients and users, the following are the default values
33
# To create a new config, just set the `users` (or point to `users_file`) and any other settings you want to change
4-
database_url: database.db
4+
database_url: sqlite://database.db
55

66
# Host that the server will listen on
77
listen_host: 0.0.0.0

src/config.rs

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use axum::http::request::Parts;
1515
use chrono::Duration;
1616
use lettre::transport::smtp;
1717
use notify::{PollWatcher, Watcher};
18-
use serde::{Deserialize, Serialize};
18+
use serde::{Deserialize, Serialize, Serializer};
1919
use tracing::{error, info};
2020

2121
use crate::CONFIG;
@@ -31,7 +31,7 @@ use crate::user_store::{FileUserStore, SQLUserStore, StaticUserStore, UserStore}
3131
/// TODO: Move the comments from here to the config.sample.yaml so the code
3232
/// is the source of truth
3333
// TODO: Generate a validation schema
34-
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
34+
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
3535
#[serde(default, deny_unknown_fields)]
3636
#[allow(clippy::struct_excessive_bools)]
3737
pub struct Config {
@@ -42,13 +42,22 @@ pub struct Config {
4242
pub path_prefix: String,
4343
pub external_url: String,
4444

45-
#[serde(deserialize_with = "duration_str::deserialize_duration_chrono")]
45+
#[serde(
46+
deserialize_with = "duration_str::deserialize_duration_chrono",
47+
serialize_with = "serialize_duration_chrono"
48+
)]
4649
pub link_duration: Duration,
47-
#[serde(deserialize_with = "duration_str::deserialize_duration_chrono")]
50+
#[serde(
51+
deserialize_with = "duration_str::deserialize_duration_chrono",
52+
serialize_with = "serialize_duration_chrono"
53+
)]
4854
pub session_duration: Duration,
4955

5056
/// Interval for periodic cleanup of expired secrets
51-
#[serde(deserialize_with = "duration_str::deserialize_duration_chrono")]
57+
#[serde(
58+
deserialize_with = "duration_str::deserialize_duration_chrono",
59+
serialize_with = "serialize_duration_chrono"
60+
)]
5261
pub secrets_cleanup_interval: Duration,
5362

5463
pub title: String,
@@ -60,7 +69,10 @@ pub struct Config {
6069
pub auth_url_email_header: String,
6170
pub auth_url_realms_header: String,
6271

63-
#[serde(deserialize_with = "duration_str::deserialize_duration_chrono")]
72+
#[serde(
73+
deserialize_with = "duration_str::deserialize_duration_chrono",
74+
serialize_with = "serialize_duration_chrono"
75+
)]
6476
pub oidc_code_duration: Duration,
6577

6678
pub saml_cert_pem_path: String,
@@ -96,7 +108,7 @@ impl Default for Config {
96108
#[allow(clippy::unwrap_used)] // All the cases are either const or on start (e.g. port)
97109
fn default() -> Self {
98110
Self {
99-
database_url: std::env::var("DATABASE_URL").unwrap_or("database.db".to_string()),
111+
database_url: std::env::var("DATABASE_URL").unwrap_or("sqlite://database.db".to_string()),
100112

101113
listen_host : std::env::var("LISTEN_HOST").unwrap_or("127.0.0.1".to_string()),
102114
listen_port : std::env::var("LISTEN_PORT").unwrap_or("8080".to_string()).parse().unwrap(),
@@ -299,6 +311,36 @@ impl Config {
299311

300312
result
301313
}
314+
315+
/// Enterprise-only feature
316+
pub async fn load_from_db(db: &Database) -> anyhow::Result<Option<Self>> {
317+
info!("Loading config from database");
318+
let config = ConfigKV::get(&ConfigKeys::Config, db)
319+
.await?
320+
.and_then(|c| serde_json::from_str::<Self>(&c).ok());
321+
println!("{config:?}");
322+
Ok(config)
323+
}
324+
325+
/// Enterprise-only feature
326+
pub async fn reload_from_db(config: Arc<ArcSwap<Config>>, db: &Database) -> anyhow::Result<()> {
327+
let Some(new_config) = Self::load_from_db(db).await? else {
328+
return Err(anyhow::anyhow!("Failed to load config from database"));
329+
};
330+
let new_config_arc = Arc::new(new_config.clone());
331+
// TODO: secrets and static pages still use the global config, updating it for the time being
332+
let mut config_guard = CONFIG.write().await;
333+
*config_guard = new_config_arc.clone();
334+
config.store(new_config_arc.into());
335+
Ok(())
336+
}
337+
338+
/// Enterprise-only feature
339+
pub async fn save_to_db(&self, db: &Database) -> anyhow::Result<()> {
340+
info!("Saving config to database");
341+
ConfigKV::set(&ConfigKeys::Config, Some(serde_json::to_string(self)?), db).await?;
342+
Ok(())
343+
}
302344
}
303345

304346
#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize)]
@@ -345,7 +387,7 @@ pub struct ConfigKV {
345387

346388
impl ConfigKV {
347389
/// Set the provided key to the provided value - overwrites any previous values
348-
pub async fn set(key: ConfigKeys, value: Option<String>, db: &Database) -> anyhow::Result<()> {
390+
pub async fn set(key: &ConfigKeys, value: Option<String>, db: &Database) -> anyhow::Result<()> {
349391
let key_str = serde_json::to_string(&key)?;
350392
let value_str = value.unwrap_or_default();
351393

@@ -371,6 +413,17 @@ impl ConfigKV {
371413
pub enum ConfigKeys {
372414
Secret,
373415
JWTSecret,
416+
Config,
374417
}
375418

376-
// Remove AsBytes trait as it's no longer needed for SQLx
419+
fn serialize_duration_chrono<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
420+
where
421+
S: Serializer,
422+
{
423+
let duration_str = format!(
424+
"{}s + {}ns",
425+
duration.num_seconds(),
426+
duration.subsec_nanos()
427+
);
428+
serializer.serialize_str(&duration_str)
429+
}

src/oidc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub async fn init(db: &Database) -> EncodingKey {
4141
rand::fill(&mut buffer);
4242
let secret = hex::encode(buffer);
4343

44-
ConfigKV::set(ConfigKeys::JWTSecret, Some(secret.clone()), db)
44+
ConfigKV::set(&ConfigKeys::JWTSecret, Some(secret.clone()), db)
4545
.await
4646
.expect("Unable to save secret in the database");
4747

0 commit comments

Comments
 (0)