Skip to content

Commit d025a0e

Browse files
authored
Merge pull request #73 from 7df-lab/dev/0519
fix: correct interleaved tool input deltas and recover command actions from tool results
2 parents 64b4001 + e0e7068 commit d025a0e

129 files changed

Lines changed: 16559 additions & 91 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

crates/core/src/query.rs

Lines changed: 256 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::collections::HashSet;
23
use std::sync::Arc;
34
use std::time::Duration;
@@ -96,6 +97,8 @@ pub enum QueryEvent {
9697
/// A tool call completed.
9798
ToolResult {
9899
tool_use_id: String,
100+
tool_name: String,
101+
input: serde_json::Value,
99102
content: ToolContent,
100103
display_content: Option<String>,
101104
is_error: bool,
@@ -610,7 +613,8 @@ pub async fn query(
610613

611614
let mut assistant_text = String::new();
612615
let mut reasoning_text = String::new();
613-
let mut tool_uses: Vec<(String, String, serde_json::Value, String, bool)> = Vec::new();
616+
let mut tool_uses: Vec<(usize, String, String, serde_json::Value, String, bool)> =
617+
Vec::new();
614618
let mut emitted_tool_use_starts: HashSet<String> = HashSet::new();
615619
let mut final_response = None;
616620
let mut stop_reason = None;
@@ -631,7 +635,10 @@ pub async fn query(
631635
emit(QueryEvent::ReasoningCompleted);
632636
}
633637
Ok(StreamEvent::ToolCallStart {
634-
id, name, input, ..
638+
index,
639+
id,
640+
name,
641+
input,
635642
}) => {
636643
if emitted_tool_use_starts.insert(id.clone()) {
637644
emit(QueryEvent::ToolUseStart {
@@ -640,12 +647,19 @@ pub async fn query(
640647
input: input.clone(),
641648
});
642649
}
643-
tool_uses.push((id, name, input, String::new(), false));
650+
tool_uses.push((index, id, name, input, String::new(), false));
644651
}
645-
Ok(StreamEvent::ToolCallInputDelta { partial_json, .. }) => {
646-
if let Some(last) = tool_uses.last_mut() {
647-
last.3.push_str(&partial_json);
648-
last.4 = true;
652+
Ok(StreamEvent::ToolCallInputDelta {
653+
index,
654+
partial_json,
655+
}) => {
656+
if let Some(tool_use) = tool_uses
657+
.iter_mut()
658+
.rev()
659+
.find(|(tool_index, ..)| *tool_index == index)
660+
{
661+
tool_use.4.push_str(&partial_json);
662+
tool_use.5 = true;
649663
}
650664
}
651665
Ok(StreamEvent::MessageDone { response }) => {
@@ -741,8 +755,10 @@ pub async fn query(
741755
tool_uses = response
742756
.content
743757
.iter()
744-
.filter_map(|block| match block {
758+
.enumerate()
759+
.filter_map(|(index, block)| match block {
745760
ResponseContent::ToolUse { id, name, input } => Some((
761+
index,
746762
id.clone(),
747763
name.clone(),
748764
input.clone(),
@@ -785,13 +801,31 @@ pub async fn query(
785801
});
786802
}
787803

804+
let final_tool_inputs: HashMap<String, serde_json::Value> = final_response
805+
.as_ref()
806+
.map(|response| {
807+
response
808+
.content
809+
.iter()
810+
.filter_map(|block| match block {
811+
ResponseContent::ToolUse { id, input, .. } => {
812+
Some((id.clone(), input.clone()))
813+
}
814+
ResponseContent::Text(_) => None,
815+
})
816+
.collect()
817+
})
818+
.unwrap_or_default();
819+
788820
let tool_calls: Vec<ToolCall> = tool_uses
789821
.into_iter()
790-
.map(|(id, name, initial_input, json_str, saw_delta)| {
822+
.map(|(_index, id, name, initial_input, json_str, saw_delta)| {
791823
let input = if saw_delta {
792-
serde_json::from_str(&json_str).unwrap_or(initial_input)
824+
serde_json::from_str(&json_str).unwrap_or_else(|_| {
825+
final_tool_inputs.get(&id).cloned().unwrap_or(initial_input)
826+
})
793827
} else {
794-
initial_input
828+
final_tool_inputs.get(&id).cloned().unwrap_or(initial_input)
795829
};
796830
if emitted_tool_use_starts.insert(id.clone()) {
797831
emit(QueryEvent::ToolUseStart {
@@ -831,12 +865,20 @@ pub async fn query(
831865
return Ok(());
832866
}
833867

834-
let tool_result_summaries: std::collections::HashMap<String, String> = tool_calls
868+
let tool_result_metadata: HashMap<String, (String, serde_json::Value, String)> = tool_calls
835869
.iter()
836870
.map(|call| {
837871
(
838872
call.id.clone(),
839-
devo_tools::tool_summary::tool_summary(&call.name, &call.input, &session.cwd),
873+
(
874+
call.name.clone(),
875+
call.input.clone(),
876+
devo_tools::tool_summary::tool_summary(
877+
&call.name,
878+
&call.input,
879+
&session.cwd,
880+
),
881+
),
840882
)
841883
})
842884
.collect();
@@ -846,7 +888,7 @@ pub async fn query(
846888
// long-running and parallel tools can render before the whole batch ends.
847889
let results = if let Some(progress_events) = on_event.clone() {
848890
let completion_events = Arc::clone(&progress_events);
849-
let summaries = Arc::new(tool_result_summaries.clone());
891+
let metadata = Arc::new(tool_result_metadata.clone());
850892
runtime
851893
.execute_batch_streaming_with_completion(
852894
&tool_calls,
@@ -859,12 +901,16 @@ pub async fn query(
859901
move |result| {
860902
let content = compact_tool_content(result.content.clone());
861903
let display_content = result.display_content.clone().map(micro_compact);
862-
let summary = summaries
904+
let (tool_name, input, summary) = metadata
863905
.get(result.tool_use_id.as_str())
864906
.cloned()
865-
.unwrap_or_default();
907+
.unwrap_or_else(|| {
908+
(String::new(), serde_json::Value::Null, String::new())
909+
});
866910
completion_events(QueryEvent::ToolResult {
867911
tool_use_id: result.tool_use_id.clone(),
912+
tool_name,
913+
input,
868914
content,
869915
display_content,
870916
is_error: result.is_error,
@@ -1025,6 +1071,10 @@ mod tests {
10251071
requests: AtomicUsize,
10261072
}
10271073

1074+
struct InterleavedToolUseProvider {
1075+
requests: AtomicUsize,
1076+
}
1077+
10281078
struct ParallelToolUseProvider {
10291079
requests: AtomicUsize,
10301080
}
@@ -1093,6 +1143,87 @@ mod tests {
10931143
}
10941144
}
10951145

1146+
#[async_trait]
1147+
impl devo_provider::ModelProviderSDK for InterleavedToolUseProvider {
1148+
async fn completion(&self, _request: ModelRequest) -> Result<ModelResponse> {
1149+
unreachable!("tests stream responses only")
1150+
}
1151+
1152+
async fn completion_stream(
1153+
&self,
1154+
_request: ModelRequest,
1155+
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
1156+
let request_number = self.requests.fetch_add(1, Ordering::SeqCst);
1157+
1158+
let events = if request_number == 0 {
1159+
vec![
1160+
Ok(StreamEvent::ToolCallStart {
1161+
index: 0,
1162+
id: "tool-1".into(),
1163+
name: "mutating_tool".into(),
1164+
input: json!({}),
1165+
}),
1166+
Ok(StreamEvent::ToolCallStart {
1167+
index: 1,
1168+
id: "tool-2".into(),
1169+
name: "mutating_tool".into(),
1170+
input: json!({}),
1171+
}),
1172+
Ok(StreamEvent::ToolCallInputDelta {
1173+
index: 0,
1174+
partial_json: r#"{"value":1}"#.into(),
1175+
}),
1176+
Ok(StreamEvent::ToolCallInputDelta {
1177+
index: 1,
1178+
partial_json: r#"{"value":2}"#.into(),
1179+
}),
1180+
Ok(StreamEvent::MessageDone {
1181+
response: ModelResponse {
1182+
id: "resp-1".into(),
1183+
content: vec![
1184+
ResponseContent::ToolUse {
1185+
id: "tool-1".into(),
1186+
name: "mutating_tool".into(),
1187+
input: json!({}),
1188+
},
1189+
ResponseContent::ToolUse {
1190+
id: "tool-2".into(),
1191+
name: "mutating_tool".into(),
1192+
input: json!({}),
1193+
},
1194+
],
1195+
stop_reason: Some(StopReason::ToolUse),
1196+
usage: Usage::default(),
1197+
metadata: Default::default(),
1198+
},
1199+
}),
1200+
]
1201+
} else {
1202+
vec![
1203+
Ok(StreamEvent::TextDelta {
1204+
index: 0,
1205+
text: "done".into(),
1206+
}),
1207+
Ok(StreamEvent::MessageDone {
1208+
response: ModelResponse {
1209+
id: "resp-2".into(),
1210+
content: vec![ResponseContent::Text("done".into())],
1211+
stop_reason: Some(StopReason::EndTurn),
1212+
usage: Usage::default(),
1213+
metadata: Default::default(),
1214+
},
1215+
}),
1216+
]
1217+
};
1218+
1219+
Ok(Box::pin(futures::stream::iter(events)))
1220+
}
1221+
1222+
fn name(&self) -> &str {
1223+
"interleaved-test-provider"
1224+
}
1225+
}
1226+
10961227
#[async_trait]
10971228
impl devo_provider::ModelProviderSDK for ParallelToolUseProvider {
10981229
async fn completion(&self, _request: ModelRequest) -> Result<ModelResponse> {
@@ -2012,6 +2143,115 @@ mod tests {
20122143
}
20132144
}
20142145

2146+
#[tokio::test]
2147+
async fn query_tool_result_event_includes_final_tool_input() {
2148+
let mut builder = ToolRegistryBuilder::new();
2149+
builder.register_handler("mutating_tool", Arc::new(DisplayContentTool));
2150+
builder.push_spec(ToolSpec {
2151+
name: "mutating_tool".into(),
2152+
description: String::new(),
2153+
input_schema: JsonSchema::object(Default::default(), None, None),
2154+
output_mode: ToolOutputMode::Text,
2155+
execution_mode: ToolExecutionMode::ReadOnly,
2156+
capability_tags: vec![],
2157+
supports_parallel: false,
2158+
preparation_feedback: ToolPreparationFeedback::None,
2159+
});
2160+
let registry = Arc::new(builder.build());
2161+
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));
2162+
2163+
let mut session = SessionState::new(SessionConfig::default(), std::env::temp_dir());
2164+
session.push_message(Message::user("run the tool"));
2165+
2166+
let seen = Arc::new(Mutex::new(Vec::new()));
2167+
let seen_clone = Arc::clone(&seen);
2168+
let callback = Arc::new(move |event: QueryEvent| {
2169+
if let QueryEvent::ToolResult {
2170+
tool_name, input, ..
2171+
} = event
2172+
{
2173+
seen_clone.lock().unwrap().push((tool_name, input));
2174+
}
2175+
});
2176+
2177+
query(
2178+
&mut session,
2179+
&TurnConfig {
2180+
model: Model::default(),
2181+
thinking_selection: None,
2182+
},
2183+
Arc::new(SingleToolUseProvider {
2184+
requests: AtomicUsize::new(0),
2185+
}),
2186+
registry,
2187+
&runtime,
2188+
Some(callback),
2189+
)
2190+
.await
2191+
.expect("query should complete");
2192+
2193+
assert_eq!(
2194+
seen.lock().unwrap().as_slice(),
2195+
&[(String::from("mutating_tool"), json!({ "value": 1 }))]
2196+
);
2197+
}
2198+
2199+
#[tokio::test]
2200+
async fn query_tool_result_event_matches_input_delta_by_tool_index() {
2201+
let mut builder = ToolRegistryBuilder::new();
2202+
builder.register_handler("mutating_tool", Arc::new(DisplayContentTool));
2203+
builder.push_spec(ToolSpec {
2204+
name: "mutating_tool".into(),
2205+
description: String::new(),
2206+
input_schema: JsonSchema::object(Default::default(), None, None),
2207+
output_mode: ToolOutputMode::Text,
2208+
execution_mode: ToolExecutionMode::ReadOnly,
2209+
capability_tags: vec![],
2210+
supports_parallel: false,
2211+
preparation_feedback: ToolPreparationFeedback::None,
2212+
});
2213+
let registry = Arc::new(builder.build());
2214+
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));
2215+
2216+
let mut session = SessionState::new(SessionConfig::default(), std::env::temp_dir());
2217+
session.push_message(Message::user("run the tools"));
2218+
2219+
let seen = Arc::new(Mutex::new(Vec::new()));
2220+
let seen_clone = Arc::clone(&seen);
2221+
let callback = Arc::new(move |event: QueryEvent| {
2222+
if let QueryEvent::ToolResult {
2223+
tool_use_id, input, ..
2224+
} = event
2225+
{
2226+
seen_clone.lock().unwrap().push((tool_use_id, input));
2227+
}
2228+
});
2229+
2230+
query(
2231+
&mut session,
2232+
&TurnConfig {
2233+
model: Model::default(),
2234+
thinking_selection: None,
2235+
},
2236+
Arc::new(InterleavedToolUseProvider {
2237+
requests: AtomicUsize::new(0),
2238+
}),
2239+
registry,
2240+
&runtime,
2241+
Some(callback),
2242+
)
2243+
.await
2244+
.expect("query should complete");
2245+
2246+
assert_eq!(
2247+
seen.lock().unwrap().as_slice(),
2248+
&[
2249+
(String::from("tool-1"), json!({ "value": 1 })),
2250+
(String::from("tool-2"), json!({ "value": 2 })),
2251+
]
2252+
);
2253+
}
2254+
20152255
#[tokio::test]
20162256
async fn query_emits_tool_result_display_content() {
20172257
let mut builder = ToolRegistryBuilder::new();

crates/server/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod provider_config;
1313
mod runtime;
1414
mod session;
1515
mod titles;
16+
mod tool_actions;
1617
mod transport;
1718
mod turn;
1819

0 commit comments

Comments
 (0)