Skip to content

Commit 79116ee

Browse files
author
Andrew T.
committed
Translate rollback turn ids to numTurns
1 parent b481d27 commit 79116ee

1 file changed

Lines changed: 86 additions & 1 deletion

File tree

src-tauri/src/shared/codex_core.rs

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,50 @@ pub(crate) async fn rollback_thread_core(
379379
turn_id: String,
380380
) -> Result<Value, String> {
381381
let session = get_session_clone(sessions, &workspace_id).await?;
382-
let params = json!({ "threadId": thread_id, "turnId": turn_id });
382+
let thread_response = read_thread_core(sessions, workspace_id.clone(), thread_id.clone()).await?;
383+
let num_turns = rollback_num_turns_from_response(&thread_response, &turn_id)?;
384+
let params = json!({ "threadId": thread_id, "numTurns": num_turns });
383385
session
384386
.send_request_for_workspace(&workspace_id, "thread/rollback", params)
385387
.await
386388
}
387389

390+
fn rollback_num_turns_from_response(response: &Value, turn_id: &str) -> Result<usize, String> {
391+
let thread = extract_thread_from_response(response)
392+
.ok_or_else(|| "Rollback failed: thread/read response missing thread payload.".to_string())?;
393+
let turns = thread
394+
.get("turns")
395+
.and_then(Value::as_array)
396+
.ok_or_else(|| "Rollback failed: thread/read response missing turns.".to_string())?;
397+
let turn_index = turns
398+
.iter()
399+
.position(|turn| {
400+
turn.as_object().is_some_and(|record| {
401+
record
402+
.get("id")
403+
.or_else(|| record.get("turnId"))
404+
.or_else(|| record.get("turn_id"))
405+
.and_then(Value::as_str)
406+
.is_some_and(|value| value == turn_id)
407+
})
408+
})
409+
.ok_or_else(|| format!("Rollback failed: turn '{turn_id}' was not found in thread."))?;
410+
Ok(turns.len().saturating_sub(turn_index))
411+
}
412+
413+
fn extract_thread_from_response<'a>(response: &'a Value) -> Option<&'a Map<String, Value>> {
414+
response
415+
.as_object()
416+
.and_then(|record| {
417+
record
418+
.get("result")
419+
.and_then(Value::as_object)
420+
.and_then(|result| result.get("thread"))
421+
.or_else(|| record.get("thread"))
422+
})
423+
.and_then(Value::as_object)
424+
}
425+
388426
pub(crate) async fn compact_thread_core(
389427
sessions: &Mutex<HashMap<String, Arc<WorkspaceSession>>>,
390428
workspace_id: String,
@@ -1043,4 +1081,51 @@ mod tests {
10431081
assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentCompact"));
10441082
assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentThreadSpawn"));
10451083
}
1084+
1085+
#[test]
1086+
fn rollback_num_turns_counts_from_target_turn_to_end() {
1087+
let response = json!({
1088+
"result": {
1089+
"thread": {
1090+
"turns": [
1091+
{ "id": "turn-1" },
1092+
{ "id": "turn-2" },
1093+
{ "id": "turn-3" }
1094+
]
1095+
}
1096+
}
1097+
});
1098+
1099+
let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap();
1100+
assert_eq!(num_turns, 2);
1101+
}
1102+
1103+
#[test]
1104+
fn rollback_num_turns_supports_turn_id_aliases() {
1105+
let response = json!({
1106+
"thread": {
1107+
"turns": [
1108+
{ "turn_id": "turn-1" },
1109+
{ "turnId": "turn-2" }
1110+
]
1111+
}
1112+
});
1113+
1114+
let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap();
1115+
assert_eq!(num_turns, 1);
1116+
}
1117+
1118+
#[test]
1119+
fn rollback_num_turns_errors_when_turn_is_missing() {
1120+
let response = json!({
1121+
"result": {
1122+
"thread": {
1123+
"turns": [{ "id": "turn-1" }]
1124+
}
1125+
}
1126+
});
1127+
1128+
let err = rollback_num_turns_from_response(&response, "turn-9").unwrap_err();
1129+
assert!(err.contains("turn-9"));
1130+
}
10461131
}

0 commit comments

Comments
 (0)