Skip to content

Commit de98df6

Browse files
authored
fix(multidb): default fresh deployments to multi-db (#188)
## What type of PR is this? - [x] fix (bug fix) - [x] docs (documentation) - [x] test (adding or updating tests) ## Which issue(s) this PR fixes Fixes # ## What this PR does / why we need it Fresh empty deployments on the latest runtime still stayed in single-db unless `MEMORIA_MULTI_DB` was set explicitly, even though new-version fresh environments are intended to run multi-db by default. This PR changes the fresh auto-detect startup path to enable multi-db by default while keeping the other startup paths intact: - `FreshSingleDb` now enables multi-db during runtime bootstrap - `PendingLegacyMigration` still migrates legacy single-db data before switching to multi-db - `MultiDbReady` still enters multi-db directly - add bootstrap unit tests for fresh / pending-migration / ready behavior - add an API multi-db routing integration test - document that fresh empty deployments default to multi-db while explicit multi-db envs are still recommended in production ## Testing - `cargo test -p memoria-cli fresh_topology_bootstrap_defaults_to_multi_db -- --nocapture` - `cargo test -p memoria-cli multi_db_ready_bootstrap_keeps_multi_db_enabled -- --nocapture` - `cargo test -p memoria-cli pending_migration_bootstrap_waits_for_migration_before_switching -- --nocapture` - `cargo test -p memoria-storage classify_runtime_topology_detects_fresh_single_db -- --nocapture` - `cargo test -p memoria-api api_multi_db_routes_reads_and_writes_to_each_users_database -- --nocapture` - `cargo test -p memoria-mcp test_multi_db_snapshot_rollback_isolates_users -- --nocapture` - full `memoria serve` smoke on isolated Docker MatrixOne for fresh / ready / pending-migration startup states Approved by: @XuPeng-SH
1 parent 11603af commit de98df6

3 files changed

Lines changed: 421 additions & 40 deletions

File tree

docs/per-user-database-architecture.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,9 @@ SELECT COUNT(*) FROM `<db_name-from-registry>`.mem_memories WHERE is_active = 1;
603603

604604
如果你只是要让新版 multi-db 服务正常启动并工作,最小集合是:
605605

606+
> 最新运行时在**全新空环境**下,即使没显式设置 `MEMORIA_MULTI_DB`,也会自动按 multi-db 启动。
607+
> 但生产环境仍然建议把 `MEMORIA_MULTI_DB=1``MEMORIA_SHARED_DATABASE_URL` 明确写出来,避免歧义,也兼容旧版本节点。
608+
606609
```bash
607610
DATABASE_URL=<shared-db-url>
608611
MEMORIA_MULTI_DB=1
@@ -620,7 +623,7 @@ EMBEDDING_DIM=<must-match-schema>
620623
| 变量 | 是否迁移后重启必需 | 说明 |
621624
|---|---|---|
622625
| `DATABASE_URL` || multi-db 模式下应指向 shared DB |
623-
| `MEMORIA_MULTI_DB` | | 置为 `1` / `true` 打开新架构 |
626+
| `MEMORIA_MULTI_DB` | 建议显式设置 | 置为 `1` / `true` 可强制打开新架构;最新运行时对 fresh empty 环境会自动切到 multi-db |
624627
| `MEMORIA_SHARED_DATABASE_URL` || 显式告诉服务 shared DB 在哪;**不要依赖默认推导** |
625628
| `EMBEDDING_PROVIDER` | 基本必需 | 不设会退回默认 `mock`,检索语义会变掉 |
626629
| `EMBEDDING_BASE_URL` | 基本必需 | HTTP embedding 服务地址 |
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
use serde_json::{json, Value};
2+
use serial_test::serial;
3+
use sqlx::{mysql::MySqlPoolOptions, MySqlPool};
4+
use std::{sync::Arc, time::Duration};
5+
6+
struct MultiDbTestServer {
7+
base: String,
8+
client: reqwest::Client,
9+
router: Arc<memoria_storage::DbRouter>,
10+
shared_db_url: String,
11+
state: memoria_api::AppState,
12+
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
13+
server_handle: tokio::task::JoinHandle<()>,
14+
}
15+
16+
impl MultiDbTestServer {
17+
async fn cleanup(mut self, user_db_names: &[String], mut direct_pools: Vec<MySqlPool>) {
18+
if let Some(tx) = self.shutdown_tx.take() {
19+
let _ = tx.send(());
20+
}
21+
self.server_handle
22+
.await
23+
.expect("server task should shut down cleanly");
24+
self.state.service.drain_edit_log().await;
25+
self.state.drain_flushers().await;
26+
27+
let shared_pool = self.router.shared_pool().clone();
28+
let global_user_pool = self.router.global_user_pool().clone();
29+
let mut db_names = sqlx::query_scalar::<_, String>(
30+
"SELECT DISTINCT db_name FROM mem_user_registry WHERE status = 'active'",
31+
)
32+
.fetch_all(&shared_pool)
33+
.await
34+
.unwrap_or_default();
35+
db_names.extend_from_slice(user_db_names);
36+
db_names.sort();
37+
db_names.dedup();
38+
39+
for pool in direct_pools.drain(..) {
40+
pool.close().await;
41+
}
42+
43+
drop(self.client);
44+
drop(self.state);
45+
drop(self.router);
46+
shared_pool.close().await;
47+
global_user_pool.close().await;
48+
cleanup_databases(&self.shared_db_url, &db_names).await;
49+
}
50+
}
51+
52+
fn test_dim() -> usize {
53+
std::env::var("EMBEDDING_DIM")
54+
.ok()
55+
.and_then(|s| s.parse().ok())
56+
.unwrap_or(1024)
57+
}
58+
59+
fn db_url() -> String {
60+
std::env::var("DATABASE_URL")
61+
.unwrap_or_else(|_| "mysql://root:111@localhost:6001/memoria".to_string())
62+
}
63+
64+
fn replace_db_name(database_url: &str, db_name: &str) -> String {
65+
let suffix_start = database_url.find(['?', '#']).unwrap_or(database_url.len());
66+
let (without_suffix, suffix) = database_url.split_at(suffix_start);
67+
let (base, _) = without_suffix
68+
.rsplit_once('/')
69+
.expect("database url must include db name");
70+
format!("{base}/{db_name}{suffix}")
71+
}
72+
73+
fn shared_db_url() -> String {
74+
replace_db_name(
75+
&db_url(),
76+
&format!(
77+
"memoria_api_multi_{}",
78+
&uuid::Uuid::new_v4().simple().to_string()[..8]
79+
),
80+
)
81+
}
82+
83+
fn uid(prefix: &str) -> String {
84+
format!("{prefix}_{}", uuid::Uuid::new_v4().simple())
85+
}
86+
87+
fn split_db_url(database_url: &str) -> (&str, &str) {
88+
let suffix_start = database_url.find(['?', '#']).unwrap_or(database_url.len());
89+
let without_suffix = &database_url[..suffix_start];
90+
without_suffix
91+
.rsplit_once('/')
92+
.expect("database url must include db name")
93+
}
94+
95+
fn quote_ident(name: &str) -> String {
96+
format!("`{}`", name.replace('`', "``"))
97+
}
98+
99+
async fn cleanup_databases(shared_db_url: &str, user_db_names: &[String]) {
100+
let (base_url, shared_db_name) = split_db_url(shared_db_url);
101+
let admin_pool = MySqlPoolOptions::new()
102+
.max_connections(1)
103+
.connect(base_url)
104+
.await
105+
.expect("connect admin pool");
106+
107+
for db_name in user_db_names {
108+
sqlx::raw_sql(&format!("DROP DATABASE IF EXISTS {}", quote_ident(db_name)))
109+
.execute(&admin_pool)
110+
.await
111+
.expect("drop user db");
112+
}
113+
sqlx::raw_sql(&format!(
114+
"DROP DATABASE IF EXISTS {}",
115+
quote_ident(shared_db_name)
116+
))
117+
.execute(&admin_pool)
118+
.await
119+
.expect("drop shared db");
120+
admin_pool.close().await;
121+
}
122+
123+
async fn wait_for_server(client: &reqwest::Client, base: &str, pool: &MySqlPool) {
124+
for _ in 0..20 {
125+
if client.get(format!("{base}/health")).send().await.is_ok() {
126+
break;
127+
}
128+
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
129+
}
130+
131+
for _ in 0..20 {
132+
if sqlx::query("SELECT 1").execute(pool).await.is_ok() {
133+
return;
134+
}
135+
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
136+
}
137+
138+
panic!("DB not ready after 1s");
139+
}
140+
141+
async fn spawn_multi_db_server() -> MultiDbTestServer {
142+
use memoria_git::GitForDataService;
143+
use memoria_service::MemoryService;
144+
use memoria_storage::{DbRouter, SqlMemoryStore};
145+
146+
let shared_db_url = shared_db_url();
147+
memoria_test_utils::wait_for_mysql_ready(&shared_db_url, Duration::from_secs(30)).await;
148+
149+
let router = Arc::new(
150+
DbRouter::connect(&shared_db_url, test_dim(), uuid::Uuid::new_v4().to_string())
151+
.await
152+
.expect("router"),
153+
);
154+
let shared_pool = router.shared_pool().clone();
155+
let shared_pool_max_connections = router.shared_pool_max_connections();
156+
let mut store = SqlMemoryStore::from_existing_pool(
157+
shared_pool.clone(),
158+
test_dim(),
159+
uuid::Uuid::new_v4().to_string(),
160+
Some(shared_db_url.clone()),
161+
Some(shared_pool_max_connections),
162+
"api_multi_db_shared_pool",
163+
);
164+
store.migrate_shared().await.expect("migrate_shared");
165+
store.set_db_router(router.clone());
166+
167+
let git = Arc::new(GitForDataService::new(
168+
shared_pool.clone(),
169+
router.shared_db_name().to_string(),
170+
));
171+
let service = Arc::new(
172+
MemoryService::new_sql_with_llm_and_router(
173+
Arc::new(store),
174+
Some(router.clone()),
175+
None,
176+
None,
177+
)
178+
.await,
179+
);
180+
let state = memoria_api::AppState::new(service, git, String::new());
181+
let app = memoria_api::build_router(state.clone());
182+
183+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
184+
.await
185+
.expect("bind");
186+
let port = listener.local_addr().expect("local addr").port();
187+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
188+
let server_handle = tokio::spawn(async move {
189+
axum::serve(listener, app)
190+
.with_graceful_shutdown(async {
191+
let _ = shutdown_rx.await;
192+
})
193+
.await
194+
.unwrap();
195+
});
196+
197+
let client = reqwest::Client::builder().no_proxy().build().unwrap();
198+
let base = format!("http://127.0.0.1:{port}");
199+
wait_for_server(&client, &base, &shared_pool).await;
200+
MultiDbTestServer {
201+
base,
202+
client,
203+
router,
204+
shared_db_url,
205+
state,
206+
shutdown_tx: Some(shutdown_tx),
207+
server_handle,
208+
}
209+
}
210+
211+
async fn store_memory(
212+
base: &str,
213+
client: &reqwest::Client,
214+
user_id: &str,
215+
content: &str,
216+
) -> String {
217+
let response = client
218+
.post(format!("{base}/v1/memories"))
219+
.header("X-User-Id", user_id)
220+
.json(&json!({ "content": content }))
221+
.send()
222+
.await
223+
.expect("store request");
224+
assert_eq!(response.status(), 201);
225+
let body: Value = response.json().await.expect("store response body");
226+
body["memory_id"].as_str().expect("memory_id").to_string()
227+
}
228+
229+
async fn list_memory_ids(base: &str, client: &reqwest::Client, user_id: &str) -> Vec<String> {
230+
let response = client
231+
.get(format!("{base}/v1/memories"))
232+
.header("X-User-Id", user_id)
233+
.send()
234+
.await
235+
.expect("list request");
236+
assert_eq!(response.status(), 200);
237+
let body: Value = response.json().await.expect("list response body");
238+
body["items"]
239+
.as_array()
240+
.expect("items")
241+
.iter()
242+
.map(|item| {
243+
item["memory_id"]
244+
.as_str()
245+
.expect("list memory_id")
246+
.to_string()
247+
})
248+
.collect()
249+
}
250+
251+
async fn user_db_pool(shared_db_url: &str, db_name: &str) -> MySqlPool {
252+
MySqlPoolOptions::new()
253+
.max_connections(1)
254+
.connect(&replace_db_name(shared_db_url, db_name))
255+
.await
256+
.expect("connect user db")
257+
}
258+
259+
async fn count_user_memories(pool: &MySqlPool, user_id: &str) -> i64 {
260+
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM mem_memories WHERE user_id = ?")
261+
.bind(user_id)
262+
.fetch_one(pool)
263+
.await
264+
.expect("count user memories")
265+
}
266+
267+
#[tokio::test]
268+
#[serial]
269+
async fn api_multi_db_routes_reads_and_writes_to_each_users_database() {
270+
let server = spawn_multi_db_server().await;
271+
let user_a = uid("api_multi_a");
272+
let user_b = uid("api_multi_b");
273+
274+
let memory_a = store_memory(&server.base, &server.client, &user_a, "alpha memory").await;
275+
let memory_b = store_memory(&server.base, &server.client, &user_b, "beta memory").await;
276+
277+
let db_a = server
278+
.router
279+
.user_db_name(&user_a)
280+
.await
281+
.expect("user A db");
282+
let db_b = server
283+
.router
284+
.user_db_name(&user_b)
285+
.await
286+
.expect("user B db");
287+
assert_ne!(
288+
db_a, db_b,
289+
"multi-db API test must route users to different databases"
290+
);
291+
292+
let listed_a = list_memory_ids(&server.base, &server.client, &user_a).await;
293+
let listed_b = list_memory_ids(&server.base, &server.client, &user_b).await;
294+
assert_eq!(listed_a, vec![memory_a]);
295+
assert_eq!(listed_b, vec![memory_b]);
296+
297+
let pool_a = user_db_pool(&server.shared_db_url, &db_a).await;
298+
let pool_b = user_db_pool(&server.shared_db_url, &db_b).await;
299+
assert_eq!(count_user_memories(&pool_a, &user_a).await, 1);
300+
assert_eq!(count_user_memories(&pool_a, &user_b).await, 0);
301+
assert_eq!(count_user_memories(&pool_b, &user_b).await, 1);
302+
assert_eq!(count_user_memories(&pool_b, &user_a).await, 0);
303+
304+
server.cleanup(&[db_a, db_b], vec![pool_a, pool_b]).await;
305+
}

0 commit comments

Comments
 (0)