Skip to content

Commit aedb529

Browse files
authored
feat(cli): connectivity checks for DB and embedding in interactive init (#42)
## What type of PR is this? - [x] feat (new feature) - [x] chore (maintenance, tooling) ## Which issue(s) this PR fixes N/A — follow-up to #41 ## What this PR does / why we need it Adds connectivity validation to `memoria init -i` after user confirms settings, before writing config files. ### Connectivity checks Between the confirm step and writing config: - **Database**: TCP connect to `host:port` (3s timeout) — catches wrong host/port/DB not running - **Embedding**: POST `/embeddings` with `{"model":"...","input":"test"}` — verifies URL reachable, API key valid, model exists On failure, shows the error and prompts `Continue anyway?` so users can still write config if they plan to start services later. ### Other changes - Fix clippy `collapsible_if` warning - Add `scripts/migrate_embedding_dim.sh` to `.gitignore` (local-only utility) ### Changes | File | Change | |------|--------| | `memoria/crates/memoria-cli/src/main.rs` | `check_db`, `check_embedding`, `check_embedding_request` functions; fix `read_password_line` tty fallback; clippy fix | | `.gitignore` | Ignore local migration script |
1 parent e0df7b9 commit aedb529

2 files changed

Lines changed: 90 additions & 3 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ build/
3535
.mypy_cache/
3636
.ruff_cache/
3737
htmlcov/
38+
scripts/migrate_embedding_dim.sh

memoria/crates/memoria-cli/src/main.rs

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,15 +617,23 @@ fn prompt_secret(label: &str, existing: &str) -> String {
617617
fn read_password_line() -> String {
618618
#[cfg(unix)]
619619
{
620+
use std::os::unix::io::AsRawFd;
621+
let stdin = std::io::stdin();
622+
let fd = stdin.as_raw_fd();
620623
unsafe {
621624
let mut termios: libc::termios = std::mem::zeroed();
622-
libc::tcgetattr(0, &mut termios);
625+
if libc::tcgetattr(fd, &mut termios) != 0 {
626+
// Not a terminal (pipe/redirect) — fall back to normal read
627+
let mut buf = String::new();
628+
std::io::stdin().read_line(&mut buf).ok();
629+
return buf.trim().to_string();
630+
}
623631
let old = termios;
624632
termios.c_lflag &= !libc::ECHO;
625-
libc::tcsetattr(0, libc::TCSANOW, &termios);
633+
libc::tcsetattr(fd, libc::TCSANOW, &termios);
626634
let mut buf = String::new();
627635
std::io::stdin().read_line(&mut buf).ok();
628-
libc::tcsetattr(0, libc::TCSANOW, &old);
636+
libc::tcsetattr(fd, libc::TCSANOW, &old);
629637
buf.trim().to_string()
630638
}
631639
}
@@ -694,6 +702,72 @@ fn prompt_confirm(label: &str) -> bool {
694702
!buf.trim().eq_ignore_ascii_case("n")
695703
}
696704

705+
fn check_db(db_url: &str) -> bool {
706+
use std::net::TcpStream;
707+
use std::time::Duration;
708+
// Parse host:port from mysql://user:pass@host:port/db
709+
let addr = db_url
710+
.strip_prefix("mysql://")
711+
.and_then(|s| s.split_once('@'))
712+
.and_then(|(_, hostdb)| hostdb.split_once('/'))
713+
.map(|(hostport, _)| hostport.to_string())
714+
.unwrap_or_default();
715+
if addr.is_empty() {
716+
println!(" ✗ Database: invalid URL");
717+
return false;
718+
}
719+
match TcpStream::connect_timeout(
720+
&addr.parse().unwrap_or_else(|_| {
721+
// Resolve manually for host:port format
722+
use std::net::ToSocketAddrs;
723+
addr.to_socket_addrs()
724+
.ok()
725+
.and_then(|mut a| a.next())
726+
.unwrap_or_else(|| ([127, 0, 0, 1], 6001).into())
727+
}),
728+
Duration::from_secs(3),
729+
) {
730+
Ok(_) => { println!(" ✓ Database: {} reachable", addr); true }
731+
Err(e) => { println!(" ✗ Database: {} — {}", addr, e); false }
732+
}
733+
}
734+
735+
fn check_embedding(base_url: &str, api_key: &str, model: &str) -> bool {
736+
if base_url.is_empty() {
737+
// OpenAI official — use default URL
738+
return check_embedding_request("https://api.openai.com/v1", api_key, model);
739+
}
740+
check_embedding_request(base_url, api_key, model)
741+
}
742+
743+
fn check_embedding_request(base_url: &str, api_key: &str, model: &str) -> bool {
744+
let url = format!("{}/embeddings", base_url.trim_end_matches('/'));
745+
let client = reqwest::blocking::Client::builder()
746+
.timeout(std::time::Duration::from_secs(10))
747+
.build()
748+
.unwrap();
749+
let mut req = client.post(&url)
750+
.header("Content-Type", "application/json")
751+
.body(format!(r#"{{"model":"{}","input":"test"}}"#, model));
752+
if !api_key.is_empty() {
753+
req = req.header("Authorization", format!("Bearer {}", api_key));
754+
}
755+
match req.send() {
756+
Ok(resp) if resp.status().is_success() => {
757+
println!(" ✓ Embedding: {} OK", base_url);
758+
true
759+
}
760+
Ok(resp) => {
761+
println!(" ✗ Embedding: {} — HTTP {}", base_url, resp.status());
762+
false
763+
}
764+
Err(e) => {
765+
println!(" ✗ Embedding: {} — {}", base_url, e);
766+
false
767+
}
768+
}
769+
}
770+
697771
fn cmd_init_interactive(project_dir: &Path, force: bool) {
698772
let existing = load_existing_config(project_dir);
699773

@@ -793,6 +867,18 @@ fn cmd_init_interactive(project_dir: &Path, force: bool) {
793867
return;
794868
}
795869

870+
// ── Connectivity checks ──
871+
println!("\n━━━ Checking connections ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
872+
println!();
873+
874+
let db_ok = check_db(&db_url);
875+
let emb_ok = check_embedding(&emb_base_url, &emb_api_key, &emb_model);
876+
877+
if (!db_ok || !emb_ok) && !prompt_confirm("Continue anyway?") {
878+
println!(" Aborted.");
879+
return;
880+
}
881+
796882
println!();
797883
cmd_init(
798884
project_dir, tools,

0 commit comments

Comments
 (0)