diff --git a/.gitignore b/.gitignore index 33a45d9..9bf2fba 100644 --- a/.gitignore +++ b/.gitignore @@ -21,8 +21,9 @@ Thumbs.db # Local AI/Agent configuration CLAUDE.md AGENTS.md +GEMINI.md # Local documentation (not for repo) docs/20251127_AlphaGrowth_Testing_Bugs.md docs/design-onboarding-improvements.md -docs/plans/ +docs/plans/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 40b2707..bee0918 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,15 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.21" @@ -330,6 +339,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +dependencies = [ + "iana-time-zone", + "num-traits", + "windows-link", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -662,8 +682,10 @@ dependencies = [ "anyhow", "base64 0.21.7", "bson", + "chrono", "clap", "dialoguer", + "dirs", "futures", "indicatif", "inquire", @@ -683,6 +705,7 @@ dependencies = [ "toml", "tracing", "tracing-subscriber", + "url", "which", ] @@ -762,6 +785,27 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -1328,6 +1372,30 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -2085,6 +2153,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "parking_lot" version = "0.12.5" @@ -2412,6 +2486,17 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror 1.0.69", +] + [[package]] name = "regex" version = "1.12.2" @@ -3774,12 +3859,65 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index fd1755d..d0047f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ categories = ["command-line-utilities", "database"] [dependencies] tokio = { version = "1.35", features = ["full"] } tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] } -clap = { version = "4.4", features = ["derive"] } +clap = { version = "4.4", features = ["derive", "env"] } anyhow = "1.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -37,3 +37,21 @@ base64 = "0.21" mongodb = "3.4" bson = "2.9" mysql_async = "0.34" +dirs = "5.0" +url = "2.5" +chrono = { version = "0.4", default-features = false, features = ["clock"] } + +[[test]] +name = "fallback_test" +path = "tests/fallback_test.rs" +doc = false + +[[test]] +name = "state_test" +path = "tests/state_test.rs" +doc = false + +[[test]] +name = "interactive_serendb_test" +path = "tests/interactive_serendb_test.rs" +doc = false diff --git a/README.md b/README.md index f05dbf7..f791441 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,31 @@ SerenAI provides managed PostgreSQL databases optimized for AI workloads. When r - Job monitoring and logging - Optimized for large database transfers -To replicate to SerenDB, simply run: +### Option 1: Interactive Project Selection (Recommended) + +With just your API key set, the tool will interactively guide you through selecting your target project and database: + ```bash export SEREN_API_KEY="your-api-key" # Get from console.serendb.com + +database-replicator init \ + --source "postgresql://user:pass@source:5432/db" \ + --local +``` + +The tool will: + +1. Show a picker to select your SerenDB project +2. Automatically enable logical replication if needed +3. Create missing databases on the target +4. Save your selection for future `sync` commands + +### Option 2: Explicit Connection String + +If you already have your connection string, you can provide it directly: + +```bash +export SEREN_API_KEY="your-api-key" database-replicator init \ --source "postgresql://user:pass@source:5432/db" \ --target "postgresql://user:pass@your-db.serendb.com:5432/db" diff --git a/docs/20251205_APIKEY_Upgrade_Brainstorm.md b/docs/20251205_APIKEY_Upgrade_Brainstorm.md new file mode 100644 index 0000000..2f37cf5 --- /dev/null +++ b/docs/20251205_APIKEY_Upgrade_Brainstorm.md @@ -0,0 +1,32 @@ +I've got an idea I want to talk through with you. I'd like you to help me turn it into a fully formed design and spec (and eventually an implementation plan) Check out the current state of the project in our working directory to understand where we're starting off and then check the idea details below. Once done, ask me questions, one at a time, to help refine the idea. Ideally, the questions would be multiple choice, but open-ended questions are OK, too. Don't forget: only one question per message. Once you believe you understand what we're doing, stop and describe the design to me, in sections of maybe 200-300 words at a time, asking after each section whether it looks right so far. Keep in mind that whatever we fix here will affect the upstream repo /Users/taariqlewis/Projects/Seren_Projects/seren-replicator so all changes must be refactored upstream. + +Here's the idea + +1.would it be better UI/UX to remove the need for users to use the connecton string in the CLI for the databse the replicator/ +2. Instead have users use their SerenDB API key Alone +3. The API Key allows users to lists all projects and all databases so the user just needs to select their target database to replicate against. +4. And then the connection_string can be read from the project by the database replicator instead of the user entering it. Easier UI/UX. +5. This will also fix the `sync` issue where there's no direct API to look up a project by endpoint hostname. +6. The API is setup for users to adjust settings on project. +7. When init is run, the user selects their target project and then database. No more need for connection string. +8. When sync is run, the system selects their already targeted project and database so they are syncing to the same database. We want to avoid users having to be confused by different branches so just make sure they are syncing to same database. +9. Existing APIs: + + GET /api/projects → List user's projects + GET /api/projects/{project_id}/replication → Get replication settings + PATCH /api/projects/{project_id}/replication → Enable logical replication (set enabled: true) + GET /api/projects/{project_id}/branches → List branches + GET /api/projects/{project_id}/branches/{branch_id}/endpoints → List endpoints + GET /api/projects/{project_id}/branches/{branch_id}/databases → List databases + + 10. Proposed UPDATED Replicator flow: + +USER enters their API KEY +GET /api/projects → Show project picker +GET /api/projects/{id}/replication → Check if logical replication is enabled +If not enabled: PATCH /api/projects/{id}/replication with {"enabled": true} +USE the DEFAUTL branch from the project +GET /api/projects/{id}/branches/{bid}/databases → Show database picker +Start replication using the selected database's connection string +Store replication target for sync run so that sync is to the same replication target. + diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 3c99590..86c34f4 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -4,11 +4,13 @@ pub mod init; pub mod status; pub mod sync; +pub mod target; pub mod validate; pub mod verify; pub use init::init; pub use status::status; pub use sync::sync; +pub use target::command as target; pub use validate::validate; pub use verify::verify; diff --git a/src/commands/sync.rs b/src/commands/sync.rs index 1ca6d5a..eed321d 100644 --- a/src/commands/sync.rs +++ b/src/commands/sync.rs @@ -7,7 +7,8 @@ use crate::replication::{ create_publication, create_subscription, detect_subscription_state, drop_subscription, wait_for_sync, SubscriptionState, }; -use anyhow::{Context, Result}; +use crate::serendb::{resolve_target_mode, ConsoleClient, TargetMode}; +use anyhow::{anyhow, Context, Result}; /// Set up logical replication between source and target databases /// @@ -381,6 +382,63 @@ pub async fn sync( Ok(()) } +/// Resolve the effective target URL for sync, honoring saved SerenDB state when using API keys. +pub async fn resolve_target_for_sync( + target: Option, + api_key: Option, + source_url: &str, +) -> Result { + let mode = resolve_target_mode(target, api_key.clone())?; + + match mode { + TargetMode::ConnectionString(url) => Ok(url), + TargetMode::SavedState(state) => { + println!( + "\n\u{1F4C1} Using saved target: {}/{}", + state.project_name, state.branch_name + ); + println!(" Databases: {:?}\n", state.databases); + + if !state.source_matches(source_url) { + eprintln!("\u{26A0} Warning: Source database has changed since the last init run"); + eprintln!(" Saved for: {}", state.source_url_hash); + eprintln!(" Current: {}", source_url); + eprintln!(); + } + + let api_key = api_key + .or_else(|| std::env::var("SEREN_API_KEY").ok()) + .ok_or_else(|| { + anyhow!( + "SEREN_API_KEY required to refresh saved SerenDB credentials. Provide --api-key or set SEREN_API_KEY." + ) + })?; + + let primary_db = state + .databases + .first() + .cloned() + .ok_or_else(|| anyhow!("Saved target has no databases recorded. Re-run init."))?; + + let client = ConsoleClient::new(None, api_key); + let conn_str = client + .get_connection_string( + &state.project_id, + &state.branch_id, + &primary_db, + false, + ) + .await + .context("Failed to fetch connection string for saved SerenDB target")?; + + Ok(conn_str) + } + TargetMode::ApiKey(_) => anyhow::bail!( + "No saved SerenDB target found. Run 'database-replicator init' first or provide --target." + ), + } +} + /// Replace the database name in a PostgreSQL connection URL /// /// # Arguments diff --git a/src/commands/target.rs b/src/commands/target.rs new file mode 100644 index 0000000..9df2c0b --- /dev/null +++ b/src/commands/target.rs @@ -0,0 +1,48 @@ +use anyhow::{Context, Result}; +use clap::{Args, Subcommand}; + +use crate::state; + +#[derive(Args)] +pub struct TargetArgs { + #[command(subcommand)] + command: TargetCommands, +} + +#[derive(Subcommand)] +enum TargetCommands { + /// Set the target database URL + Set { + /// The PostgreSQL URL to set as the target + url: String, + }, + /// Unset the target database URL + Unset, + /// Show the current target database URL + Get, +} + +pub async fn command(args: TargetArgs) -> Result<()> { + match args.command { + TargetCommands::Set { url } => { + let mut state = state::load().context("Failed to load state")?; + state.target_url = Some(url.clone()); + state::save(&state).context("Failed to save state")?; + println!("Target database URL set to: {}", url); + } + TargetCommands::Unset => { + let mut state = state::load().context("Failed to load state")?; + state.target_url = None; + state::save(&state).context("Failed to save state")?; + println!("Target database URL unset."); + } + TargetCommands::Get => { + let state = state::load().context("Failed to load state")?; + match state.target_url { + Some(url) => println!("Current target database URL: {}", url), + None => println!("Target database URL is not set."), + } + } + } + Ok(()) +} diff --git a/src/interactive.rs b/src/interactive.rs index 76d35c8..4799438 100644 --- a/src/interactive.rs +++ b/src/interactive.rs @@ -4,11 +4,72 @@ use crate::{ filters::ReplicationFilter, migration, postgres, + serendb::ConsoleClient, table_rules::{QualifiedTable, TableRules}, }; use anyhow::{Context, Result}; use inquire::{Confirm, MultiSelect, Select, Text}; +/// Prompts the user to select a SerenDB project and database interactively. +/// +/// This function will: +/// 1. Get the SerenDB API key (from environment or prompt). +/// 2. Fetch and display a list of projects for the user to select. +/// 3. Fetch the default branch for the selected project. +/// 4. Fetch and display a list of databases for the user to select. +/// 5. Return the connection string for the selected database. +/// +/// # Returns +/// +/// A `Result` containing the connection string of the selected database. +pub async fn select_seren_database() -> Result { + print_header("Select SerenDB Target"); + + let api_key = get_api_key()?; + let client = ConsoleClient::new(None, api_key); + + // 1. Select a project + let projects = client.list_projects().await?; + if projects.is_empty() { + anyhow::bail!("No projects found for your account."); + } + let project_names: Vec = projects.iter().map(|p| p.name.clone()).collect(); + let selected_project_name = Select::new("Select a project:", project_names).prompt()?; + let selected_project = projects + .into_iter() + .find(|p| p.name == selected_project_name) + .unwrap(); + + // 2. Select a database + let branch = client.get_default_branch(&selected_project.id).await?; + let databases = client + .list_databases(&selected_project.id, &branch.id) + .await?; + if databases.is_empty() { + anyhow::bail!( + "Project '{}' has no databases in its default branch.", + selected_project.name + ); + } + let database_names: Vec = databases.iter().map(|db| db.name.clone()).collect(); + let selected_database_name = Select::new("Select a database:", database_names).prompt()?; + let selected_database = databases + .into_iter() + .find(|db| db.name == selected_database_name) + .unwrap(); + + // 3. Get connection string + let conn_str = client + .get_connection_string( + &selected_project.id, + &branch.id, + &selected_database.name, + false, + ) + .await?; + Ok(conn_str) +} + /// Wizard step state machine enum WizardStep { SelectDatabases, @@ -778,6 +839,38 @@ fn replace_database_in_url(url: &str, new_db_name: &str) -> Result { Ok(new_url) } +pub fn get_api_key() -> anyhow::Result { + use dialoguer::{theme::ColorfulTheme, Input}; + + // Try environment variable first + if let Ok(key) = std::env::var("SEREN_API_KEY") { + if !key.trim().is_empty() { + return Ok(key.trim().to_string()); + } + } + + // Prompt user interactively + println!("\nRemote execution requires a SerenDB API key for authentication."); + println!("\nYou can generate an API key at:"); + println!(" https://console.serendb.com/api-keys\n"); + + let key: String = Input::with_theme(&ColorfulTheme::default()) + .with_prompt("Enter your SerenDB API key") + .allow_empty(false) + .interact_text()?; + + if key.trim().is_empty() { + anyhow::bail!( + "API key is required for remote execution.\n\ + Set the SEREN_API_KEY environment variable or run interactively.\n\ + Get your API key at: https://console.serendb.com/api-keys\n\ + Or use --local to run replication on your machine instead" + ); + } + + Ok(key.trim().to_string()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index d0bfd01..c311d38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ pub mod remote; pub mod replication; pub mod serendb; pub mod sqlite; +pub mod state; pub mod table_rules; pub mod utils; diff --git a/src/main.rs b/src/main.rs index 1d07982..e468714 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ // ABOUTME: CLI entry point for database-replicator // ABOUTME: Parses commands and routes to appropriate handlers +use anyhow::Context; use clap::{Args, Parser, Subcommand}; use database_replicator::commands; @@ -19,6 +20,9 @@ struct Cli { /// Set the log level (error, warn, info, debug, trace) #[arg(long, global = true, default_value = "info")] log: String, + /// SerenDB API key for interactive target selection (falls back to SEREN_API_KEY env) + #[arg(long = "api-key", env = "SEREN_API_KEY", global = true)] + api_key: Option, #[command(subcommand)] command: Commands, } @@ -46,7 +50,7 @@ enum Commands { #[arg(long)] source: String, #[arg(long)] - target: String, + target: Option, /// Include only these databases (comma-separated) #[arg(long, value_delimiter = ',')] include_databases: Option>, @@ -68,7 +72,7 @@ enum Commands { #[arg(long)] source: String, #[arg(long)] - target: String, + target: Option, /// Skip confirmation prompt #[arg(short = 'y', long)] yes: bool, @@ -119,7 +123,7 @@ enum Commands { #[arg(long)] source: String, #[arg(long)] - target: String, + target: Option, /// Include only these databases (comma-separated) #[arg(long, value_delimiter = ',')] include_databases: Option>, @@ -152,7 +156,7 @@ enum Commands { #[arg(long)] source: String, #[arg(long)] - target: String, + target: Option, /// Include only these databases (comma-separated) #[arg(long, value_delimiter = ',')] include_databases: Option>, @@ -165,7 +169,7 @@ enum Commands { #[arg(long)] source: String, #[arg(long)] - target: String, + target: Option, /// Include only these databases (comma-separated) #[arg(long, value_delimiter = ',')] include_databases: Option>, @@ -179,12 +183,18 @@ enum Commands { #[arg(long, value_delimiter = ',')] exclude_tables: Option>, }, + /// Manage the target database URL + Target { + #[command(flatten)] + args: commands::target::TargetArgs, + }, } #[tokio::main] async fn main() -> anyhow::Result<()> { // We need to parse CLI args early to get the log level let cli = Cli::parse(); + let global_api_key = cli.api_key.clone(); // Initialize logging // 1. RUST_LOG environment variable has highest precedence @@ -215,6 +225,11 @@ async fn main() -> anyhow::Result<()> { exclude_tables, no_interactive, } => { + let state = database_replicator::state::load()?; + let target = target.or(state.target_url).ok_or_else(|| { + anyhow::anyhow!("Target database URL not provided and not set in state. Use `--target` or `database-replicator target set`.") + })?; + let filter = if !no_interactive { // Interactive mode (default) - prompt user to select databases and tables let (filter, rules) = @@ -250,6 +265,23 @@ async fn main() -> anyhow::Result<()> { seren_api, job_timeout, } => { + let mut state = database_replicator::state::load()?; + let mut target = target.or(state.target_url); + + if seren { + if let Some(t) = &target { + if !database_replicator::utils::is_serendb_target(t) { + anyhow::bail!("--seren flag is only compatible with SerenDB targets."); + } + } else { + target = Some(database_replicator::interactive::select_seren_database().await?); + } + } + + let target = target.ok_or_else(|| { + anyhow::anyhow!("Target database URL not provided and not set in state. Use `--target` or `database-replicator target set`.") + })?; + // Check if CLI filter flags were provided (skip interactive if so) let has_cli_filters = include_databases.is_some() || exclude_databases.is_some() @@ -303,9 +335,10 @@ async fn main() -> anyhow::Result<()> { if use_remote { tracing::info!("Using SerenAI cloud execution"); - return init_remote( + init_remote( source, - target, + target.clone(), + None, yes, final_include_databases, final_exclude_databases, @@ -317,61 +350,65 @@ async fn main() -> anyhow::Result<()> { job_timeout, cli.log, ) - .await; - } - - // Local execution path - // Clone filter values for potential fallback to remote - let fallback_include_dbs = final_include_databases.clone(); - let fallback_exclude_dbs = final_exclude_databases.clone(); - let fallback_include_tables = final_include_tables.clone(); - let fallback_exclude_tables = final_exclude_tables.clone(); + .await?; + } else { + // Local execution path + // Clone filter values for potential fallback to remote + let fallback_include_dbs = final_include_databases.clone(); + let fallback_exclude_dbs = final_exclude_databases.clone(); + let fallback_include_tables = final_include_tables.clone(); + let fallback_exclude_tables = final_exclude_tables.clone(); - let filter = database_replicator::filters::ReplicationFilter::new( - final_include_databases, - final_exclude_databases, - final_include_tables, - final_exclude_tables, - )?; - let table_rule_data = build_table_rules(&table_rules)?; - let filter = filter.with_table_rules(table_rule_data); + let filter = database_replicator::filters::ReplicationFilter::new( + final_include_databases, + final_exclude_databases, + final_include_tables, + final_exclude_tables, + )?; + let table_rule_data = build_table_rules(&table_rules)?; + let filter = filter.with_table_rules(table_rule_data); - let enable_sync = !no_sync; // Invert the flag: by default sync is enabled + let enable_sync = !no_sync; // Invert the flag: by default sync is enabled - // Run init with pre-flight checks, handle fallback to remote - match commands::init( - &source, - &target, - yes, - filter, - drop_existing, - enable_sync, - !no_resume, - local, // Pass whether --local was explicit - ) - .await - { - Ok(_) => Ok(()), - Err(e) if e.to_string().contains("PREFLIGHT_FALLBACK_TO_REMOTE") => { - // Auto-fallback to remote execution - init_remote( - source, - target, - yes, - fallback_include_dbs, - fallback_exclude_dbs, - fallback_include_tables, - fallback_exclude_tables, - drop_existing, - no_sync, - seren_api, - job_timeout, - cli.log, - ) - .await + // Run init with pre-flight checks, handle fallback to remote + match commands::init( + &source, + &target, + yes, + filter, + drop_existing, + enable_sync, + !no_resume, + local, // Pass whether --local was explicit + ) + .await + { + Ok(_) => {} + Err(e) if e.to_string().contains("PREFLIGHT_FALLBACK_TO_REMOTE") => { + // Auto-fallback to remote execution + init_remote( + source, + target.clone(), + None, // No saved target state in fallback path + yes, + fallback_include_dbs, + fallback_exclude_dbs, + fallback_include_tables, + fallback_exclude_tables, + drop_existing, + no_sync, + seren_api, + job_timeout, + cli.log, + ) + .await?; + } + Err(e) => return Err(e), } - Err(e) => Err(e), } + state.target_url = Some(target); + database_replicator::state::save(&state)?; + Ok(()) } Commands::Sync { source, @@ -386,6 +423,17 @@ async fn main() -> anyhow::Result<()> { project_id, console_api, } => { + let mut app_state = database_replicator::state::load()?; + let target_candidate = target.or(app_state.target_url.clone()); + let resolved_target = database_replicator::commands::sync::resolve_target_for_sync( + target_candidate, + global_api_key.clone(), + &source, + ) + .await?; + app_state.target_url = Some(resolved_target.clone()); + database_replicator::state::save(&app_state)?; + let filter = if !no_interactive { // Interactive mode (default) - prompt user to select databases and tables let (filter, rules) = @@ -405,12 +453,21 @@ async fn main() -> anyhow::Result<()> { // If project_id is provided and target is SerenDB, check/enable logical replication if let Some(ref project_id) = project_id { - if database_replicator::utils::is_serendb_target(&target) { + if database_replicator::utils::is_serendb_target(&resolved_target) { check_and_enable_logical_replication(project_id, &console_api).await?; } } - commands::sync(&source, &target, Some(filter), None, None, None, force).await + commands::sync( + &source, + &resolved_target, + Some(filter), + None, + None, + None, + force, + ) + .await } Commands::Status { source, @@ -418,6 +475,11 @@ async fn main() -> anyhow::Result<()> { include_databases, exclude_databases, } => { + let state = database_replicator::state::load()?; + let target = target.or(state.target_url).ok_or_else(|| { + anyhow::anyhow!("Target database URL not provided and not set in state. Use `--target` or `database-replicator target set`.") + })?; + let filter = database_replicator::filters::ReplicationFilter::new( include_databases, exclude_databases, @@ -434,6 +496,11 @@ async fn main() -> anyhow::Result<()> { include_tables, exclude_tables, } => { + let state = database_replicator::state::load()?; + let target = target.or(state.target_url).ok_or_else(|| { + anyhow::anyhow!("Target database URL not provided and not set in state. Use `--target` or `database-replicator target set`.") + })?; + let filter = database_replicator::filters::ReplicationFilter::new( include_databases, exclude_databases, @@ -442,42 +509,10 @@ async fn main() -> anyhow::Result<()> { )?; commands::verify(&source, &target, Some(filter)).await } + Commands::Target { args } => commands::target(args).await, } } -#[allow(clippy::too_many_arguments)] -fn get_api_key() -> anyhow::Result { - use dialoguer::{theme::ColorfulTheme, Input}; - - // Try environment variable first - if let Ok(key) = std::env::var("SEREN_API_KEY") { - if !key.trim().is_empty() { - return Ok(key.trim().to_string()); - } - } - - // Prompt user interactively - println!("\nRemote execution requires a SerenDB API key for authentication."); - println!("\nYou can generate an API key at:"); - println!(" https://console.serendb.com/api-keys\n"); - - let key: String = Input::with_theme(&ColorfulTheme::default()) - .with_prompt("Enter your SerenDB API key") - .allow_empty(false) - .interact_text()?; - - if key.trim().is_empty() { - anyhow::bail!( - "API key is required for remote execution.\n\ - Set the SEREN_API_KEY environment variable or run interactively.\n\ - Get your API key at: https://console.serendb.com/api-keys\n\ - Or use --local to run replication on your machine instead" - ); - } - - Ok(key.trim().to_string()) -} - /// Check if logical replication is enabled on SerenDB project and offer to enable it async fn check_and_enable_logical_replication( project_id: &str, @@ -488,8 +523,8 @@ async fn check_and_enable_logical_replication( tracing::info!("Checking logical replication status for SerenDB project..."); - // Get API key - let api_key = get_api_key()?; + // Get API key from interactive module (handles env var or prompt) + let api_key = database_replicator::interactive::get_api_key()?; // Create Console API client let client = ConsoleClient::new(Some(console_api), api_key); @@ -569,6 +604,7 @@ async fn check_and_enable_logical_replication( async fn init_remote( source: String, target: String, + target_state: Option, _yes: bool, include_databases: Option>, exclude_databases: Option>, @@ -588,8 +624,42 @@ async fn init_remote( println!("🌐 SerenAI cloud execution enabled"); println!("API endpoint: {}", seren_api); - // Get API key (from env or prompt user) - let api_key = get_api_key()?; + // Get API key from interactive module (handles env var or prompt) + let api_key = database_replicator::interactive::get_api_key()?; + let remote_api_key = api_key.clone(); + + // Extract SerenDB IDs either from saved state (API-key flow) or the target URL + let ( + target_project_id, + target_branch_id, + target_databases, + connection_string_mode, + resolved_target_url, + ) = if let Some(state) = target_state { + let databases = state.databases; + if databases.is_empty() { + anyhow::bail!("Saved target is missing database entries"); + } + ( + Some(state.project_id), + Some(state.branch_id), + Some(databases), + SerenTargetMode::Project, + None, + ) + } else if database_replicator::utils::is_serendb_target(&target) { + let (p_id, b_id, _) = database_replicator::utils::parse_serendb_url_for_ids(&target) + .context("Failed to parse SerenDB target URL for project, branch, and database IDs.")?; + ( + Some(p_id), + Some(b_id), + None, + SerenTargetMode::Url, + Some(target.clone()), + ) + } else { + (None, None, None, SerenTargetMode::Url, Some(target.clone())) + }; // Estimate database size for automatic instance selection println!("Analyzing database size..."); @@ -670,17 +740,38 @@ async fn init_remote( ); // Note: "yes" is client-side only, not sent to server - let job_spec = JobSpec { - version: "1.0".to_string(), - command: "init".to_string(), - source_url: source, - target_url: target, - filter, - options, + let job_spec = match connection_string_mode { + SerenTargetMode::Project => JobSpec { + version: "1.0".to_string(), + command: "init".to_string(), + source_url: source, + target_url: None, + target_project_id, + target_branch_id, + target_databases, + seren_api_key: Some(api_key.clone()), + filter, + options, + }, + SerenTargetMode::Url => JobSpec { + version: "1.0".to_string(), + command: "init".to_string(), + source_url: source, + target_url: Some( + resolved_target_url + .expect("Seren target URL must exist when using connection string mode"), + ), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, + filter, + options, + }, }; // Submit job - let client = RemoteClient::new(seren_api, Some(api_key))?; + let client = RemoteClient::new(seren_api, Some(remote_api_key))?; println!("Submitting replication job..."); tracing::debug!("Job spec: {:?}", job_spec); @@ -746,3 +837,9 @@ fn build_table_rules( rules.apply_time_filter_cli(&args.time_filters)?; Ok(rules) } + +/// Internal mode to track whether we're using project-based or URL-based target +enum SerenTargetMode { + Project, + Url, +} diff --git a/src/remote/client.rs b/src/remote/client.rs index 843ec8b..593f8fa 100644 --- a/src/remote/client.rs +++ b/src/remote/client.rs @@ -7,6 +7,7 @@ use std::time::Duration; use super::models::{JobResponse, JobSpec, JobStatus}; +#[derive(Clone)] pub struct RemoteClient { client: Client, api_base_url: String, diff --git a/src/remote/models.rs b/src/remote/models.rs index 93aa864..bf386fc 100644 --- a/src/remote/models.rs +++ b/src/remote/models.rs @@ -10,7 +10,16 @@ pub struct JobSpec { pub version: String, pub command: String, // "init" or "sync" pub source_url: String, - pub target_url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub target_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub target_project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub target_branch_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub target_databases: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub seren_api_key: Option, #[serde(skip_serializing_if = "Option::is_none")] pub filter: Option, pub options: HashMap, @@ -48,3 +57,67 @@ pub struct ProgressInfo { pub databases_total: usize, pub message: Option, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_job_spec_serialization() { + let mut options = HashMap::new(); + options.insert("drop_existing".to_string(), serde_json::Value::Bool(true)); + + // Test with all fields populated + let job_spec = JobSpec { + version: "1.0".to_string(), + command: "init".to_string(), + source_url: "postgresql://source".to_string(), + target_url: Some("postgresql://target".to_string()), + target_project_id: Some("proj123".to_string()), + target_branch_id: Some("brnch456".to_string()), + target_databases: Some(vec!["db1".to_string()]), + seren_api_key: Some("seren_key".to_string()), + filter: Some(FilterSpec { + include_databases: Some(vec!["db1".to_string()]), + exclude_databases: None, + include_tables: None, + exclude_tables: None, + }), + options: options.clone(), + }; + + let parsed: serde_json::Value = serde_json::to_value(&job_spec).unwrap(); + assert_eq!(parsed["target_url"], "postgresql://target"); + assert_eq!(parsed["target_project_id"], "proj123"); + assert_eq!(parsed["target_branch_id"], "brnch456"); + assert_eq!(parsed["target_databases"], serde_json::json!(["db1"])); + assert_eq!(parsed["seren_api_key"], "seren_key"); + assert_eq!( + parsed["filter"], + serde_json::json!({"include_databases": ["db1"], "exclude_databases": null, "include_tables": null, "exclude_tables": null}) + ); + assert_eq!(parsed["schema_version"], "1.0"); + + // Test with optional fields as None + let job_spec_none = JobSpec { + version: "1.0".to_string(), + command: "init".to_string(), + source_url: "postgresql://source".to_string(), + target_url: Some("postgresql://target".to_string()), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, + filter: None, + options, + }; + + let parsed_none: serde_json::Value = serde_json::to_value(&job_spec_none).unwrap(); + assert_eq!(parsed_none["target_url"], "postgresql://target"); + assert!(parsed_none.get("target_project_id").is_none()); + assert!(parsed_none.get("target_branch_id").is_none()); + assert!(parsed_none.get("target_databases").is_none()); + assert!(parsed_none.get("seren_api_key").is_none()); + assert!(parsed_none.get("filter").is_none()); + } +} diff --git a/src/serendb/client.rs b/src/serendb/client.rs index 063eb43..2535777 100644 --- a/src/serendb/client.rs +++ b/src/serendb/client.rs @@ -5,6 +5,8 @@ use anyhow::{Context, Result}; use reqwest::Client; use serde::{Deserialize, Serialize}; +use crate::utils::replace_database_in_connection_string; + /// Default SerenDB Console API base URL pub const DEFAULT_CONSOLE_API_URL: &str = "https://console.serendb.com"; @@ -16,7 +18,7 @@ pub struct ConsoleClient { } /// Project information from SerenDB Console API -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Project { pub id: String, pub name: String, @@ -25,6 +27,58 @@ pub struct Project { pub organization_id: Option, } +/// Branch information from SerenDB Console API +#[allow(dead_code)] +#[derive(Debug, Clone, Deserialize)] +pub struct Branch { + pub id: String, + pub name: String, + pub project_id: String, + #[serde(default)] + pub is_default: bool, +} + +/// Database information from SerenDB Console API +#[allow(dead_code)] +#[derive(Debug, Clone, Deserialize)] +pub struct Database { + pub id: String, + pub name: String, + pub branch_id: String, +} + +/// Connection string response payload +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct ConnectionStringResponse { + pub connection_string: String, +} + +/// Request payload to create a database +#[allow(dead_code)] +#[derive(Debug, Serialize)] +pub struct CreateDatabaseRequest { + pub name: String, +} + +/// Paginated response wrapper from the Console API +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct PaginatedResponse { + pub data: Vec, + #[serde(default)] + pub pagination: Option, +} + +/// Pagination metadata returned by the Console API +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct Pagination { + pub total: i64, + pub page: i64, + pub per_page: i64, +} + /// Wrapper for API responses #[derive(Debug, Deserialize)] pub struct DataResponse { @@ -56,6 +110,202 @@ impl ConsoleClient { } } + /// List all projects accessible to the authenticated user + /// + /// # Returns + /// + /// Vector of projects the user has access to + /// + /// # Examples + /// ```ignore + /// let client = ConsoleClient::new(None, "seren_key".to_string()); + /// let projects = client.list_projects().await?; + /// for project in projects { + /// println!("{}: {}", project.id, project.name); + /// } + /// ``` + pub async fn list_projects(&self) -> Result> { + let url = format!("{}/api/projects", self.api_base_url); + + let response = self + .client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .send() + .await + .context("Failed to send request to SerenDB Console API")?; + + self.handle_common_errors(&response).await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("SerenDB Console API returned error {}: {}", status, body); + } + + let data: PaginatedResponse = response + .json() + .await + .context("Failed to parse projects response from SerenDB Console API")?; + + Ok(data.data) + } + + /// List all branches for a project + pub async fn list_branches(&self, project_id: &str) -> Result> { + let url = format!("{}/api/projects/{}/branches", self.api_base_url, project_id); + + let response = self + .client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .send() + .await + .context("Failed to send request to SerenDB Console API")?; + + self.handle_common_errors(&response).await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("SerenDB Console API returned error {}: {}", status, body); + } + + let data: PaginatedResponse = response + .json() + .await + .context("Failed to parse branches response from SerenDB Console API")?; + + Ok(data.data) + } + + /// Get the default branch for a project + /// + /// Returns the branch marked as default, or the first branch if none are marked. + pub async fn get_default_branch(&self, project_id: &str) -> Result { + let branches = self.list_branches(project_id).await?; + select_default_branch(project_id, branches) + } + + /// List all databases within a SerenDB branch + pub async fn list_databases(&self, project_id: &str, branch_id: &str) -> Result> { + let url = format!( + "{}/api/projects/{}/branches/{}/databases", + self.api_base_url, project_id, branch_id + ); + + let response = self + .client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .send() + .await + .context("Failed to send request to SerenDB Console API")?; + + self.handle_common_errors(&response).await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("SerenDB Console API returned error {}: {}", status, body); + } + + let data: PaginatedResponse = response + .json() + .await + .context("Failed to parse databases response from SerenDB Console API")?; + + Ok(data.data) + } + + /// Create a new SerenDB database inside a branch + pub async fn create_database( + &self, + project_id: &str, + branch_id: &str, + name: &str, + ) -> Result { + let url = format!( + "{}/api/projects/{}/branches/{}/databases", + self.api_base_url, project_id, branch_id + ); + + let request = CreateDatabaseRequest { + name: name.to_string(), + }; + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await + .context("Failed to send request to SerenDB Console API")?; + + self.handle_common_errors(&response).await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!( + "Failed to create database '{}': {} - {}", + name, + status, + body + ); + } + + let data: DataResponse = response + .json() + .await + .context("Failed to parse create database response from SerenDB Console API")?; + + Ok(data.data) + } + + /// Get a connection string for a branch/database combination + pub async fn get_connection_string( + &self, + project_id: &str, + branch_id: &str, + database: &str, + pooled: bool, + ) -> Result { + let url = format!( + "{}/api/projects/{}/branches/{}/connection-string?pooled={}", + self.api_base_url, project_id, branch_id, pooled + ); + + let response = self + .client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .send() + .await + .context("Failed to send request to SerenDB Console API")?; + + self.handle_common_errors(&response).await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("SerenDB Console API returned error {}: {}", status, body); + } + + let data: ConnectionStringResponse = response + .json() + .await + .context("Failed to parse connection string response from SerenDB Console API")?; + + replace_database_in_connection_string(&data.connection_string, database) + } + /// Get project information by ID /// /// # Arguments @@ -77,20 +327,15 @@ impl ConsoleClient { .await .context("Failed to send request to SerenDB Console API")?; - if response.status() == reqwest::StatusCode::UNAUTHORIZED { - anyhow::bail!( - "SerenDB API key is invalid or expired.\n\ - Generate a new key at: https://console.serendb.com/api-keys" - ); - } - - if response.status() == reqwest::StatusCode::NOT_FOUND { - anyhow::bail!( + self.handle_common_errors_with_context( + &response, + Some(format!( "Project {} not found.\n\ Verify the project ID is correct and you have access to it.", project_id - ); - } + )), + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -135,20 +380,15 @@ impl ConsoleClient { .await .context("Failed to send request to SerenDB Console API")?; - if response.status() == reqwest::StatusCode::UNAUTHORIZED { - anyhow::bail!( - "SerenDB API key is invalid or expired.\n\ - Generate a new key at: https://console.serendb.com/api-keys" - ); - } - - if response.status() == reqwest::StatusCode::NOT_FOUND { - anyhow::bail!( + self.handle_common_errors_with_context( + &response, + Some(format!( "Project {} not found.\n\ Verify the project ID is correct and you have access to it.", project_id - ); - } + )), + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -181,6 +421,45 @@ impl ConsoleClient { let project = self.get_project(project_id).await?; Ok(project.enable_logical_replication) } + + async fn handle_common_errors(&self, response: &reqwest::Response) -> Result<()> { + self.handle_common_errors_with_context(response, None).await + } + + async fn handle_common_errors_with_context( + &self, + response: &reqwest::Response, + not_found_message: Option, + ) -> Result<()> { + if response.status() == reqwest::StatusCode::UNAUTHORIZED { + anyhow::bail!( + "SerenDB API key is invalid or expired.\n\ + Generate a new key at: https://console.serendb.com/api-keys" + ); + } + + if response.status() == reqwest::StatusCode::NOT_FOUND { + if let Some(message) = not_found_message { + anyhow::bail!(message); + } else { + anyhow::bail!("Resource not found. Verify the ID is correct and you have access."); + } + } + + Ok(()) + } +} + +fn select_default_branch(project_id: &str, branches: Vec) -> Result { + if branches.is_empty() { + anyhow::bail!("Project {} has no branches", project_id); + } + + if let Some(default_branch) = branches.iter().find(|branch| branch.is_default) { + return Ok(default_branch.clone()); + } + + Ok(branches.into_iter().next().expect("branches is not empty")) } #[cfg(test)] @@ -211,4 +490,81 @@ mod tests { assert!(json.contains("enable_logical_replication")); assert!(json.contains("true")); } + + #[test] + fn test_branch_deserialization() { + let json = r#"{"id": "abc", "name": "main", "project_id": "xyz", "is_default": true}"#; + let branch: Branch = serde_json::from_str(json).unwrap(); + assert_eq!(branch.name, "main"); + assert!(branch.is_default); + } + + #[test] + fn test_database_deserialization() { + let json = r#"{"id": "db1", "name": "myapp", "branch_id": "br1"}"#; + let db: Database = serde_json::from_str(json).unwrap(); + assert_eq!(db.name, "myapp"); + assert_eq!(db.branch_id, "br1"); + } + + #[test] + fn test_select_default_branch_prefers_flagged_branch() { + let branches = vec![ + Branch { + id: "br1".into(), + name: "preview".into(), + project_id: "proj".into(), + is_default: false, + }, + Branch { + id: "br2".into(), + name: "main".into(), + project_id: "proj".into(), + is_default: true, + }, + ]; + + let default = select_default_branch("proj", branches).unwrap(); + assert_eq!(default.id, "br2"); + assert_eq!(default.name, "main"); + } + + #[test] + fn test_select_default_branch_falls_back_to_first() { + let branches = vec![ + Branch { + id: "br1".into(), + name: "alpha".into(), + project_id: "proj".into(), + is_default: false, + }, + Branch { + id: "br2".into(), + name: "beta".into(), + project_id: "proj".into(), + is_default: false, + }, + ]; + + let default = select_default_branch("proj", branches).unwrap(); + assert_eq!(default.id, "br1"); + assert_eq!(default.name, "alpha"); + } + + #[test] + fn test_select_default_branch_errors_when_empty() { + let err = select_default_branch("proj", Vec::new()).unwrap_err(); + assert!(format!("{err}").contains("has no branches")); + } + + #[test] + fn test_replace_database_in_connection_string() { + let original = + "postgresql://user:pass@host.serendb.com:5432/serendb?sslmode=require&foo=bar"; + let updated = + replace_database_in_connection_string(original, "myapp").expect("replace succeeds"); + assert!(updated.contains("/myapp?")); + assert!(updated.starts_with("postgresql://user:pass@host.serendb.com:5432/")); + assert!(updated.ends_with("sslmode=require&foo=bar")); + } } diff --git a/src/serendb/mod.rs b/src/serendb/mod.rs index c6d39e8..cc1f54a 100644 --- a/src/serendb/mod.rs +++ b/src/serendb/mod.rs @@ -2,5 +2,124 @@ // ABOUTME: Enables checking and enabling logical replication on SerenDB projects mod client; +mod picker; +mod target; -pub use client::ConsoleClient; +pub use client::{Branch, ConsoleClient, Database, Project}; +pub use picker::{create_missing_databases, select_target, TargetSelection}; +pub use target::{clear_target_state, load_target_state, save_target_state, TargetState}; + +use anyhow::Result; + +#[cfg(test)] +pub(crate) fn target_env_mutex() -> &'static std::sync::Mutex<()> { + use std::sync::{Mutex, OnceLock}; + static ENV_MUTEX: OnceLock> = OnceLock::new(); + ENV_MUTEX.get_or_init(|| Mutex::new(())) +} + +/// How the target database is specified +#[derive(Debug, Clone)] +pub enum TargetMode { + /// User provided --target connection string directly + ConnectionString(String), + /// User provided API key, will use interactive selection + ApiKey(String), + /// Using saved target from previous init + SavedState(TargetState), +} + +/// Resolve which target mode to use based on CLI args and environment +pub fn resolve_target_mode(target: Option, api_key: Option) -> Result { + match (target, api_key) { + (Some(url), _) => Ok(TargetMode::ConnectionString(url)), + (None, Some(key)) => { + if let Some(state) = load_target_state()? { + tracing::info!( + "Using saved target configuration: {}/{}", + state.project_name, + state.branch_name + ); + Ok(TargetMode::SavedState(state)) + } else { + Ok(TargetMode::ApiKey(key)) + } + } + (None, None) => { + anyhow::bail!( + "Target database required.\n\n\ + Option 1: Provide --target with a PostgreSQL connection string\n\ + Option 2: Set SEREN_API_KEY or pass --api-key for interactive SerenDB selection\n\n\ + Get your API key at: https://console.serendb.com/api-keys" + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serendb::target::{clear_target_state, save_target_state, TargetState}; + use tempfile::tempdir; + + fn with_temp_state_path(func: F) { + let _guard = crate::serendb::target_env_mutex().lock().unwrap(); + let dir = tempdir().expect("tempdir"); + let path = dir.path().join("target.json"); + std::env::set_var("SEREN_TARGET_STATE_PATH", &path); + func(); + std::env::remove_var("SEREN_TARGET_STATE_PATH"); + } + + #[test] + fn test_resolve_target_mode_connection_string() { + let mode = + resolve_target_mode(Some("postgresql://localhost/db".to_string()), None).unwrap(); + match mode { + TargetMode::ConnectionString(url) => assert!(url.contains("localhost")), + _ => panic!("Expected ConnectionString mode"), + } + } + + #[test] + fn test_resolve_target_mode_prefers_explicit_target() { + let mode = resolve_target_mode( + Some("postgresql://localhost/db".to_string()), + Some("seren_key".to_string()), + ) + .unwrap(); + + if !matches!(mode, TargetMode::ConnectionString(_)) { + panic!("Expected ConnectionString mode"); + } + } + + #[test] + fn test_resolve_target_mode_uses_saved_state() { + with_temp_state_path(|| { + let state = TargetState::new( + "proj".into(), + "Project".into(), + "branch".into(), + "main".into(), + vec!["db1".into()], + "postgresql://localhost/source", + ); + save_target_state(&state).expect("save state"); + + let mode = resolve_target_mode(None, Some("seren_key".into())).unwrap(); + match mode { + TargetMode::SavedState(saved) => assert_eq!(saved.project_id, "proj"), + _ => panic!("Expected SavedState mode"), + } + + clear_target_state().expect("clear state"); + }); + } + + #[test] + fn test_resolve_target_mode_neither_fails() { + let result = resolve_target_mode(None, None); + assert!(result.is_err()); + } +} diff --git a/src/serendb/picker.rs b/src/serendb/picker.rs new file mode 100644 index 0000000..2ff4117 --- /dev/null +++ b/src/serendb/picker.rs @@ -0,0 +1,107 @@ +// ABOUTME: Interactive terminal UI for selecting SerenDB projects and databases +// ABOUTME: Uses dialoguer for consistent UX with existing interactive flows + +use crate::serendb::{Branch, ConsoleClient, Project}; +use anyhow::{Context, Result}; +use dialoguer::{theme::ColorfulTheme, Select}; + +/// Result of the interactive project/database selection +#[derive(Debug, Clone)] +pub struct TargetSelection { + pub project: Project, + pub branch: Branch, + pub databases: Vec, +} + +/// Run interactive SerenDB target selection. +/// Returns the selected project, branch, and database names to mirror the source. +pub async fn select_target( + client: &ConsoleClient, + source_databases: &[String], +) -> Result { + println!("\n=================================================="); + println!("SerenDB Target Selection"); + println!("==================================================\n"); + + let projects = client.list_projects().await?; + + if projects.is_empty() { + anyhow::bail!( + "No SerenDB projects found for this API key.\n\ + Create a project at: https://console.serendb.com" + ); + } + + let project_labels: Vec = projects + .iter() + .map(|p| { + let short_id: String = p.id.chars().take(8).collect(); + format!("{} ({})", p.name, short_id) + }) + .collect(); + + let project_idx = Select::with_theme(&ColorfulTheme::default()) + .with_prompt("Select target project") + .items(&project_labels) + .default(0) + .interact() + .context("Project selection cancelled")?; + + let project = projects[project_idx].clone(); + println!(" Selected project: {}\n", project.name); + + let branch = client.get_default_branch(&project.id).await?; + println!(" Using branch: {}\n", branch.name); + + let existing = client.list_databases(&project.id, &branch.id).await?; + let existing_names: Vec = existing.iter().map(|d| d.name.clone()).collect(); + + println!("Source databases to replicate: {:?}", source_databases); + println!("Existing target databases: {:?}\n", existing_names); + + let mut target_databases = Vec::new(); + for source_db in source_databases { + if existing_names.contains(source_db) { + println!(" \u{2713} {}", source_db); + } else { + println!(" + {} (will be created)", source_db); + } + target_databases.push(source_db.clone()); + } + + println!(); + + Ok(TargetSelection { + project, + branch, + databases: target_databases, + }) +} + +/// Ensure target branch contains all databases required for replication. +pub async fn create_missing_databases( + client: &ConsoleClient, + project_id: &str, + branch_id: &str, + databases: &[String], +) -> Result<()> { + let existing = client.list_databases(project_id, branch_id).await?; + let existing_names: Vec = existing.iter().map(|d| d.name.clone()).collect(); + + for db_name in databases { + if !existing_names.contains(db_name) { + println!(" Creating database '{}'...", db_name); + client + .create_database(project_id, branch_id, db_name) + .await?; + println!(" \u{2713} Created '{}'", db_name); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + // Interactive picker relies on network + terminal input, so unit tests are not practical here. +} diff --git a/src/serendb/target.rs b/src/serendb/target.rs new file mode 100644 index 0000000..7d4f08e --- /dev/null +++ b/src/serendb/target.rs @@ -0,0 +1,196 @@ +// ABOUTME: Persists SerenDB target selection for reuse across commands +// ABOUTME: Stores project/branch/database selection in .seren-replicator/target.json + +use anyhow::{Context, Result}; +use chrono::Utc; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::path::PathBuf; + +const TARGET_FILE: &str = ".seren-replicator/target.json"; +const TARGET_FILE_ENV: &str = "SEREN_TARGET_STATE_PATH"; +const STATE_VERSION: u32 = 1; + +/// Persisted target selection state +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TargetState { + /// Schema version for forward compatibility + pub version: u32, + /// Selected SerenDB project ID + pub project_id: String, + /// Human-readable project name + pub project_name: String, + /// Selected branch ID + pub branch_id: String, + /// Branch name + pub branch_name: String, + /// List of database names being replicated + pub databases: Vec, + /// SHA256 hash of source URL (to detect mismatches) + pub source_url_hash: String, + /// When this target was configured + pub created_at: String, +} + +impl TargetState { + /// Create a new target state snapshot + pub fn new( + project_id: String, + project_name: String, + branch_id: String, + branch_name: String, + databases: Vec, + source_url: &str, + ) -> Self { + Self { + version: STATE_VERSION, + project_id, + project_name, + branch_id, + branch_name, + databases, + source_url_hash: hash_url(source_url), + created_at: Utc::now().to_rfc3339(), + } + } + + /// Check if a source URL matches the stored configuration + pub fn source_matches(&self, source_url: &str) -> bool { + self.source_url_hash == hash_url(source_url) + } +} + +/// Hash a URL for comparison (strips password for privacy) +fn hash_url(url: &str) -> String { + let sanitized = crate::utils::strip_password_from_url(url).unwrap_or_else(|_| url.to_string()); + let mut hasher = Sha256::new(); + hasher.update(sanitized.as_bytes()); + format!("sha256:{:x}", hasher.finalize()) +} + +/// Get the path to the target state file, allowing an env override for tests +fn target_file_path() -> PathBuf { + if let Ok(custom) = std::env::var(TARGET_FILE_ENV) { + return PathBuf::from(custom); + } + PathBuf::from(TARGET_FILE) +} + +/// Load target state from disk. Returns Ok(None) if the file does not exist. +pub fn load_target_state() -> Result> { + let path = target_file_path(); + + if !path.exists() { + return Ok(None); + } + + let content = std::fs::read_to_string(&path) + .with_context(|| format!("Failed to read {}", path.display()))?; + + let state: TargetState = serde_json::from_str(&content).with_context(|| { + format!( + "Failed to parse {}. Delete it and run init again.", + path.display() + ) + })?; + + if state.version > STATE_VERSION { + anyhow::bail!( + "Target state file was created by a newer database-replicator version. \ + Upgrade this CLI or delete {}", + path.display() + ); + } + + Ok(Some(state)) +} + +/// Save target state to disk +pub fn save_target_state(state: &TargetState) -> Result<()> { + let path = target_file_path(); + + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let content = + serde_json::to_string_pretty(state).context("Failed to serialize target state")?; + + std::fs::write(&path, content) + .with_context(|| format!("Failed to write {}", path.display()))?; + + tracing::info!("Saved SerenDB target configuration to {}", path.display()); + Ok(()) +} + +/// Delete persisted target state (if present) +pub fn clear_target_state() -> Result<()> { + let path = target_file_path(); + if path.exists() { + std::fs::remove_file(&path) + .with_context(|| format!("Failed to remove {}", path.display()))?; + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + fn with_temp_state_path(func: F) { + let _guard = crate::serendb::target_env_mutex().lock().unwrap(); + let dir = tempdir().expect("tempdir"); + let file_path = dir.path().join("target.json"); + std::env::set_var(TARGET_FILE_ENV, &file_path); + func(); + std::env::remove_var(TARGET_FILE_ENV); + } + + #[test] + fn test_target_state_roundtrip() { + with_temp_state_path(|| { + let state = TargetState::new( + "proj-123".to_string(), + "my-project".to_string(), + "branch-456".to_string(), + "main".to_string(), + vec!["db1".to_string(), "db2".to_string()], + "postgresql://localhost/source", + ); + + save_target_state(&state).expect("save target state"); + let loaded = load_target_state() + .expect("load state") + .expect("state present"); + + assert_eq!(loaded.project_id, "proj-123"); + assert_eq!(loaded.databases.len(), 2); + assert!(loaded.source_matches("postgresql://localhost/source")); + }); + } + + #[test] + fn test_source_url_matching() { + let state = TargetState::new( + "p".to_string(), + "proj".to_string(), + "b".to_string(), + "main".to_string(), + vec![], + "postgresql://user:pass@host/db", + ); + + assert!(state.source_matches("postgresql://user:pass@host/db")); + assert!(state.source_matches("postgresql://user:other@host/db")); + assert!(!state.source_matches("postgresql://user:pass@other/db")); + } + + #[test] + fn test_hash_url_strips_password() { + let hash1 = hash_url("postgresql://user:secret1@host/db"); + let hash2 = hash_url("postgresql://user:secret2@host/db"); + assert_eq!(hash1, hash2); + } +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..1019dd0 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::PathBuf; + +#[derive(Serialize, Deserialize, Default)] +pub struct AppState { + pub target_url: Option, +} + +fn get_state_path() -> Result { + let home_dir = + dirs::home_dir().ok_or_else(|| anyhow::anyhow!("Could not find home directory"))?; + let state_dir = home_dir.join(".database-replicator"); + if !state_dir.exists() { + fs::create_dir_all(&state_dir)?; + } + Ok(state_dir.join("state.json")) +} + +pub fn load() -> Result { + let state_path = get_state_path()?; + if !state_path.exists() { + return Ok(AppState::default()); + } + let state_file = fs::File::open(state_path)?; + let state = serde_json::from_reader(state_file)?; + Ok(state) +} + +pub fn save(state: &AppState) -> Result<()> { + let state_path = get_state_path()?; + let state_file = fs::File::create(state_path)?; + serde_json::to_writer_pretty(state_file, state)?; + Ok(()) +} diff --git a/src/utils.rs b/src/utils.rs index 35b1db1..af418e7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,6 +3,7 @@ use anyhow::{bail, Context, Result}; use std::time::Duration; +use url::Url; use which::which; /// Get TCP keepalive environment variables for PostgreSQL client tools @@ -1095,6 +1096,49 @@ pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result { Ok(cleaned_count) } +/// Parse a SerenDB URL to extract project, branch, and database IDs +/// +/// SerenDB URLs have the format: postgresql://user:pass@...serendb.com:5432/db +/// This function extracts the three UUIDs from the hostname. +/// +/// # Arguments +/// +/// * `url` - The SerenDB PostgreSQL connection string +/// +/// # Returns +/// +/// An `Option` containing a tuple of `(project_id, branch_id, database_id)` if the +/// URL is a valid SerenDB target and contains the expected ID format, otherwise `None`. +pub fn parse_serendb_url_for_ids(url: &str) -> Option<(String, String, String)> { + let parts = parse_postgres_url(url).ok()?; + + if !is_serendb_target(url) { + return None; + } + + // Hostname format: ...serendb.com + // Or with custom subdomains: ....serendb.com + // We want the last three parts before .serendb.com + let host_parts: Vec<&str> = parts.host.split('.').collect(); + + if host_parts.len() < 4 { + return None; // Not enough parts for SerenDB ID format + } + + let num_host_parts = host_parts.len(); + let database_id = host_parts[num_host_parts - 4].to_string(); + let branch_id = host_parts[num_host_parts - 3].to_string(); + let project_id = host_parts[num_host_parts - 2].to_string(); + + // Basic UUID format validation (optional but good for robustness) + // A real UUID check would be more extensive, but string length is a good start + if database_id.len() == 36 && branch_id.len() == 36 && project_id.len() == 36 { + Some((project_id, branch_id, database_id)) + } else { + None + } +} + /// Remove a managed temporary directory /// /// Explicitly removes a temporary directory created by `create_managed_temp_dir()`. @@ -1143,6 +1187,30 @@ pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> { Ok(()) } +/// Replace the database name in a connection string URL +/// +/// This is used internally by SerenDB to provide a generic connection string +/// which then needs to be specialized for a particular database. +/// +/// # Arguments +/// +/// * `url` - The connection string URL (e.g., postgresql://host/template_db) +/// * `new_db` - The new database name to insert into the URL +/// +/// # Returns +/// +/// A new URL string with the database name replaced. +/// +/// # Errors +/// +/// Returns an error if the URL is invalid and cannot be parsed. +pub fn replace_database_in_connection_string(url: &str, new_db: &str) -> Result { + let mut parsed = Url::parse(url).context("Invalid connection string URL")?; + parsed.set_path(&format!("/{}", new_db)); + + Ok(parsed.to_string()) +} + /// Check if a PostgreSQL URL points to a SerenDB instance /// /// SerenDB hosts have domains ending with `.serendb.com` diff --git a/tests/fallback_test.rs b/tests/fallback_test.rs new file mode 100644 index 0000000..0427503 --- /dev/null +++ b/tests/fallback_test.rs @@ -0,0 +1,35 @@ +use std::process::Command; +use tempfile::tempdir; + +#[test] +fn test_remote_to_local_fallback() { + let temp_dir = tempdir().unwrap(); + let source_db_path = temp_dir.path().join("source.db"); + let target_db_path = temp_dir.path().join("target.db"); + + // Create dummy database files + std::fs::write(&source_db_path, "").unwrap(); + std::fs::write(&target_db_path, "").unwrap(); + + let bin_path = env!("CARGO_BIN_EXE_database-replicator"); + + let output = Command::new(bin_path) + .arg("init") + .arg("--source") + .arg(source_db_path.to_str().unwrap()) + .arg("--target") + .arg(target_db_path.to_str().unwrap()) + .arg("--seren") + .arg("--no-interactive") + .output() + .expect("Failed to execute command"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + println!("stdout: {}", stdout); + println!("stderr: {}", stderr); + dbg!(&output); + + assert!(stderr.contains("--seren flag is only compatible with SerenDB targets.")); +} diff --git a/tests/integration_remote_test.rs b/tests/integration_remote_test.rs index ea083e1..45fe800 100644 --- a/tests/integration_remote_test.rs +++ b/tests/integration_remote_test.rs @@ -33,12 +33,15 @@ async fn test_remote_job_submission() { let api_key = std::env::var("SEREN_API_KEY").ok(); let client = RemoteClient::new(api_url, api_key).expect("Failed to create remote client"); - // Create a job spec for validation (safe, read-only) let job_spec = JobSpec { version: "1.0".to_string(), command: "validate".to_string(), source_url, - target_url, + target_url: Some(target_url), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, filter: None, options: HashMap::new(), }; @@ -84,18 +87,22 @@ async fn test_remote_job_polling() { let api_key = std::env::var("SEREN_API_KEY").ok(); let client = RemoteClient::new(api_url, api_key).expect("Failed to create remote client"); - // Create and submit a job spec for validation (safe, read-only) let job_spec = JobSpec { version: "1.0".to_string(), command: "validate".to_string(), source_url, - target_url, + target_url: Some(target_url), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, filter: None, options: HashMap::new(), }; // Submit the job let job_response = client + .clone() .submit_job(&job_spec) .await .expect("Failed to submit job"); @@ -106,6 +113,7 @@ async fn test_remote_job_polling() { // Poll for initial status println!("Polling for job status..."); let status = client + .clone() .get_job_status(&job_response.job_id) .await .expect("Failed to get job status"); @@ -135,6 +143,7 @@ async fn test_remote_job_polling() { tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; let updated_status = client + .clone() .get_job_status(&job_response.job_id) .await .expect("Failed to get job status"); @@ -185,13 +194,18 @@ async fn test_remote_job_poll_until_complete() { version: "1.0".to_string(), command: "validate".to_string(), source_url, - target_url, + target_url: Some(target_url), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, filter: None, options: HashMap::new(), }; // Submit the job let job_response = client + .clone() .submit_job(&job_spec) .await .expect("Failed to submit job"); @@ -202,6 +216,7 @@ async fn test_remote_job_poll_until_complete() { // Poll until completion with a callback let final_status = client + .clone() .poll_until_complete(&job_response.job_id, |status| { println!(" Status update: {}", status.status); if let Some(progress) = &status.progress { @@ -275,13 +290,17 @@ async fn test_remote_job_submission_with_filters() { version: "1.0".to_string(), command: "validate".to_string(), source_url, - target_url, + target_url: Some(target_url), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, filter: Some(filter), options: HashMap::new(), }; // Submit the job - let result = client.submit_job(&job_spec).await; + let result = client.clone().submit_job(&job_spec).await; match &result { Ok(job_response) => { @@ -323,13 +342,17 @@ async fn test_remote_job_submission_with_options() { version: "1.0".to_string(), command: "validate".to_string(), source_url, - target_url, + target_url: Some(target_url), + target_project_id: None, + target_branch_id: None, + target_databases: None, + seren_api_key: None, filter: None, options, }; // Submit the job - let result = client.submit_job(&job_spec).await; + let result = client.clone().submit_job(&job_spec).await; match &result { Ok(job_response) => { diff --git a/tests/interactive_serendb_test.rs b/tests/interactive_serendb_test.rs new file mode 100644 index 0000000..704d708 --- /dev/null +++ b/tests/interactive_serendb_test.rs @@ -0,0 +1,41 @@ +use std::fs; +use std::process::Command; +use tempfile::tempdir; + +#[tokio::test] +#[ignore] // This test requires manual interaction due to inquire prompts +async fn test_interactive_serendb_selection() { + let temp_dir = tempdir().unwrap(); + let home_dir = temp_dir.path(); + let state_dir = home_dir.join(".database-replicator"); + let _ = fs::create_dir_all(&state_dir); + + let bin_path = env!("CARGO_BIN_EXE_database-replicator"); + + println!("\n--- Starting interactive SerenDB selection test ---"); + println!("This test requires manual interaction. Please follow the prompts."); + println!("Ensure SEREN_API_KEY is set in your environment or be ready to enter it."); + + // This command will trigger the interactive selection. We cannot assert stdout directly + // due to the interactive nature, but we can verify it doesn't crash and potentially + // manually observe the prompts. + let output = Command::new(bin_path) + .arg("init") + .arg("--source") + .arg("sqlite:///tmp/dummy.db") // Dummy source, won't be connected + .arg("--seren") + .env("HOME", home_dir) // Use temp home for state file + .output() + .expect("Failed to execute command"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + println!("stdout: {}", stdout); + println!("stderr: {}", stderr); + + // If the interactive selection was successful and a target was chosen, it should be saved + // We can't automate the interaction, so this test mainly verifies the flow doesn't panic. + // A more sophisticated integration test would use expect-test or similar for interactive prompts. + assert!(output.status.success() || stderr.contains("Target database URL not provided")); +} diff --git a/tests/serendb_api_test.rs b/tests/serendb_api_test.rs new file mode 100644 index 0000000..8dadd63 --- /dev/null +++ b/tests/serendb_api_test.rs @@ -0,0 +1,178 @@ +// ABOUTME: Integration tests for SerenDB Console API +// ABOUTME: Tests require SEREN_API_KEY and optionally TEST_SERENDB_PROJECT_ID + +//! Integration tests for SerenDB Console API +//! +//! These tests require: +//! - SEREN_API_KEY environment variable +//! - TEST_SERENDB_PROJECT_ID environment variable (for project-specific tests) +//! +//! Run with: cargo test --test serendb_api_test -- --ignored --nocapture + +use database_replicator::serendb::ConsoleClient; + +fn get_test_client() -> Option { + let api_key = std::env::var("SEREN_API_KEY").ok()?; + Some(ConsoleClient::new(None, api_key)) +} + +fn get_test_project_id() -> Option { + std::env::var("TEST_SERENDB_PROJECT_ID").ok() +} + +#[tokio::test] +#[ignore] +async fn test_list_projects() { + let client = get_test_client().expect("SEREN_API_KEY required"); + + let projects = client.list_projects().await.unwrap(); + + assert!(!projects.is_empty(), "Should have at least one project"); + println!("Found {} projects:", projects.len()); + for project in &projects { + println!( + " - {} (id: {}, logical_replication: {})", + project.name, project.id, project.enable_logical_replication + ); + } +} + +#[tokio::test] +#[ignore] +async fn test_get_project() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let project = client.get_project(&project_id).await.unwrap(); + + assert_eq!(project.id, project_id); + println!("Project: {} ({})", project.name, project.id); + println!( + " Logical replication: {}", + project.enable_logical_replication + ); +} + +#[tokio::test] +#[ignore] +async fn test_list_branches() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let branches = client.list_branches(&project_id).await.unwrap(); + + assert!(!branches.is_empty(), "Should have at least one branch"); + println!("Found {} branches:", branches.len()); + for branch in &branches { + let default_marker = if branch.is_default { " (default)" } else { "" }; + println!(" - {}{} (id: {})", branch.name, default_marker, branch.id); + } +} + +#[tokio::test] +#[ignore] +async fn test_get_default_branch() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let branch = client.get_default_branch(&project_id).await.unwrap(); + + println!("Default branch: {} (id: {})", branch.name, branch.id); + assert!(!branch.id.is_empty()); + assert!(!branch.name.is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_list_databases() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let branch = client.get_default_branch(&project_id).await.unwrap(); + let databases = client + .list_databases(&project_id, &branch.id) + .await + .unwrap(); + + println!( + "Found {} databases in branch {}:", + databases.len(), + branch.name + ); + for db in &databases { + println!(" - {} (id: {})", db.name, db.id); + } +} + +#[tokio::test] +#[ignore] +async fn test_get_connection_string() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let branch = client.get_default_branch(&project_id).await.unwrap(); + let conn_str = client + .get_connection_string(&project_id, &branch.id, "serendb", false) + .await + .unwrap(); + + assert!( + conn_str.starts_with("postgresql://"), + "Should be a PostgreSQL connection string" + ); + assert!( + conn_str.contains("serendb.com") || conn_str.contains("localhost"), + "Should contain SerenDB hostname" + ); + println!("Connection string retrieved successfully (credentials redacted)"); +} + +#[tokio::test] +#[ignore] +async fn test_is_logical_replication_enabled() { + let client = get_test_client().expect("SEREN_API_KEY required"); + let project_id = get_test_project_id().expect("TEST_SERENDB_PROJECT_ID required"); + + let enabled = client + .is_logical_replication_enabled(&project_id) + .await + .unwrap(); + + println!("Logical replication enabled: {}", enabled); +} + +#[tokio::test] +#[ignore] +async fn test_invalid_api_key_returns_error() { + let client = ConsoleClient::new(None, "invalid_key".to_string()); + + let result = client.list_projects().await; + + assert!(result.is_err(), "Should fail with invalid API key"); + let error = result.unwrap_err().to_string(); + assert!( + error.contains("invalid") || error.contains("expired") || error.contains("401"), + "Error should indicate authentication failure: {}", + error + ); + println!("Correctly rejected invalid API key"); +} + +#[tokio::test] +#[ignore] +async fn test_nonexistent_project_returns_error() { + let client = get_test_client().expect("SEREN_API_KEY required"); + + let result = client + .get_project("00000000-0000-0000-0000-000000000000") + .await; + + assert!(result.is_err(), "Should fail for nonexistent project"); + let error = result.unwrap_err().to_string(); + assert!( + error.contains("not found") || error.contains("404"), + "Error should indicate not found: {}", + error + ); + println!("Correctly returned not found for nonexistent project"); +} diff --git a/tests/state_test.rs b/tests/state_test.rs new file mode 100644 index 0000000..55e907b --- /dev/null +++ b/tests/state_test.rs @@ -0,0 +1,63 @@ +use std::fs; +use std::process::Command; +use tempfile::tempdir; + +#[test] +fn test_target_command() { + let temp_dir = tempdir().unwrap(); + let home_dir = temp_dir.path(); + let state_dir = home_dir.join(".database-replicator"); + let state_file = state_dir.join("state.json"); + + let bin_path = env!("CARGO_BIN_EXE_database-replicator"); + + // Test `target get` when state is not set + let output = Command::new(bin_path) + .arg("target") + .arg("get") + .env("HOME", home_dir) + .output() + .expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Target database URL is not set.")); + + // Test `target set` + let target_url = "postgres://user:pass@host:5432/db"; + let output = Command::new(bin_path) + .arg("target") + .arg("set") + .arg(target_url) + .env("HOME", home_dir) + .output() + .expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains(&format!("Target database URL set to: {}", target_url))); + + // Verify state file content + let state_content = fs::read_to_string(&state_file).unwrap(); + assert!(state_content.contains(target_url)); + + // Test `target get` when state is set + let output = Command::new(bin_path) + .arg("target") + .arg("get") + .env("HOME", home_dir) + .output() + .expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains(&format!("Current target database URL: {}", target_url))); + + // Test `target unset` + let output = Command::new(bin_path) + .arg("target") + .arg("unset") + .env("HOME", home_dir) + .output() + .expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Target database URL unset.")); + + // Verify state file content + let state_content = fs::read_to_string(&state_file).unwrap(); + assert!(!state_content.contains(target_url)); +}