Skip to content

Commit 2417fb9

Browse files
committed
tui: add preparing state for large tool calls
1 parent 05df1b5 commit 2417fb9

12 files changed

Lines changed: 488 additions & 48 deletions

File tree

crates/core/src/query.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashSet;
12
use std::sync::Arc;
23
use std::time::Duration;
34

@@ -610,6 +611,7 @@ pub async fn query(
610611
let mut assistant_text = String::new();
611612
let mut reasoning_text = String::new();
612613
let mut tool_uses: Vec<(String, String, serde_json::Value, String, bool)> = Vec::new();
614+
let mut emitted_tool_use_starts: HashSet<String> = HashSet::new();
613615
let mut final_response = None;
614616
let mut stop_reason = None;
615617

@@ -631,6 +633,13 @@ pub async fn query(
631633
Ok(StreamEvent::ToolCallStart {
632634
id, name, input, ..
633635
}) => {
636+
if emitted_tool_use_starts.insert(id.clone()) {
637+
emit(QueryEvent::ToolUseStart {
638+
id: id.clone(),
639+
name: name.clone(),
640+
input: input.clone(),
641+
});
642+
}
634643
tool_uses.push((id, name, input, String::new(), false));
635644
}
636645
Ok(StreamEvent::ToolCallInputDelta { partial_json, .. }) => {
@@ -784,11 +793,13 @@ pub async fn query(
784793
} else {
785794
initial_input
786795
};
787-
emit(QueryEvent::ToolUseStart {
788-
id: id.clone(),
789-
name: name.clone(),
790-
input: input.clone(),
791-
});
796+
if emitted_tool_use_starts.insert(id.clone()) {
797+
emit(QueryEvent::ToolUseStart {
798+
id: id.clone(),
799+
name: name.clone(),
800+
input: input.clone(),
801+
});
802+
}
792803
assistant_content.push(ContentBlock::ToolUse {
793804
id: id.clone(),
794805
name: name.clone(),
@@ -977,6 +988,7 @@ mod tests {
977988
use devo_protocol::Usage;
978989
use devo_provider::ModelProviderSDK;
979990
use devo_safety::PermissionMode;
991+
use devo_tools::ToolPreparationFeedback;
980992
use devo_tools::ToolRegistry;
981993
use devo_tools::ToolRuntime;
982994
use devo_tools::errors::ToolExecutionError;
@@ -1462,6 +1474,7 @@ mod tests {
14621474
execution_mode: ToolExecutionMode::Mutating,
14631475
capability_tags: vec![],
14641476
supports_parallel: false,
1477+
preparation_feedback: ToolPreparationFeedback::None,
14651478
});
14661479
let registry = Arc::new(builder.build());
14671480
let deny_checker = PermissionChecker::new(|request| {
@@ -1957,6 +1970,7 @@ mod tests {
19571970
execution_mode: ToolExecutionMode::Mutating,
19581971
capability_tags: vec![],
19591972
supports_parallel: false,
1973+
preparation_feedback: ToolPreparationFeedback::None,
19601974
});
19611975
let registry = Arc::new(builder.build());
19621976
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));
@@ -2010,6 +2024,7 @@ mod tests {
20102024
execution_mode: ToolExecutionMode::ReadOnly,
20112025
capability_tags: vec![],
20122026
supports_parallel: false,
2027+
preparation_feedback: ToolPreparationFeedback::None,
20132028
});
20142029
let registry = Arc::new(builder.build());
20152030
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));
@@ -2067,6 +2082,7 @@ mod tests {
20672082
execution_mode: ToolExecutionMode::Mutating,
20682083
capability_tags: vec![],
20692084
supports_parallel: false,
2085+
preparation_feedback: ToolPreparationFeedback::None,
20702086
});
20712087
let registry = Arc::new(builder.build());
20722088
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));
@@ -2144,6 +2160,7 @@ mod tests {
21442160
execution_mode: ToolExecutionMode::ReadOnly,
21452161
capability_tags: vec![],
21462162
supports_parallel: true,
2163+
preparation_feedback: ToolPreparationFeedback::None,
21472164
});
21482165
let registry = Arc::new(builder.build());
21492166
let runtime = ToolRuntime::new_without_permissions(Arc::clone(&registry));

crates/protocol/src/event.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ pub struct CommandExecutionPayload {
6161
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
6262
pub struct FileChangePayload {
6363
pub tool_call_id: String,
64+
#[serde(default, skip_serializing_if = "Option::is_none")]
65+
pub tool_name: Option<String>,
6466
pub changes: Vec<(std::path::PathBuf, FileChange)>,
6567
pub is_error: bool,
6668
}

crates/server/src/runtime/turn_exec.rs

Lines changed: 149 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ use std::sync::Arc;
33

44
use super::*;
55
use crate::{FileChangePayload, TurnPlanStepPayload, TurnPlanUpdatedPayload};
6+
use devo_tools::tool_spec::ToolPreparationFeedback;
67
use devo_utils::git_op::extract_paths_from_patch;
78
use tokio::sync::mpsc;
89

910
struct PendingToolCall {
10-
item_id: ItemId,
11-
item_seq: u64,
11+
item_id: Option<ItemId>,
12+
item_seq: Option<u64>,
1213
input: serde_json::Value,
1314
is_command_execution: bool,
1415
command: String,
@@ -185,7 +186,7 @@ fn command_execution_item_id_for_progress(
185186
pending_tool_calls
186187
.get(tool_use_id)
187188
.filter(|pending| pending.is_command_execution)
188-
.map(|pending| pending.item_id)
189+
.and_then(|pending| pending.item_id)
189190
}
190191

191192
impl ServerRuntime {
@@ -377,7 +378,13 @@ impl ServerRuntime {
377378
}
378379
let is_command_execution = is_unified_exec_tool(&name);
379380
let command = command_display_from_input(&name, &input);
380-
let item_kind = if is_file_change_tool(&name) {
381+
let preparation_feedback =
382+
runtime.deps.registry.preparation_feedback(&name);
383+
let item_kind = if preparation_feedback
384+
== ToolPreparationFeedback::LiveOnly
385+
{
386+
ItemKind::ToolCall
387+
} else if is_file_change_tool(&name) {
381388
ItemKind::FileChange
382389
} else if is_command_execution {
383390
ItemKind::CommandExecution
@@ -386,9 +393,22 @@ impl ServerRuntime {
386393
} else {
387394
ItemKind::ToolCall
388395
};
389-
let started_payload = if is_file_change_tool(&name) {
396+
let started_payload = if preparation_feedback
397+
== ToolPreparationFeedback::LiveOnly
398+
{
399+
serde_json::to_value(ToolCallPayload {
400+
tool_call_id: id.clone(),
401+
tool_name: name.clone(),
402+
parameters: input.clone(),
403+
command_actions: command_actions_from_tool_input(
404+
&name, &command, &input,
405+
),
406+
})
407+
.expect("serialize tool call payload")
408+
} else if is_file_change_tool(&name) {
390409
serde_json::to_value(FileChangePayload {
391410
tool_call_id: id.clone(),
411+
tool_name: Some(name.clone()),
392412
changes: Vec::new(),
393413
is_error: false,
394414
})
@@ -422,14 +442,21 @@ impl ServerRuntime {
422442
})
423443
.expect("serialize tool call payload")
424444
};
425-
let (item_id, item_seq) = runtime
426-
.start_item(
427-
session_id,
428-
turn_for_events.turn_id,
429-
item_kind,
430-
started_payload,
431-
)
432-
.await;
445+
let (item_id, item_seq) = if preparation_feedback
446+
== ToolPreparationFeedback::LiveOnly
447+
{
448+
let (item_id, item_seq) = runtime
449+
.start_item(
450+
session_id,
451+
turn_for_events.turn_id,
452+
item_kind,
453+
started_payload,
454+
)
455+
.await;
456+
(Some(item_id), Some(item_seq))
457+
} else {
458+
(None, None)
459+
};
433460
pending_tool_calls.insert(
434461
id,
435462
PendingToolCall {
@@ -451,7 +478,102 @@ impl ServerRuntime {
451478
let tool_name = tool_names_by_id.get(&tool_use_id).cloned();
452479
// First complete the pending ToolCall item so its item/completed
453480
// arrives before the ToolResult item/completed.
454-
if let Some(pending) = pending_tool_calls.remove(&tool_use_id) {
481+
if let Some(mut pending) = pending_tool_calls.remove(&tool_use_id) {
482+
if pending.item_id.is_none() || pending.item_seq.is_none() {
483+
let started_payload = if let Some(tool_name) = tool_name.clone() {
484+
let item_kind = if runtime
485+
.deps
486+
.registry
487+
.preparation_feedback(&tool_name)
488+
== ToolPreparationFeedback::LiveOnly
489+
{
490+
ItemKind::ToolCall
491+
} else if is_file_change_tool(&tool_name) {
492+
ItemKind::FileChange
493+
} else if pending.is_command_execution {
494+
ItemKind::CommandExecution
495+
} else if is_plan_tool(&tool_name) {
496+
ItemKind::Plan
497+
} else {
498+
ItemKind::ToolCall
499+
};
500+
let payload = if runtime
501+
.deps
502+
.registry
503+
.preparation_feedback(&tool_name)
504+
== ToolPreparationFeedback::LiveOnly
505+
{
506+
serde_json::to_value(ToolCallPayload {
507+
tool_call_id: tool_use_id.clone(),
508+
tool_name: tool_name.clone(),
509+
parameters: pending.input.clone(),
510+
command_actions: command_actions_from_tool_input(
511+
&tool_name,
512+
&pending.command,
513+
&pending.input,
514+
),
515+
})
516+
.expect("serialize tool call payload")
517+
} else if is_file_change_tool(&tool_name) {
518+
serde_json::to_value(FileChangePayload {
519+
tool_call_id: tool_use_id.clone(),
520+
tool_name: Some(tool_name.clone()),
521+
changes: Vec::new(),
522+
is_error: false,
523+
})
524+
.expect("serialize file change payload")
525+
} else if pending.is_command_execution {
526+
serde_json::to_value(CommandExecutionPayload {
527+
tool_call_id: tool_use_id.clone(),
528+
tool_name: tool_name.clone(),
529+
command: pending.command.clone(),
530+
source: devo_protocol::protocol::ExecCommandSource::Agent,
531+
command_actions: command_actions_from_tool_input(
532+
&tool_name,
533+
&pending.command,
534+
&pending.input,
535+
),
536+
output: None,
537+
is_error: false,
538+
})
539+
.expect("serialize command execution payload")
540+
} else if is_plan_tool(&tool_name) {
541+
serde_json::json!({
542+
"title": "Plan",
543+
"text": ""
544+
})
545+
} else {
546+
serde_json::to_value(ToolCallPayload {
547+
tool_call_id: tool_use_id.clone(),
548+
tool_name: tool_name.clone(),
549+
parameters: pending.input.clone(),
550+
command_actions: command_actions_from_tool_input(
551+
&tool_name,
552+
&pending.command,
553+
&pending.input,
554+
),
555+
})
556+
.expect("serialize tool call payload")
557+
};
558+
let (item_id, item_seq) = runtime
559+
.start_item(
560+
session_id,
561+
turn_for_events.turn_id,
562+
item_kind.clone(),
563+
payload,
564+
)
565+
.await;
566+
pending.item_id = Some(item_id);
567+
pending.item_seq = Some(item_seq);
568+
item_kind
569+
} else {
570+
ItemKind::ToolCall
571+
};
572+
let _ = started_payload;
573+
}
574+
575+
let pending_item_id = pending.item_id.expect("pending item id");
576+
let pending_item_seq = pending.item_seq.expect("pending item seq");
455577
if let Some(tool_name) = tool_name.clone()
456578
&& is_plan_tool(&tool_name)
457579
{
@@ -479,8 +601,8 @@ impl ServerRuntime {
479601
.complete_item(
480602
session_id,
481603
turn_for_events.turn_id,
482-
pending.item_id,
483-
pending.item_seq,
604+
pending_item_id,
605+
pending_item_seq,
484606
ItemKind::Plan,
485607
TurnItem::Plan(TextItem {
486608
text: output_json.to_string(),
@@ -617,8 +739,8 @@ impl ServerRuntime {
617739
.complete_item(
618740
session_id,
619741
turn_for_events.turn_id,
620-
pending.item_id,
621-
pending.item_seq,
742+
pending_item_id,
743+
pending_item_seq,
622744
ItemKind::FileChange,
623745
TurnItem::ToolResult(ToolResultItem {
624746
tool_call_id: tool_use_id.clone(),
@@ -629,6 +751,7 @@ impl ServerRuntime {
629751
}),
630752
serde_json::to_value(FileChangePayload {
631753
tool_call_id: tool_use_id.clone(),
754+
tool_name: Some(tool_name.clone()),
632755
changes,
633756
is_error,
634757
})
@@ -669,8 +792,8 @@ impl ServerRuntime {
669792
.complete_item(
670793
session_id,
671794
turn_for_events.turn_id,
672-
pending.item_id,
673-
pending.item_seq,
795+
pending_item_id,
796+
pending_item_seq,
674797
ItemKind::CommandExecution,
675798
TurnItem::CommandExecution(CommandExecutionItem {
676799
tool_call_id: tool_use_id.clone(),
@@ -700,8 +823,8 @@ impl ServerRuntime {
700823
.complete_item(
701824
session_id,
702825
turn_for_events.turn_id,
703-
pending.item_id,
704-
pending.item_seq,
826+
pending_item_id,
827+
pending_item_seq,
705828
ItemKind::ToolCall,
706829
TurnItem::ToolCall(ToolCallItem {
707830
tool_call_id: tool_use_id.clone(),
@@ -1374,8 +1497,8 @@ mod tests {
13741497
pending_tool_calls.insert(
13751498
"exec".to_string(),
13761499
PendingToolCall {
1377-
item_id: command_item_id,
1378-
item_seq: 1,
1500+
item_id: Some(command_item_id),
1501+
item_seq: Some(1),
13791502
input: serde_json::json!({}),
13801503
is_command_execution: true,
13811504
command: "cargo test".to_string(),
@@ -1384,8 +1507,8 @@ mod tests {
13841507
pending_tool_calls.insert(
13851508
"read".to_string(),
13861509
PendingToolCall {
1387-
item_id: tool_item_id,
1388-
item_seq: 2,
1510+
item_id: Some(tool_item_id),
1511+
item_seq: Some(2),
13891512
input: serde_json::json!({}),
13901513
is_command_execution: false,
13911514
command: String::new(),

crates/server/tests/skills_integration.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ fn auto_review_registry(calls: Arc<std::sync::atomic::AtomicUsize>) -> Arc<ToolR
299299
execution_mode: ToolExecutionMode::Mutating,
300300
capability_tags: vec![devo_tools::ToolCapabilityTag::WriteFiles],
301301
supports_parallel: false,
302+
preparation_feedback: devo_tools::ToolPreparationFeedback::None,
302303
});
303304
Arc::new(builder.build())
304305
}
@@ -951,6 +952,7 @@ async fn turn_steer_injects_resolved_skill_into_next_model_request() -> Result<(
951952
execution_mode: ToolExecutionMode::ReadOnly,
952953
capability_tags: vec![],
953954
supports_parallel: true,
955+
preparation_feedback: devo_tools::ToolPreparationFeedback::None,
954956
});
955957
let registry = Arc::new(builder.build());
956958
let provider = Arc::new(SteerCapturingProvider::default());

0 commit comments

Comments
 (0)