diff --git a/.gitignore b/.gitignore index e68a3e1e..8ee21062 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ NDK *.sqlite* *.db *.db-journal + +.vscode diff --git a/Cargo.lock b/Cargo.lock index d7014c4c..c67a1ae5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,6 +275,7 @@ dependencies = [ "log", "log-panics", "openssl-sys", + "regex", "rocket", "rocket_cors", "rust-embed", diff --git a/README.md b/README.md index 4766e574..14f6feb4 100644 --- a/README.md +++ b/README.md @@ -60,16 +60,33 @@ Available options: # Additional regex CORS origins to allow (e.g. for sideloaded browser extensions) #cors_regex = ["chrome-extension://yourextensionidhere"] + +# Allow official ActivityWatch Chrome extension? (default: true) +#cors_allow_aw_chrome_extension = true + +# Allow all Firefox extensions? (default: false, DANGEROUS) +#cors_allow_all_mozilla_extension = false ``` +#### Persistence and Settings UI + +The CORS-related settings (`cors`, `cors_regex`, `cors_allow_aw_chrome_extension`, and `cors_allow_all_mozilla_extension`) follow a precedence logic between the configuration file and the database: + +- **TOML Precedence**: If a field is explicitly defined in your `config.toml`, it takes absolute precedence. The server will use the value from the file, and that setting will be **read-only** in the Web UI (marked as "Fixed in config file"). +- **Database Fallback**: If a field is **missing** or commented out in the `config.toml`, the server will look for it in the database. These can be managed and edited via the **Security & CORS** modal in the Settings page. +- **Initial Setup**: On the first start, a default `config.toml` is created with all settings commented out, allowing the Web UI to take control of the configuration by default while providing a template for manual overrides. + +> [!IMPORTANT] +> **Server Restart Required**: Changing any CORS-related settings (whether via `config.toml` or the Web UI) requires stopping and restarting the server for the changes to take effect. These settings are loaded into memory once during the server's initialization and are not hot-reloadable. + #### Custom CORS Origins By default, the server allows requests from: - The server's own origin (`http://127.0.0.1:`, `http://localhost:`) -- The official Chrome extension (`chrome-extension://nglaklhklhcoonedhgnpgddginnjdadi`) -- All Firefox extensions (`moz-extension://.*`) +- The official Chrome extension (`chrome-extension://nglaklhklhcoonedhgnpgddginnjdadi`) if `cors_allow_aw_chrome_extension` is true (default). +- All Firefox extensions (`moz-extension://.*`) ONLY IF `cors_allow_all_mozilla_extension` is set to true. -To allow additional origins (e.g. a sideloaded Chrome extension), add them to your config: +To allow additional origins (e.g. a sideloaded Chrome extension), add them to your `cors` or `cors_regex` config: ```toml # Allow a specific sideloaded Chrome extension diff --git a/aw-datastore/src/worker.rs b/aw-datastore/src/worker.rs index b116a1f3..c42b5ce2 100644 --- a/aw-datastore/src/worker.rs +++ b/aw-datastore/src/worker.rs @@ -294,7 +294,10 @@ impl DatastoreWorker { Err(e) => Err(e), }, Command::SetKeyValue(key, data) => match ds.insert_key_value(tx, &key, &data) { - Ok(()) => Ok(Response::Empty()), + Ok(()) => { + self.commit = true; + Ok(Response::Empty()) + } Err(e) => Err(e), }, Command::GetKeyValue(key) => match ds.get_key_value(tx, &key) { @@ -302,7 +305,10 @@ impl DatastoreWorker { Err(e) => Err(e), }, Command::DeleteKeyValue(key) => match ds.delete_key_value(tx, &key) { - Ok(()) => Ok(Response::Empty()), + Ok(()) => { + self.commit = true; + Ok(Response::Empty()) + } Err(e) => Err(e), }, Command::Close() => { diff --git a/aw-server/Cargo.toml b/aw-server/Cargo.toml index da6c908e..d12edb23 100644 --- a/aw-server/Cargo.toml +++ b/aw-server/Cargo.toml @@ -29,6 +29,7 @@ uuid = { version = "1.3", features = ["serde", "v4"] } clap = { version = "4.1", features = ["derive", "cargo"] } log-panics = { version = "2", features = ["with-backtrace"]} rust-embed = { version = "8.0.0", features = ["interpolate-folder-path", "debug-embed"] } +regex = "1" aw-datastore = { path = "../aw-datastore" } aw-models = { path = "../aw-models" } diff --git a/aw-server/src/config.rs b/aw-server/src/config.rs index 35ac435c..f357901f 100644 --- a/aw-server/src/config.rs +++ b/aw-server/src/config.rs @@ -1,5 +1,6 @@ -use std::fs::File; -use std::io::{Read, Write}; +use std::collections::HashSet; +use std::fs::{self, File}; +use std::io::Write; use rocket::config::Config; use rocket::data::{Limits, ToByteUnit}; @@ -7,6 +8,14 @@ use rocket::log::LogLevel; use serde::{Deserialize, Serialize}; use crate::dirs; +use serde_json; + +pub const CORS_FIELDS: &[&str] = &[ + "cors", + "cors_regex", + "cors_allow_aw_chrome_extension", + "cors_allow_all_mozilla_extension", +]; // Far from an optimal way to solve it, but works and is simple static mut TESTING: bool = true; @@ -19,7 +28,7 @@ pub fn is_testing() -> bool { unsafe { TESTING } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct AWConfig { #[serde(default = "default_address")] pub address: String, @@ -36,6 +45,12 @@ pub struct AWConfig { #[serde(default = "default_cors")] pub cors_regex: Vec, + #[serde(default = "default_true")] + pub cors_allow_aw_chrome_extension: bool, + + #[serde(default = "default_false")] + pub cors_allow_all_mozilla_extension: bool, + // A mapping of watcher names to paths where the // custom visualizations are located. #[serde(default = "default_custom_static")] @@ -50,6 +65,8 @@ impl Default for AWConfig { testing: default_testing(), cors: default_cors(), cors_regex: default_cors(), + cors_allow_aw_chrome_extension: default_true(), + cors_allow_all_mozilla_extension: default_false(), custom_static: default_custom_static(), } } @@ -91,6 +108,14 @@ fn default_testing() -> bool { is_testing() } +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + fn default_port() -> u16 { if is_testing() { 5666 @@ -103,14 +128,40 @@ fn default_custom_static() -> std::collections::HashMap { std::collections::HashMap::new() } -pub fn create_config(testing: bool) -> AWConfig { - set_testing(testing); +pub fn get_config_path(testing: bool) -> (std::path::PathBuf, Vec) { let mut config_path = dirs::get_config_dir().unwrap(); if !testing { config_path.push("config.toml") } else { config_path.push("config-testing.toml") } + if !config_path.is_file() { + return ( + config_path, + CORS_FIELDS.iter().map(|f| f.to_string()).collect(), + ); + } + let content = fs::read_to_string(&config_path).unwrap_or_default(); + let toml_value: toml::Value = + toml::from_str(&content).unwrap_or_else(|_| toml::Value::Table(toml::Table::new())); + + let file_keys: HashSet = toml_value + .as_table() + .map(|t| t.keys().cloned().collect()) + .unwrap_or_default(); + + let missing = CORS_FIELDS + .iter() + .filter(|f| !file_keys.contains(&f.to_string())) + .map(|f| f.to_string()) + .collect(); + + (config_path, missing) +} + +pub fn create_config(testing: bool, datastore: &aw_datastore::Datastore) -> AWConfig { + set_testing(testing); + let (config_path, missing_cors_fields) = get_config_path(testing); /* If there is no config file, create a new config file with default values but every value is * commented out by default in case we would change a default value at some point in the future */ @@ -132,12 +183,22 @@ pub fn create_config(testing: bool) -> AWConfig { } debug!("Reading config at {:?}", config_path); - let mut rfile = File::open(config_path).expect("Failed to open config file for reading"); - let mut content = String::new(); - rfile - .read_to_string(&mut content) - .expect("Failed to read config as a string"); - let aw_config: AWConfig = toml::from_str(&content).expect("Failed to parse config file"); + let content = fs::read_to_string(config_path).expect("Failed to read config file"); + let toml_value: toml::Value = toml::from_str(&content).expect("Failed to parse config file"); + let mut aw_config: AWConfig = + toml_value.try_into().expect("Failed to convert TOML value to AWConfig"); + + for field in missing_cors_fields { + let Ok(value_str) = datastore.get_key_value(&format!("cors.{field}")) else { continue }; + + match field.as_str() { + "cors" => aw_config.cors = serde_json::from_str(&value_str).unwrap_or_default(), + "cors_regex" => aw_config.cors_regex = serde_json::from_str(&value_str).unwrap_or_default(), + "cors_allow_aw_chrome_extension" => aw_config.cors_allow_aw_chrome_extension = serde_json::from_str(&value_str).unwrap_or_default(), + "cors_allow_all_mozilla_extension" => aw_config.cors_allow_all_mozilla_extension = serde_json::from_str(&value_str).unwrap_or_default(), + _ => {} + } + } aw_config } diff --git a/aw-server/src/endpoints/cors.rs b/aw-server/src/endpoints/cors.rs index 530be147..54a47820 100644 --- a/aw-server/src/endpoints/cors.rs +++ b/aw-server/src/endpoints/cors.rs @@ -6,21 +6,26 @@ use crate::config::AWConfig; pub fn cors(config: &AWConfig) -> rocket_cors::Cors { let root_url = format!("http://127.0.0.1:{}", config.port); let root_url_localhost = format!("http://localhost:{}", config.port); - let mut allowed_exact_origins = vec![root_url, root_url_localhost]; + let mut allowed_exact_origins = vec![root_url.clone(), root_url_localhost.clone()]; allowed_exact_origins.extend(config.cors.clone()); - if config.testing { - allowed_exact_origins.push("http://127.0.0.1:27180".to_string()); - allowed_exact_origins.push("http://localhost:27180".to_string()); + let mut allowed_regex_origins = config.cors_regex.clone(); + + if config.cors_allow_aw_chrome_extension { + allowed_regex_origins.push("chrome-extension://nglaklhklhcoonedhgnpgddginnjdadi".to_string()); } - let mut allowed_regex_origins = vec![ - "chrome-extension://nglaklhklhcoonedhgnpgddginnjdadi".to_string(), + + if config.cors_allow_all_mozilla_extension { // Every version of a mozilla extension has its own ID to avoid fingerprinting, so we // unfortunately have to allow all extensions to have access to aw-server - "moz-extension://.*".to_string(), - ]; - allowed_regex_origins.extend(config.cors_regex.clone()); + allowed_regex_origins.push("moz-extension://.*".to_string()); + } + if config.testing { + allowed_exact_origins.extend(vec![ + "http://127.0.0.1:27180".to_string(), + "http://localhost:27180".to_string(), + ]); allowed_regex_origins.push("chrome-extension://.*".to_string()); } @@ -32,13 +37,29 @@ pub fn cors(config: &AWConfig) -> rocket_cors::Cors { let allowed_headers = AllowedHeaders::all(); // TODO: is this unsafe? // You can also deserialize this - rocket_cors::CorsOptions { + let cors_options = rocket_cors::CorsOptions { allowed_origins, allowed_methods, allowed_headers, allow_credentials: false, ..Default::default() + }; + + match cors_options.to_cors() { + Ok(cors) => cors, + Err(e) => { + error!("Failed to set up CORS with provided origins: {:?}", e); + error!("Exact origins: {:?}", allowed_exact_origins); + error!("Regex origins: {:?}", allowed_regex_origins); + // Fallback to a safe default to allow the server to at least start + let fallback_origins = vec![root_url, root_url_localhost]; + let empty_regex: &[String] = &[]; + rocket_cors::CorsOptions { + allowed_origins: AllowedOrigins::some(&fallback_origins, empty_regex), + ..Default::default() + } + .to_cors() + .expect("Safe default CORS should always work") + } } - .to_cors() - .expect("Failed to set up CORS") } diff --git a/aw-server/src/endpoints/cors_config.rs b/aw-server/src/endpoints/cors_config.rs new file mode 100644 index 00000000..46c77490 --- /dev/null +++ b/aw-server/src/endpoints/cors_config.rs @@ -0,0 +1,120 @@ +use crate::endpoints::ServerState; +use rocket::http::Status; +use rocket::serde::json::Json; +use rocket::State; +use serde::{Deserialize, Serialize}; +use crate::endpoints::HttpErrorJson; +use crate::config; + +#[derive(Serialize, Deserialize)] +pub struct CorsConfig { + pub cors: Vec, + pub cors_regex: Vec, + pub cors_allow_aw_chrome_extension: bool, + pub cors_allow_all_mozilla_extension: bool, + pub in_file: Vec, + #[serde(skip_deserializing)] + pub needs_restart: bool, +} + +#[get("/")] +pub fn cors_config_get(state: &State) -> Result, HttpErrorJson> { + let config = endpoints_get_lock!(state.config); + let (_, missing_fields) = config::get_config_path(config.testing); + let in_file = config::CORS_FIELDS + .iter() + .filter(|&&f| !missing_fields.contains(&f.to_string())) + .map(|&f| f.to_string()) + .collect(); + Ok(Json(CorsConfig { + cors: config.cors.clone(), + cors_regex: config.cors_regex.clone(), + cors_allow_aw_chrome_extension: config.cors_allow_aw_chrome_extension, + cors_allow_all_mozilla_extension: config.cors_allow_all_mozilla_extension, + in_file, + needs_restart: true, + })) +} + +#[post("/", data = "")] +pub fn cors_config_set( + state: &State, + new_cors: Json, +) -> Result { + let datastore = endpoints_get_lock!(state.datastore); + + // Identify which fields are allowed to be modified (those missing from the TOML file) + let (_, missing_fields) = { + let config = endpoints_get_lock!(state.config); + config::get_config_path(config.testing) + }; + + // Validate exact origins before persisting + if missing_fields.contains(&"cors".to_string()) { + for origin in &new_cors.cors { + if !origin.starts_with("http://") && !origin.starts_with("https://") { + return Err(HttpErrorJson::new( + Status::BadRequest, + format!("Invalid CORS origin: {}. Must start with 'http://' or 'https://'", origin), + )); + } + } + } + + // Validate regular expressions before persisting + if missing_fields.contains(&"cors_regex".to_string()) { + for pattern in &new_cors.cors_regex { + if let Err(e) = regex::Regex::new(pattern) { + return Err(HttpErrorJson::new( + Status::BadRequest, + format!("Invalid regular expression in CORS settings: {}. Error: {}", pattern, e), + )); + } + } + } + + let fields = [ + ("cors", serde_json::to_string(&new_cors.cors).unwrap()), + ("cors_regex", serde_json::to_string(&new_cors.cors_regex).unwrap()), + ( + "cors_allow_aw_chrome_extension", + serde_json::to_string(&new_cors.cors_allow_aw_chrome_extension).unwrap(), + ), + ( + "cors_allow_all_mozilla_extension", + serde_json::to_string(&new_cors.cors_allow_all_mozilla_extension).unwrap(), + ), + ]; + + for (field, value_str) in fields { + // Only save to datastore if the field is not fixed in the config file + if missing_fields.contains(&field.to_string()) { + let key = format!("cors.{}", field); + datastore.set_key_value(&key, &value_str).map_err(|e| { + HttpErrorJson::new( + Status::InternalServerError, + format!("Failed to save {}: {:?}", field, e), + ) + })?; + } + } + + // Update the in-memory config for permitted fields so that GET reflect the changes immediately + { + let mut config = endpoints_get_lock!(state.config); + if missing_fields.contains(&"cors".to_string()) { + config.cors = new_cors.cors.clone(); + } + if missing_fields.contains(&"cors_regex".to_string()) { + config.cors_regex = new_cors.cors_regex.clone(); + } + if missing_fields.contains(&"cors_allow_aw_chrome_extension".to_string()) { + config.cors_allow_aw_chrome_extension = new_cors.cors_allow_aw_chrome_extension; + } + if missing_fields.contains(&"cors_allow_all_mozilla_extension".to_string()) { + config.cors_allow_all_mozilla_extension = new_cors.cors_allow_all_mozilla_extension; + } + } + + Ok(Status::Ok) +} diff --git a/aw-server/src/endpoints/mod.rs b/aw-server/src/endpoints/mod.rs index f6c9271e..71baefb5 100644 --- a/aw-server/src/endpoints/mod.rs +++ b/aw-server/src/endpoints/mod.rs @@ -42,6 +42,7 @@ pub struct ServerState { pub datastore: Mutex, pub asset_resolver: AssetResolver, pub device_id: String, + pub config: Mutex, } #[macro_use] @@ -53,6 +54,7 @@ mod hostcheck; mod import; mod query; mod settings; +mod cors_config; pub use util::HttpErrorJson; @@ -189,6 +191,13 @@ pub fn build_rocket(server_state: ServerState, config: AWConfig) -> rocket::Rock settings::settings_get, ], ) + .mount( + "/api/0/cors-config", + routes![ + cors_config::cors_config_get, + cors_config::cors_config_set, + ], + ) .mount("/", rocket_cors::catch_all_options_routes()); // for each custom static directory, mount it at the given name diff --git a/aw-server/src/main.rs b/aw-server/src/main.rs index 2cbf39e7..b3aa651e 100644 --- a/aw-server/src/main.rs +++ b/aw-server/src/main.rs @@ -79,7 +79,27 @@ async fn main() -> Result<(), rocket::Error> { info!("Running server in Testing mode"); } - let mut config = config::create_config(testing); + // Set db path if overridden + let db_path: String = if let Some(dbpath) = opts.dbpath.clone() { + dbpath + } else { + dirs::db_path(testing) + .expect("Failed to get db path") + .to_str() + .unwrap() + .to_string() + }; + info!("Using DB at path {:?}", db_path); + + // Only use legacy import if opts.dbpath is not set + let legacy_import = !opts.no_legacy_import && opts.dbpath.is_none(); + if opts.dbpath.is_some() { + info!("Since custom dbpath is set, --no-legacy-import is implied"); + } + + let datastore = aw_datastore::Datastore::new(db_path, legacy_import); + + let mut config = config::create_config(testing, &datastore); // set host if overridden if let Some(host) = opts.host { @@ -114,27 +134,9 @@ async fn main() -> Result<(), rocket::Error> { } } - // Set db path if overridden - let db_path: String = if let Some(dbpath) = opts.dbpath.clone() { - dbpath - } else { - dirs::db_path(testing) - .expect("Failed to get db path") - .to_str() - .unwrap() - .to_string() - }; - info!("Using DB at path {:?}", db_path); - let asset_path = opts.webpath.map(|webpath| PathBuf::from(webpath)); info!("Using aw-webui assets at path {:?}", asset_path); - // Only use legacy import if opts.dbpath is not set - let legacy_import = !opts.no_legacy_import && opts.dbpath.is_none(); - if opts.dbpath.is_some() { - info!("Since custom dbpath is set, --no-legacy-import is implied"); - } - let device_id: String = if let Some(id) = opts.device_id { id } else { @@ -144,9 +146,10 @@ async fn main() -> Result<(), rocket::Error> { let server_state = endpoints::ServerState { // Even if legacy_import is set to true it is disabled on Android so // it will not happen there - datastore: Mutex::new(aw_datastore::Datastore::new(db_path, legacy_import)), + datastore: Mutex::new(datastore), asset_resolver: endpoints::AssetResolver::new(asset_path), device_id, + config: Mutex::new(config.clone()), }; let _rocket = endpoints::build_rocket(server_state, config) diff --git a/aw-server/tests/api.rs b/aw-server/tests/api.rs index c5da04b0..4c9bb6b8 100644 --- a/aw-server/tests/api.rs +++ b/aw-server/tests/api.rs @@ -20,12 +20,13 @@ mod api_tests { use rocket::local::blocking::Client; fn setup_testserver() -> rocket::Rocket { + let aw_config = config::AWConfig::default(); let state = endpoints::ServerState { datastore: Mutex::new(aw_datastore::Datastore::new_in_memory(false)), asset_resolver: endpoints::AssetResolver::new(None), device_id: "test_id".to_string(), + config: Mutex::new(aw_config.clone()), }; - let aw_config = config::AWConfig::default(); endpoints::build_rocket(state, aw_config) }