Skip to content

Commit dc5812a

Browse files
gouhongshenCopilot
andcommitted
Add selective memory_pick branch apply
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent b34717d commit dc5812a

9 files changed

Lines changed: 1849 additions & 26 deletions

File tree

memoria/crates/memoria-api/src/lib.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,12 @@ async fn call_log_mw(
8484
// which is necessary because JSON-RPC errors return HTTP 200.
8585
if !is_dashboard && !path.starts_with("/v1/mcp") {
8686
if let Some(reporter) = &state.stats_reporter {
87-
reporter.report(
88-
memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
89-
user_id: uid.clone(),
90-
path: path.clone(),
91-
is_mcp: false,
92-
is_success: status_code < 400,
93-
},
94-
);
87+
reporter.report(memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
88+
user_id: uid.clone(),
89+
path: path.clone(),
90+
is_mcp: false,
91+
is_success: status_code < 400,
92+
});
9593
}
9694
}
9795
if let Some(mask) = should_mark_metrics_dirty(&method, &path, status_code) {
@@ -147,6 +145,7 @@ fn should_mark_metrics_dirty(
147145
} else if path.starts_with("/v1/snapshots/") && path.ends_with("/rollback")
148146
|| path.starts_with("/v1/branches/") && path.ends_with("/checkout")
149147
|| path.starts_with("/v1/branches/") && path.ends_with("/merge")
148+
|| path.starts_with("/v1/branches/") && path.ends_with("/pick")
150149
{
151150
Some(DirtyMask::FULL)
152151
} else if path.starts_with("/v1/sessions/") && path.ends_with("/summary") {
@@ -283,6 +282,10 @@ pub fn build_router(state: AppState) -> Router {
283282
"/v1/branches/:name/diff",
284283
get(routes::snapshots::diff_branch),
285284
)
285+
.route(
286+
"/v1/branches/:name/pick",
287+
post(routes::snapshots::pick_branch),
288+
)
286289
.route(
287290
"/v1/branches/:name",
288291
delete(routes::snapshots::delete_branch),

memoria/crates/memoria-api/src/models.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,44 @@ fn default_strategy() -> String {
284284
"accept".to_string()
285285
}
286286

287+
#[derive(Debug, Deserialize, Serialize)]
288+
pub struct PickRequest {
289+
#[serde(default = "default_pick_target")]
290+
pub target: String,
291+
#[serde(default = "default_pick_strategy")]
292+
pub strategy: String,
293+
pub selector: PickSelector,
294+
}
295+
296+
fn default_pick_target() -> String {
297+
"main".to_string()
298+
}
299+
300+
fn default_pick_strategy() -> String {
301+
"fail".to_string()
302+
}
303+
304+
fn default_pick_top_k() -> i64 {
305+
5
306+
}
307+
308+
#[derive(Debug, Deserialize, Serialize)]
309+
#[serde(tag = "type", rename_all = "snake_case")]
310+
pub enum PickSelector {
311+
KeyList {
312+
keys: Vec<String>,
313+
},
314+
SnapshotRange {
315+
from_snapshot: String,
316+
to_snapshot: String,
317+
},
318+
Retrieve {
319+
query: String,
320+
#[serde(default = "default_pick_top_k")]
321+
top_k: i64,
322+
},
323+
}
324+
287325
// ── Helpers ───────────────────────────────────────────────────────────────────
288326

289327
pub fn parse_memory_type(s: &str) -> Result<MemoryType, String> {

memoria/crates/memoria-api/src/routes/mcp.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ fn mcp_tool_dirty_mask(tool: &str) -> Option<crate::metrics_summary::DirtyMask>
9191
"memory_snapshot" | "memory_snapshot_delete" => Some(DirtyMask::SNAPSHOT),
9292
"memory_rollback" => Some(DirtyMask::FULL),
9393
"memory_branch" | "memory_branch_delete" => Some(DirtyMask::BRANCH),
94-
"memory_checkout" | "memory_merge" => Some(DirtyMask::FULL),
94+
"memory_checkout" | "memory_merge" | "memory_pick" => Some(DirtyMask::FULL),
9595
_ => None,
9696
}
9797
}
@@ -133,14 +133,12 @@ pub async fn mcp_handler(
133133
RpcMeta::err($code),
134134
);
135135
if let Some(reporter) = &state.stats_reporter {
136-
reporter.report(
137-
memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
138-
user_id: auth.user_id.clone(),
139-
path: $path.to_string(),
140-
is_mcp: true,
141-
is_success: false,
142-
},
143-
);
136+
reporter.report(memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
137+
user_id: auth.user_id.clone(),
138+
path: $path.to_string(),
139+
is_mcp: true,
140+
is_success: false,
141+
});
144142
}
145143
return Json($body).into_response();
146144
}};

memoria/crates/memoria-api/src/routes/snapshots.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
routes::memory::{api_err, api_err_typed},
1414
state::AppState,
1515
};
16-
use memoria_core::TrustTier;
16+
use memoria_core::{MemoriaError, TrustTier};
1717
use memoria_git::GitForDataService;
1818
use std::sync::Arc;
1919

@@ -66,6 +66,27 @@ async fn git_call(
6666
Ok(json!({ "result": text }))
6767
}
6868

69+
async fn git_call_pick(
70+
state: &AppState,
71+
user_id: &str,
72+
args: serde_json::Value,
73+
) -> Result<serde_json::Value, (StatusCode, String)> {
74+
let result =
75+
memoria_mcp::git_tools::call("memory_pick", args, &state.git, &state.service, user_id)
76+
.await
77+
.map_err(|e| match e {
78+
MemoriaError::Validation(msg) if msg.starts_with("Conflict:") => {
79+
(StatusCode::CONFLICT, msg)
80+
}
81+
other => api_err_typed(other),
82+
})?;
83+
let text = result["content"][0]["text"]
84+
.as_str()
85+
.unwrap_or("")
86+
.to_string();
87+
Ok(json!({ "result": text }))
88+
}
89+
6990
async fn user_snapshot_store(
7091
state: &AppState,
7192
user_id: &str,
@@ -687,6 +708,26 @@ pub async fn diff_branch(
687708
Ok(Json(r))
688709
}
689710

711+
pub async fn pick_branch(
712+
State(state): State<AppState>,
713+
AuthUser { user_id, .. }: AuthUser,
714+
Path(name): Path<String>,
715+
Json(req): Json<PickRequest>,
716+
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
717+
let r = git_call_pick(
718+
&state,
719+
&user_id,
720+
json!({
721+
"source": name,
722+
"target": req.target,
723+
"strategy": req.strategy,
724+
"selector": req.selector,
725+
}),
726+
)
727+
.await?;
728+
Ok(Json(r))
729+
}
730+
690731
pub async fn delete_branch(
691732
State(state): State<AppState>,
692733
AuthUser { user_id, .. }: AuthUser,

0 commit comments

Comments
 (0)