Skip to content

Commit 1f99f39

Browse files
authored
fix(memory): enforce strict session retrieval semantics (#185)
## Summary - thread strict session retrieval options through API, service, storage, and MCP - skip graph retrieval when strict session scope is requested and keep default cross-session semantics unchanged - add a strict scoped vector fallback plus REST/MCP regressions covering issue #184 semantics ## Testing - DATABASE_URL=mysql://root:111@localhost:6001/memoria_test SQLX_OFFLINE=true cargo test -p memoria-storage --test store_crud test_search_vector_from_filtered_scoped_prefilters_by_session -- --nocapture - DATABASE_URL=mysql://root:111@localhost:6001/memoria_test SQLX_OFFLINE=true cargo test -p memoria-storage --test store_crud test_search_vector_from_filtered_scoped_fills_limit_with_session_candidates -- --nocapture - DATABASE_URL=mysql://root:111@localhost:6001/memoria_test SQLX_OFFLINE=true cargo test -p memoria-api --test api_e2e test_retrieve_filter_session_prefilters_and_skips_graph -- --nocapture - DATABASE_URL=mysql://root:111@localhost:6001/memoria_test SQLX_OFFLINE=true cargo test -p memoria-api --test api_e2e test_retrieve_filter_session_preserves_top_k_with_session_candidates -- --nocapture - DATABASE_URL=mysql://root:111@localhost:6001/memoria_test SQLX_OFFLINE=true cargo test -p memoria-api --test api_e2e test_mcp_memory_retrieve_filter_session_end_to_end -- --nocapture Fixes #184 Approved by: @XuPeng-SH
1 parent 897d6d5 commit 1f99f39

9 files changed

Lines changed: 984 additions & 172 deletions

File tree

memoria/crates/memoria-api/src/models.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ pub struct RetrieveRequest {
3333
#[serde(default = "default_top_k")]
3434
pub top_k: i64,
3535
pub session_id: Option<String>,
36+
/// Explicit strict session filter. Overrides include_cross_session when provided.
37+
#[serde(default)]
38+
pub filter_session: Option<bool>,
3639
/// When false and session_id is set, only return memories from that session.
3740
#[serde(default = "default_true")]
3841
pub include_cross_session: bool,
@@ -47,6 +50,16 @@ fn default_true() -> bool {
4750
true
4851
}
4952

53+
impl RetrieveRequest {
54+
pub fn retrieve_options(&self) -> memoria_service::RetrieveOptions {
55+
memoria_service::RetrieveOptions::from_session_scope(
56+
self.session_id.as_deref(),
57+
self.filter_session,
58+
Some(self.include_cross_session),
59+
)
60+
}
61+
}
62+
5063
fn deserialize_explain<'de, D: serde::Deserializer<'de>>(d: D) -> Result<String, D::Error> {
5164
use serde::Deserialize;
5265
#[derive(Deserialize)]

memoria/crates/memoria-api/src/routes/memory.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -211,39 +211,40 @@ pub async fn retrieve(
211211
AuthUser { user_id, .. }: AuthUser,
212212
Json(req): Json<RetrieveRequest>,
213213
) -> ApiResult<serde_json::Value> {
214+
if req.session_id.is_none() && (req.filter_session == Some(true) || !req.include_cross_session)
215+
{
216+
return Err((
217+
StatusCode::UNPROCESSABLE_ENTITY,
218+
"session_id is required for strict session retrieval".to_string(),
219+
));
220+
}
214221
let top_k = req.top_k.clamp(1, 100);
215222
let level = memoria_service::ExplainLevel::from_str_or_bool(&req.explain);
216-
let filter_session = req
217-
.session_id
218-
.as_deref()
219-
.filter(|_| !req.include_cross_session);
220-
221-
let apply_filter = |mut mems: Vec<memoria_core::Memory>| -> Vec<memoria_core::Memory> {
222-
if let Some(sid) = filter_session {
223-
mems.retain(|m| m.session_id.as_deref() == Some(sid));
224-
}
225-
mems
226-
};
223+
let retrieve_options = req.retrieve_options();
227224

228225
if level != memoria_service::ExplainLevel::None {
229226
let (results, explain) = state
230227
.service
231-
.retrieve_explain_level(&user_id, &req.query, top_k, level)
228+
.retrieve_explain_level_with_options(
229+
&user_id,
230+
&req.query,
231+
top_k,
232+
level,
233+
&retrieve_options,
234+
)
232235
.await
233236
.map_err(api_err)?;
234-
let items: Vec<MemoryResponse> =
235-
apply_filter(results).into_iter().map(Into::into).collect();
237+
let items: Vec<MemoryResponse> = results.into_iter().map(Into::into).collect();
236238
Ok(Json(
237239
serde_json::json!({"results": items, "explain": explain}),
238240
))
239241
} else {
240242
let results = state
241243
.service
242-
.retrieve(&user_id, &req.query, top_k)
244+
.retrieve_with_options(&user_id, &req.query, top_k, &retrieve_options)
243245
.await
244246
.map_err(api_err)?;
245-
let items: Vec<MemoryResponse> =
246-
apply_filter(results).into_iter().map(Into::into).collect();
247+
let items: Vec<MemoryResponse> = results.into_iter().map(Into::into).collect();
247248
Ok(Json(serde_json::json!(items)))
248249
}
249250
}

0 commit comments

Comments
 (0)