Skip to content

Commit e6c50f0

Browse files
committed
Add the ability to update each config field procedurally
1 parent 1892cbd commit e6c50f0

5 files changed

Lines changed: 117 additions & 143 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/app_build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub async fn axum_build(
5555

5656
let state = AppState {
5757
db,
58-
config,
58+
config_arc: config,
5959
link_senders,
6060
key,
6161
webauthn,

src/config.rs

Lines changed: 112 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -24,141 +24,121 @@ use crate::service::Services;
2424
use crate::user::User;
2525
use crate::user_store::{FileUserStore, SQLUserStore, StaticUserStore, UserStore};
2626

27-
/// The actual, deserialized config data
28-
///
29-
/// To see what each field represents check out the [config.sample.yaml](https://github.com/dzervas/magicentry/blob/main/config.sample.yaml) file
30-
///
31-
/// TODO: Move the comments from here to the config.sample.yaml so the code
32-
/// is the source of truth
33-
// TODO: Generate a validation schema
34-
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
35-
#[serde(default, deny_unknown_fields)]
36-
#[allow(clippy::struct_excessive_bools)]
37-
pub struct Config {
38-
pub database_url: String,
39-
40-
pub listen_host: String,
41-
pub listen_port: u16,
42-
pub path_prefix: String,
43-
pub external_url: String,
27+
macro_rules! config_struct {
28+
(
29+
$(
30+
$(#[$meta:meta])*
31+
$pub:vis $name:ident: $type:ty = $default:expr
32+
),+ $(,)?
33+
) => {
34+
35+
/// The actual, deserialized config data
36+
///
37+
/// To see what each field represents check out the [config.sample.yaml](https://github.com/dzervas/magicentry/blob/main/config.sample.yaml) file
38+
///
39+
/// TODO: Move the comments from here to the config.sample.yaml so the code
40+
/// is the source of truth
41+
// TODO: Generate a validation schema
42+
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
43+
#[serde(default, deny_unknown_fields)]
44+
pub struct Config {
45+
$(
46+
$(#[$meta])*
47+
$pub $name: $type
48+
),+
49+
}
50+
51+
impl Default for Config {
52+
fn default() -> Self {
53+
Self {
54+
$( $name: $default ),+
55+
}
56+
}
57+
}
58+
59+
impl Config {
60+
pub async fn update_field(config_arc: Arc<ArcSwap<Config>>, db: &Database, field: &str, value: serde_json::Value) -> anyhow::Result<()> {
61+
let config_full = config_arc.load().to_owned();
62+
let mut config = Arc::unwrap_or_clone(config_full);
63+
match field {
64+
$(
65+
stringify!($name) => {
66+
config.$name = serde_json::from_value::<$type>(value)?;
67+
},
68+
)+
69+
_ => anyhow::bail!("Unknown field: {field}"),
70+
};
71+
72+
config.save_to_db(db).await?;
73+
config.replace(config_arc).await?;
74+
Ok(())
75+
}
76+
}
77+
};
78+
}
79+
80+
config_struct! {
81+
pub database_url: String = std::env::var("DATABASE_URL").unwrap_or("sqlite://database.db".to_string()),
82+
83+
pub listen_host: String = std::env::var("LISTEN_HOST").unwrap_or("127.0.0.1".to_string()),
84+
pub listen_port: u16 = std::env::var("LISTEN_PORT").unwrap_or("8080".to_string()).parse().unwrap(),
85+
pub path_prefix: String = "/".to_string(),
86+
pub external_url: String = "http://localhost:8080".to_string(),
4487

4588
#[serde(
4689
deserialize_with = "duration_str::deserialize_duration_chrono",
47-
serialize_with = "serialize_duration_chrono"
90+
serialize_with = "crate::config::serialize_duration_chrono"
4891
)]
49-
pub link_duration: Duration,
92+
pub link_duration: Duration = Duration::try_hours(12).unwrap(),
5093
#[serde(
5194
deserialize_with = "duration_str::deserialize_duration_chrono",
52-
serialize_with = "serialize_duration_chrono"
95+
serialize_with = "crate::config::serialize_duration_chrono"
5396
)]
54-
pub session_duration: Duration,
55-
56-
/// Interval for periodic cleanup of expired secrets
97+
pub session_duration: Duration = Duration::try_days(30).unwrap(),
5798
#[serde(
5899
deserialize_with = "duration_str::deserialize_duration_chrono",
59-
serialize_with = "serialize_duration_chrono"
100+
serialize_with = "crate::config::serialize_duration_chrono"
60101
)]
61-
pub secrets_cleanup_interval: Duration,
102+
pub secrets_cleanup_interval: Duration = Duration::try_hours(24).unwrap(),
62103

63-
pub title: String,
64-
pub static_path: String,
104+
pub title: String = "MagicEntry".to_string(),
105+
pub static_path: String = "static".to_string(),
65106

66-
pub auth_url_enable: bool,
67-
pub auth_url_user_header: String,
68-
pub auth_url_name_header: String,
69-
pub auth_url_email_header: String,
70-
pub auth_url_realms_header: String,
107+
pub auth_url_enable: bool = true,
108+
pub auth_url_user_header: String = "X-Auth-User".to_string(),
109+
pub auth_url_name_header: String = "X-Auth-Name".to_string(),
110+
pub auth_url_email_header: String = "X-Auth-Email".to_string(),
111+
pub auth_url_realms_header: String = "X-Auth-Realms".to_string(),
71112

72113
#[serde(
73114
deserialize_with = "duration_str::deserialize_duration_chrono",
74-
serialize_with = "serialize_duration_chrono"
115+
serialize_with = "crate::config::serialize_duration_chrono"
75116
)]
76-
pub oidc_code_duration: Duration,
77-
78-
pub saml_cert_pem_path: String,
79-
pub saml_key_pem_path: String,
80-
81-
pub smtp_enable: bool,
82-
pub smtp_url: String,
83-
pub smtp_from: String,
84-
pub smtp_subject: String,
85-
pub smtp_body: String,
86-
87-
pub request_enable: bool,
88-
pub request_url: String,
89-
pub request_method: String,
90-
pub request_data: Option<String>,
91-
pub request_content_type: String,
92-
93-
pub webauthn_enable: bool,
94-
95-
// pub force_https_redirects: bool,
96-
// Private to avoid reading from the field instead of the user store
97-
users: Vec<User>,
98-
/// Path to a file containing the user definitions
99-
pub users_file: Option<String>,
100-
pub users_sql_query_all: Option<String>,
101-
pub users_sql_query_email: Option<String>,
102-
pub users_sql_url: Option<String>,
103-
pub services: Services,
104-
}
105-
106-
impl Default for Config {
107-
#[allow(clippy::or_fun_call)]
108-
#[allow(clippy::unwrap_used)] // All the cases are either const or on start (e.g. port)
109-
fn default() -> Self {
110-
Self {
111-
database_url: std::env::var("DATABASE_URL").unwrap_or("sqlite://database.db".to_string()),
112-
113-
listen_host : std::env::var("LISTEN_HOST").unwrap_or("127.0.0.1".to_string()),
114-
listen_port : std::env::var("LISTEN_PORT").unwrap_or("8080".to_string()).parse().unwrap(),
115-
path_prefix : "/".to_string(),
116-
external_url: "http://localhost:8080".to_string(),
117-
118-
link_duration : Duration::try_hours(12).unwrap(),
119-
session_duration: Duration::try_days(30).unwrap(),
120-
121-
secrets_cleanup_interval: Duration::try_hours(24).unwrap(),
122-
123-
title: "MagicEntry".to_string(),
124-
static_path: "static".to_string(),
125-
126-
auth_url_enable : true,
127-
auth_url_user_header : "X-Remote-User".to_string(),
128-
auth_url_email_header : "X-Remote-Email".to_string(),
129-
auth_url_name_header : "X-Remote-Name".to_string(),
130-
auth_url_realms_header: "X-Remote-Realms".to_string(),
131-
132-
oidc_code_duration: Duration::try_minutes(1).unwrap(),
133-
134-
saml_cert_pem_path: "saml_cert.pem".to_string(),
135-
saml_key_pem_path : "saml_key.pem".to_string(),
136-
137-
smtp_enable : false,
138-
smtp_url : "smtp://localhost:25".to_string(),
139-
smtp_from : "{title} <magicentry@example.com>".to_string(),
140-
smtp_subject: "{title} Login".to_string(),
141-
smtp_body : "Click the link to login: {magic_link}".to_string(),
142-
143-
request_enable : false,
144-
request_url : "https://www.cinotify.cc/api/notify".to_string(),
145-
request_method : "POST".to_string(),
146-
request_data : Some(std::env::var("REQUEST_DATA").unwrap_or("to={email}&subject={title} Login&body=Click the link to login: <a href=\"{magic_link}\">Login</a>&type=text/html".to_string())),
147-
request_content_type: "application/x-www-form-urlencoded".to_string(),
148-
149-
webauthn_enable: true,
150-
151-
// force_https_redirects: true,
152-
153-
users: vec![],
154-
users_file: None,
155-
users_sql_query_all: None,
156-
users_sql_query_email: None,
157-
users_sql_url: None,
158-
159-
services: Services(vec![]),
160-
}
161-
}
117+
pub oidc_code_duration: Duration = Duration::try_minutes(1).unwrap(),
118+
119+
pub saml_cert_pem_path: String = "saml_cert.pem".to_string(),
120+
pub saml_key_pem_path: String = "saml_key.pem".to_string(),
121+
122+
pub smtp_enable: bool = false,
123+
pub smtp_url: String = "smtp://localhost:25".to_string(),
124+
pub smtp_from: String = "{title} <magicentry@example.com>".to_string(),
125+
pub smtp_subject: String = "{title} Login".to_string(),
126+
pub smtp_body: String = "Click the link to login: {magic_link}".to_string(),
127+
128+
pub request_enable: bool = false,
129+
pub request_url: String = "https://www.cinotify.cc/api/notify".to_string(),
130+
pub request_method: String = "POST".to_string(),
131+
pub request_data: Option<String> = Some(std::env::var("REQUEST_DATA").unwrap_or("to={email}&subject={title} Login&body=Click the link to login: <a href=\"{magic_link}\">Login</a>&type=text/html".to_string())),
132+
pub request_content_type: String = "application/x-www-form-urlencoded".to_string(),
133+
134+
pub webauthn_enable: bool = true,
135+
136+
users: Vec<User> = vec![],
137+
pub users_file: Option<String> = None,
138+
pub users_sql_query_all: Option<String> = None,
139+
pub users_sql_query_email: Option<String> = None,
140+
pub users_sql_url: Option<String> = None,
141+
pub services: Services = Services::default(),
162142
}
163143

164144
impl Config {
@@ -183,11 +163,8 @@ impl Config {
183163
/// Note that live-updating the `CONFIG_FILE` environment variable
184164
/// is **NOT** supported (and is probably impossible anyway)
185165
pub async fn reload(config_path: &str, config: Arc<ArcSwap<Config>>) -> anyhow::Result<()> {
186-
let new_config: Arc<Config> = Self::reload_from_path(config_path).await?.into();
187-
// TODO: secrets and static pages still use the global config, updating it for the time being
188-
let mut config_guard = CONFIG.write().await;
189-
*config_guard = new_config.clone();
190-
config.store(new_config);
166+
let new_config = Self::reload_from_path(config_path).await?;
167+
new_config.replace(config).await?;
191168
Ok(())
192169
}
193170

@@ -312,30 +289,33 @@ impl Config {
312289
result
313290
}
314291

292+
async fn replace(self, config: Arc<ArcSwap<Config>>) -> anyhow::Result<()> {
293+
let new_config_arc = Arc::new(self);
294+
295+
// TODO: secrets and static pages still use the global config, updating it for the time being
296+
let mut config_guard = CONFIG.write().await;
297+
*config_guard = new_config_arc.clone();
298+
config.store(new_config_arc);
299+
Ok(())
300+
}
301+
315302
/// Enterprise-only feature
316303
pub async fn load_from_db(db: &Database) -> anyhow::Result<Option<Self>> {
317304
info!("Loading config from database");
318305
let config = ConfigKV::get(&ConfigKeys::Config, db)
319306
.await?
320307
.and_then(|c| serde_json::from_str::<Self>(&c).ok());
321-
println!("{config:?}");
322308
Ok(config)
323309
}
324310

325-
/// Enterprise-only feature
326311
pub async fn reload_from_db(config: Arc<ArcSwap<Config>>, db: &Database) -> anyhow::Result<()> {
327312
let Some(new_config) = Self::load_from_db(db).await? else {
328313
return Err(anyhow::anyhow!("Failed to load config from database"));
329314
};
330-
let new_config_arc = Arc::new(new_config.clone());
331-
// TODO: secrets and static pages still use the global config, updating it for the time being
332-
let mut config_guard = CONFIG.write().await;
333-
*config_guard = new_config_arc.clone();
334-
config.store(new_config_arc.into());
315+
new_config.replace(config).await?;
335316
Ok(())
336317
}
337318

338-
/// Enterprise-only feature
339319
pub async fn save_to_db(&self, db: &Database) -> anyhow::Result<()> {
340320
info!("Saving config to database");
341321
ConfigKV::set(&ConfigKeys::Config, Some(serde_json::to_string(self)?), db).await?;

src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub struct InFlightConfig(Arc<Config>);
141141
#[derive(Clone)]
142142
pub struct AppState {
143143
pub db: crate::Database,
144-
config: Arc<ArcSwap<Config>>,
144+
pub config_arc: Arc<ArcSwap<Config>>,
145145
pub user_store: Arc<dyn UserStore>,
146146
pub link_senders: Vec<Arc<dyn LinkSender>>,
147147

@@ -153,7 +153,7 @@ impl AppState {
153153
pub async fn send_magic_link(&self, user: &User, link: &str) -> Result<(), AppError> {
154154
// TODO: Make this concurrent and return multiple errors
155155
// It's ok to re-read the config here since it only uses the link_senders
156-
let config = self.config.load();
156+
let config = self.config_arc.load();
157157
for sender in &self.link_senders {
158158
sender.send_magic_link(user, link, &config).await?;
159159
}
@@ -166,7 +166,7 @@ impl AppState {
166166
request: axum::http::Request<axum::body::Body>,
167167
next: Next,
168168
) -> impl IntoResponse {
169-
let config_arc = state.config.load_full();
169+
let config_arc = state.config_arc.load_full();
170170

171171
let mut request = request;
172172
request.extensions_mut().insert(config_arc);

src/service.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ pub struct ServiceOIDC {
5858
pub struct Services(pub Vec<Service>);
5959

6060
impl Services {
61-
#[must_use]
6261
pub fn get(&self, name: &str) -> Option<&Service> {
6362
self.0.iter().find(|s| s.name == name)
6463
}
@@ -68,7 +67,6 @@ impl Services {
6867
}
6968

7069
/// Returns all the services that the provided user has access to
71-
#[must_use]
7270
pub fn from_user(&self, user: &User) -> Self {
7371
let res = self
7472
.0
@@ -81,7 +79,6 @@ impl Services {
8179
}
8280

8381
/// Returns the first service that matches the given OIDC client ID
84-
#[must_use]
8582
pub fn from_oidc_client_id(&self, client_id: &str) -> Option<Service> {
8683
self.0
8784
.iter()
@@ -90,7 +87,6 @@ impl Services {
9087
}
9188

9289
/// Returns the first service that matches the given OIDC redirect URL
93-
#[must_use]
9490
pub fn from_oidc_redirect_url(&self, redirect_url: &url::Url) -> Option<Service> {
9591
self.0
9692
.iter()
@@ -103,7 +99,6 @@ impl Services {
10399
}
104100

105101
/// Returns the first service that matches the given SAML entity ID
106-
#[must_use]
107102
pub fn from_saml_entity_id(&self, entity_id: &str) -> Option<Service> {
108103
self.0
109104
.iter()
@@ -112,7 +107,6 @@ impl Services {
112107
}
113108

114109
/// Returns the first service that matches the given redirect URL
115-
#[must_use]
116110
pub fn from_saml_redirect_url(&self, redirect_url: &url::Url) -> Option<Service> {
117111
self.0
118112
.iter()
@@ -125,7 +119,6 @@ impl Services {
125119
}
126120

127121
/// Returns the first service that matches the given redirect URL
128-
#[must_use]
129122
pub fn from_auth_url_origin(&self, origin: &url::Origin) -> Option<Service> {
130123
let origin_str = origin.ascii_serialization();
131124
self.0

0 commit comments

Comments
 (0)