Skip to content

Commit 67f4632

Browse files
committed
Add the ability to update each config field procedurally
1 parent 1892cbd commit 67f4632

2 files changed

Lines changed: 109 additions & 139 deletions

File tree

src/config.rs

Lines changed: 109 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -24,141 +24,118 @@ use crate::service::Services;
2424
use crate::user::User;
2525
use crate::user_store::{FileUserStore, SQLUserStore, StaticUserStore, UserStore};
2626

27-
/// The actual, deserialized config data
28-
///
29-
/// To see what each field represents check out the [config.sample.yaml](https://github.com/dzervas/magicentry/blob/main/config.sample.yaml) file
30-
///
31-
/// TODO: Move the comments from here to the config.sample.yaml so the code
32-
/// is the source of truth
33-
// TODO: Generate a validation schema
34-
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
35-
#[serde(default, deny_unknown_fields)]
36-
#[allow(clippy::struct_excessive_bools)]
37-
pub struct Config {
38-
pub database_url: String,
39-
40-
pub listen_host: String,
41-
pub listen_port: u16,
42-
pub path_prefix: String,
43-
pub external_url: String,
27+
macro_rules! config_struct {
28+
(
29+
$(
30+
$(#[$meta:meta])*
31+
$pub:vis $name:ident: $type:ty = $default:expr
32+
),+ $(,)?
33+
) => {
34+
35+
/// The actual, deserialized config data
36+
///
37+
/// To see what each field represents check out the [config.sample.yaml](https://github.com/dzervas/magicentry/blob/main/config.sample.yaml) file
38+
///
39+
/// TODO: Move the comments from here to the config.sample.yaml so the code
40+
/// is the source of truth
41+
// TODO: Generate a validation schema
42+
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
43+
#[serde(default, deny_unknown_fields)]
44+
pub struct Config {
45+
$(
46+
$(#[$meta])*
47+
$pub $name: $type
48+
),+
49+
}
50+
51+
impl Default for Config {
52+
fn default() -> Self {
53+
Self {
54+
$( $name: $default ),+
55+
}
56+
}
57+
}
58+
59+
impl Config {
60+
pub async fn update_field(&mut self, db: &Database, field: &str, value: serde_json::Value) -> anyhow::Result<()> {
61+
match field {
62+
$(
63+
stringify!($name) => {
64+
self.$name = serde_json::from_value::<$type>(value)?;
65+
},
66+
)+
67+
_ => anyhow::bail!("Unknown field: {field}"),
68+
};
69+
70+
self.save_to_db(db).await?;
71+
Ok(())
72+
}
73+
}
74+
};
75+
}
76+
77+
config_struct! {
78+
pub database_url: String = std::env::var("DATABASE_URL").unwrap_or("sqlite://database.db".to_string()),
79+
80+
pub listen_host: String = std::env::var("LISTEN_HOST").unwrap_or("127.0.0.1".to_string()),
81+
pub listen_port: u16 = std::env::var("LISTEN_PORT").unwrap_or("8080".to_string()).parse().unwrap(),
82+
pub path_prefix: String = "/".to_string(),
83+
pub external_url: String = "http://localhost:8080".to_string(),
4484

4585
#[serde(
4686
deserialize_with = "duration_str::deserialize_duration_chrono",
47-
serialize_with = "serialize_duration_chrono"
87+
serialize_with = "crate::config::serialize_duration_chrono"
4888
)]
49-
pub link_duration: Duration,
89+
pub link_duration: Duration = Duration::try_hours(12).unwrap(),
5090
#[serde(
5191
deserialize_with = "duration_str::deserialize_duration_chrono",
52-
serialize_with = "serialize_duration_chrono"
92+
serialize_with = "crate::config::serialize_duration_chrono"
5393
)]
54-
pub session_duration: Duration,
55-
56-
/// Interval for periodic cleanup of expired secrets
94+
pub session_duration: Duration = Duration::try_days(30).unwrap(),
5795
#[serde(
5896
deserialize_with = "duration_str::deserialize_duration_chrono",
59-
serialize_with = "serialize_duration_chrono"
97+
serialize_with = "crate::config::serialize_duration_chrono"
6098
)]
61-
pub secrets_cleanup_interval: Duration,
99+
pub secrets_cleanup_interval: Duration = Duration::try_hours(24).unwrap(),
62100

63-
pub title: String,
64-
pub static_path: String,
101+
pub title: String = "MagicEntry".to_string(),
102+
pub static_path: String = "static".to_string(),
65103

66-
pub auth_url_enable: bool,
67-
pub auth_url_user_header: String,
68-
pub auth_url_name_header: String,
69-
pub auth_url_email_header: String,
70-
pub auth_url_realms_header: String,
104+
pub auth_url_enable: bool = true,
105+
pub auth_url_user_header: String = "X-Auth-User".to_string(),
106+
pub auth_url_name_header: String = "X-Auth-Name".to_string(),
107+
pub auth_url_email_header: String = "X-Auth-Email".to_string(),
108+
pub auth_url_realms_header: String = "X-Auth-Realms".to_string(),
71109

72110
#[serde(
73111
deserialize_with = "duration_str::deserialize_duration_chrono",
74-
serialize_with = "serialize_duration_chrono"
112+
serialize_with = "crate::config::serialize_duration_chrono"
75113
)]
76-
pub oidc_code_duration: Duration,
77-
78-
pub saml_cert_pem_path: String,
79-
pub saml_key_pem_path: String,
80-
81-
pub smtp_enable: bool,
82-
pub smtp_url: String,
83-
pub smtp_from: String,
84-
pub smtp_subject: String,
85-
pub smtp_body: String,
86-
87-
pub request_enable: bool,
88-
pub request_url: String,
89-
pub request_method: String,
90-
pub request_data: Option<String>,
91-
pub request_content_type: String,
92-
93-
pub webauthn_enable: bool,
94-
95-
// pub force_https_redirects: bool,
96-
// Private to avoid reading from the field instead of the user store
97-
users: Vec<User>,
98-
/// Path to a file containing the user definitions
99-
pub users_file: Option<String>,
100-
pub users_sql_query_all: Option<String>,
101-
pub users_sql_query_email: Option<String>,
102-
pub users_sql_url: Option<String>,
103-
pub services: Services,
104-
}
105-
106-
impl Default for Config {
107-
#[allow(clippy::or_fun_call)]
108-
#[allow(clippy::unwrap_used)] // All the cases are either const or on start (e.g. port)
109-
fn default() -> Self {
110-
Self {
111-
database_url: std::env::var("DATABASE_URL").unwrap_or("sqlite://database.db".to_string()),
112-
113-
listen_host : std::env::var("LISTEN_HOST").unwrap_or("127.0.0.1".to_string()),
114-
listen_port : std::env::var("LISTEN_PORT").unwrap_or("8080".to_string()).parse().unwrap(),
115-
path_prefix : "/".to_string(),
116-
external_url: "http://localhost:8080".to_string(),
117-
118-
link_duration : Duration::try_hours(12).unwrap(),
119-
session_duration: Duration::try_days(30).unwrap(),
120-
121-
secrets_cleanup_interval: Duration::try_hours(24).unwrap(),
122-
123-
title: "MagicEntry".to_string(),
124-
static_path: "static".to_string(),
125-
126-
auth_url_enable : true,
127-
auth_url_user_header : "X-Remote-User".to_string(),
128-
auth_url_email_header : "X-Remote-Email".to_string(),
129-
auth_url_name_header : "X-Remote-Name".to_string(),
130-
auth_url_realms_header: "X-Remote-Realms".to_string(),
131-
132-
oidc_code_duration: Duration::try_minutes(1).unwrap(),
133-
134-
saml_cert_pem_path: "saml_cert.pem".to_string(),
135-
saml_key_pem_path : "saml_key.pem".to_string(),
136-
137-
smtp_enable : false,
138-
smtp_url : "smtp://localhost:25".to_string(),
139-
smtp_from : "{title} <magicentry@example.com>".to_string(),
140-
smtp_subject: "{title} Login".to_string(),
141-
smtp_body : "Click the link to login: {magic_link}".to_string(),
142-
143-
request_enable : false,
144-
request_url : "https://www.cinotify.cc/api/notify".to_string(),
145-
request_method : "POST".to_string(),
146-
request_data : Some(std::env::var("REQUEST_DATA").unwrap_or("to={email}&subject={title} Login&body=Click the link to login: <a href=\"{magic_link}\">Login</a>&type=text/html".to_string())),
147-
request_content_type: "application/x-www-form-urlencoded".to_string(),
148-
149-
webauthn_enable: true,
150-
151-
// force_https_redirects: true,
152-
153-
users: vec![],
154-
users_file: None,
155-
users_sql_query_all: None,
156-
users_sql_query_email: None,
157-
users_sql_url: None,
158-
159-
services: Services(vec![]),
160-
}
161-
}
114+
pub oidc_code_duration: Duration = Duration::try_minutes(1).unwrap(),
115+
116+
pub saml_cert_pem_path: String = "saml_cert.pem".to_string(),
117+
pub saml_key_pem_path: String = "saml_key.pem".to_string(),
118+
119+
pub smtp_enable: bool = false,
120+
pub smtp_url: String = "smtp://localhost:25".to_string(),
121+
pub smtp_from: String = "{title} <magicentry@example.com>".to_string(),
122+
pub smtp_subject: String = "{title} Login".to_string(),
123+
pub smtp_body: String = "Click the link to login: {magic_link}".to_string(),
124+
125+
pub request_enable: bool = false,
126+
pub request_url: String = "https://www.cinotify.cc/api/notify".to_string(),
127+
pub request_method: String = "POST".to_string(),
128+
pub request_data: Option<String> = Some(std::env::var("REQUEST_DATA").unwrap_or("to={email}&subject={title} Login&body=Click the link to login: <a href=\"{magic_link}\">Login</a>&type=text/html".to_string())),
129+
pub request_content_type: String = "application/x-www-form-urlencoded".to_string(),
130+
131+
pub webauthn_enable: bool = true,
132+
133+
users: Vec<User> = vec![],
134+
pub users_file: Option<String> = None,
135+
pub users_sql_query_all: Option<String> = None,
136+
pub users_sql_query_email: Option<String> = None,
137+
pub users_sql_url: Option<String> = None,
138+
pub services: Services = Services::default(),
162139
}
163140

164141
impl Config {
@@ -183,11 +160,8 @@ impl Config {
183160
/// Note that live-updating the `CONFIG_FILE` environment variable
184161
/// is **NOT** supported (and is probably impossible anyway)
185162
pub async fn reload(config_path: &str, config: Arc<ArcSwap<Config>>) -> anyhow::Result<()> {
186-
let new_config: Arc<Config> = Self::reload_from_path(config_path).await?.into();
187-
// TODO: secrets and static pages still use the global config, updating it for the time being
188-
let mut config_guard = CONFIG.write().await;
189-
*config_guard = new_config.clone();
190-
config.store(new_config);
163+
let new_config = Self::reload_from_path(config_path).await?;
164+
new_config.replace(config).await?;
191165
Ok(())
192166
}
193167

@@ -312,30 +286,33 @@ impl Config {
312286
result
313287
}
314288

289+
async fn replace(self, config: Arc<ArcSwap<Config>>) -> anyhow::Result<()> {
290+
let new_config_arc = Arc::new(self);
291+
292+
// TODO: secrets and static pages still use the global config, updating it for the time being
293+
let mut config_guard = CONFIG.write().await;
294+
*config_guard = new_config_arc.clone();
295+
config.store(new_config_arc);
296+
Ok(())
297+
}
298+
315299
/// Enterprise-only feature
316300
pub async fn load_from_db(db: &Database) -> anyhow::Result<Option<Self>> {
317301
info!("Loading config from database");
318302
let config = ConfigKV::get(&ConfigKeys::Config, db)
319303
.await?
320304
.and_then(|c| serde_json::from_str::<Self>(&c).ok());
321-
println!("{config:?}");
322305
Ok(config)
323306
}
324307

325-
/// Enterprise-only feature
326308
pub async fn reload_from_db(config: Arc<ArcSwap<Config>>, db: &Database) -> anyhow::Result<()> {
327309
let Some(new_config) = Self::load_from_db(db).await? else {
328310
return Err(anyhow::anyhow!("Failed to load config from database"));
329311
};
330-
let new_config_arc = Arc::new(new_config.clone());
331-
// TODO: secrets and static pages still use the global config, updating it for the time being
332-
let mut config_guard = CONFIG.write().await;
333-
*config_guard = new_config_arc.clone();
334-
config.store(new_config_arc.into());
312+
new_config.replace(config).await?;
335313
Ok(())
336314
}
337315

338-
/// Enterprise-only feature
339316
pub async fn save_to_db(&self, db: &Database) -> anyhow::Result<()> {
340317
info!("Saving config to database");
341318
ConfigKV::set(&ConfigKeys::Config, Some(serde_json::to_string(self)?), db).await?;

src/service.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ pub struct ServiceOIDC {
5858
pub struct Services(pub Vec<Service>);
5959

6060
impl Services {
61-
#[must_use]
6261
pub fn get(&self, name: &str) -> Option<&Service> {
6362
self.0.iter().find(|s| s.name == name)
6463
}
@@ -68,7 +67,6 @@ impl Services {
6867
}
6968

7069
/// Returns all the services that the provided user has access to
71-
#[must_use]
7270
pub fn from_user(&self, user: &User) -> Self {
7371
let res = self
7472
.0
@@ -81,7 +79,6 @@ impl Services {
8179
}
8280

8381
/// Returns the first service that matches the given OIDC client ID
84-
#[must_use]
8582
pub fn from_oidc_client_id(&self, client_id: &str) -> Option<Service> {
8683
self.0
8784
.iter()
@@ -90,7 +87,6 @@ impl Services {
9087
}
9188

9289
/// Returns the first service that matches the given OIDC redirect URL
93-
#[must_use]
9490
pub fn from_oidc_redirect_url(&self, redirect_url: &url::Url) -> Option<Service> {
9591
self.0
9692
.iter()
@@ -103,7 +99,6 @@ impl Services {
10399
}
104100

105101
/// Returns the first service that matches the given SAML entity ID
106-
#[must_use]
107102
pub fn from_saml_entity_id(&self, entity_id: &str) -> Option<Service> {
108103
self.0
109104
.iter()
@@ -112,7 +107,6 @@ impl Services {
112107
}
113108

114109
/// Returns the first service that matches the given redirect URL
115-
#[must_use]
116110
pub fn from_saml_redirect_url(&self, redirect_url: &url::Url) -> Option<Service> {
117111
self.0
118112
.iter()
@@ -125,7 +119,6 @@ impl Services {
125119
}
126120

127121
/// Returns the first service that matches the given redirect URL
128-
#[must_use]
129122
pub fn from_auth_url_origin(&self, origin: &url::Origin) -> Option<Service> {
130123
let origin_str = origin.ascii_serialization();
131124
self.0

0 commit comments

Comments
 (0)