Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 465 additions & 0 deletions tests/agent/test_tool_confirmation.py

Large diffs are not rendered by default.

113 changes: 105 additions & 8 deletions trae_agent/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from trae_agent.tools.ckg.ckg_database import clear_older_ckg
from trae_agent.tools.docker_tool_executor import DockerToolExecutor
from trae_agent.utils.cli import CLIConsole
from trae_agent.utils.config import AgentConfig, ModelConfig
from trae_agent.utils.cli.cli_console import ToolConfirmationResult
from trae_agent.utils.config import AgentConfig, ModelConfig, ToolConfirmationConfig
from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
from trae_agent.utils.llm_clients.llm_client import LLMClient
from trae_agent.utils.trajectory_recorder import TrajectoryRecorder
Expand Down Expand Up @@ -74,6 +75,12 @@ def __init__(

self._cli_console: CLIConsole | None = None

# Tool confirmation state
self._tool_confirmation_config: ToolConfirmationConfig = agent_config.tool_confirmation
self._tool_confirmation_approved_all: bool = False
self._allowed_command_prefixes: list[str] = [] # For bash prefix matching
self._allowed_tool_names: set[str] = set() # For non-bash tool name matching

# Trajectory recorder
self._trajectory_recorder: TrajectoryRecorder | None = None

Expand Down Expand Up @@ -328,18 +335,58 @@ async def _tool_call_handler(
step.tool_calls = tool_calls
self._update_cli_console(step)

if self._model_config.parallel_tool_calls:
tool_results = await self._tool_caller.parallel_tool_call(tool_calls)
# Tool confirmation logic
approved_calls: list[ToolCall] = []
rejected_results: list[ToolResult] = []

if (
self._tool_confirmation_config.enabled
and not self._tool_confirmation_approved_all
and self._cli_console is not None
):
for tool_call in tool_calls:
if self._is_tool_call_allowed(tool_call):
approved_calls.append(tool_call)
elif self._should_confirm_tool(tool_call.name):
confirmation = self._cli_console.get_tool_confirmation(tool_call)
if confirmation == ToolConfirmationResult.APPROVE:
approved_calls.append(tool_call)
elif confirmation == ToolConfirmationResult.APPROVE_ALL:
self._add_allowed_pattern(tool_call)
approved_calls.append(tool_call)
else: # REJECT
rejected_results.append(
ToolResult(
call_id=tool_call.call_id,
name=tool_call.name,
success=False,
error=f"Tool call '{tool_call.name}' was rejected by the user.",
id=tool_call.id,
)
)
else:
approved_calls.append(tool_call)
else:
tool_results = await self._tool_caller.sequential_tool_call(tool_calls)
step.tool_results = tool_results
approved_calls = list(tool_calls)

# Execute approved tool calls
if approved_calls:
if self._model_config.parallel_tool_calls:
tool_results = await self._tool_caller.parallel_tool_call(approved_calls)
else:
tool_results = await self._tool_caller.sequential_tool_call(approved_calls)
else:
tool_results = []

all_results = tool_results + rejected_results
step.tool_results = all_results
self._update_cli_console(step)
for tool_result in tool_results:
# Add tool result to conversation

for tool_result in all_results:
message = LLMMessage(role="user", tool_result=tool_result)
messages.append(message)

reflection = self.reflect_on_result(tool_results)
reflection = self.reflect_on_result(all_results)
if reflection:
step.state = AgentStepState.REFLECTING
step.reflection = reflection
Expand All @@ -350,3 +397,53 @@ async def _tool_call_handler(
messages.append(LLMMessage(role="assistant", content=reflection))

return messages

def _should_confirm_tool(self, tool_name: str) -> bool:
"""Check whether a tool with the given name requires user confirmation."""
config = self._tool_confirmation_config
if not config.enabled:
return False
required_list = config.tools_requiring_confirmation
if required_list is None:
return True
normalized_name = tool_name.lower().replace("_", "")
return any(
normalized_name == required.lower().replace("_", "") for required in required_list
)

def _is_tool_call_allowed(self, tool_call: ToolCall) -> bool:
"""Check if a tool call matches a previously approved pattern."""
if self._tool_confirmation_approved_all:
return True

tool_name = tool_call.name.lower().replace("_", "")

# For bash: check command prefix matching
if tool_name == "bash":
command = str(tool_call.arguments.get("command", ""))
for prefix in self._allowed_command_prefixes:
if command.startswith(prefix):
return True

# For non-bash tools: check tool name matching
return tool_name in self._allowed_tool_names

def _add_allowed_pattern(self, tool_call: ToolCall) -> None:
"""Add an allowed pattern based on the tool call (APPROVE_ALL result)."""
tool_name = tool_call.name.lower().replace("_", "")

# For bash: add command prefix (first two tokens)
if tool_name == "bash":
command = str(tool_call.arguments.get("command", ""))
tokens = command.split()
prefix = " ".join(tokens[:2]) if len(tokens) >= 2 else command
self._allowed_command_prefixes.append(prefix)
else:
# For non-bash tools: allow by tool name
self._allowed_tool_names.add(tool_name)

def reset_tool_confirmation_state(self) -> None:
"""Reset the tool confirmation state for a new task."""
self._tool_confirmation_approved_all = False
self._allowed_command_prefixes.clear()
self._allowed_tool_names.clear()
1 change: 1 addition & 0 deletions trae_agent/agent/trae_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def new_task(
tool_names: list[str] | None = None,
):
"""Create a new task."""
self.reset_tool_confirmation_state()
self._task: str = task

if tool_names is None and len(self._tools) == 0:
Expand Down
32 changes: 32 additions & 0 deletions trae_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def cli():
help="Type of agent to use (trae_agent)",
default="trae_agent",
)
@click.option(
"--confirm-tools",
is_flag=True,
default=False,
help="Require user confirmation before executing tool calls",
)
def run(
task: str | None,
file_path: str | None,
Expand All @@ -204,6 +210,7 @@ def run(
trajectory_file: str | None = None,
console_type: str | None = "simple",
agent_type: str | None = "trae_agent",
confirm_tools: bool = False,
# --- Add Docker Mode ---
docker_image: str | None = None,
docker_container_id: str | None = None,
Expand Down Expand Up @@ -305,6 +312,15 @@ def run(
console.print("[red]Error: agent_type is required.[/red]")
sys.exit(1)

# Apply --confirm-tools flag to config
if confirm_tools and config.trae_agent:
from trae_agent.utils.config import ToolConfirmationConfig

config.trae_agent.tool_confirmation = ToolConfirmationConfig(
enabled=True,
tools_requiring_confirmation=["bash", "str_replace_based_edit_tool", "json_edit_tool"],
)

# Create CLI Console
console_mode = ConsoleMode.RUN
if console_type:
Expand Down Expand Up @@ -437,6 +453,12 @@ def run(
help="Type of agent to use (trae_agent)",
default="trae_agent",
)
@click.option(
"--confirm-tools",
is_flag=True,
default=False,
help="Require user confirmation before executing tool calls",
)
def interactive(
provider: str | None = None,
model: str | None = None,
Expand All @@ -447,6 +469,7 @@ def interactive(
trajectory_file: str | None = None,
console_type: str | None = "simple",
agent_type: str | None = "trae_agent",
confirm_tools: bool = False,
):
"""
This function starts an interactive session with Trae Agent.
Expand All @@ -472,6 +495,15 @@ def interactive(
console.print("[red]Error: trae_agent configuration is required in the config file.[/red]")
sys.exit(1)

# Apply --confirm-tools flag to config
if confirm_tools and config.trae_agent:
from trae_agent.utils.config import ToolConfirmationConfig

config.trae_agent.tool_confirmation = ToolConfirmationConfig(
enabled=True,
tools_requiring_confirmation=["bash", "str_replace_based_edit_tool", "json_edit_tool"],
)

# Create CLI Console for interactive mode
console_mode = ConsoleMode.INTERACTIVE
if console_type:
Expand Down
3 changes: 2 additions & 1 deletion trae_agent/utils/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""CLI console module for Trae Agent."""

from .cli_console import CLIConsole, ConsoleMode, ConsoleType
from .cli_console import CLIConsole, ConsoleMode, ConsoleType, ToolConfirmationResult
from .console_factory import ConsoleFactory
from .rich_console import RichCLIConsole
from .simple_console import SimpleCLIConsole
Expand All @@ -12,6 +12,7 @@
"CLIConsole",
"ConsoleMode",
"ConsoleType",
"ToolConfirmationResult",
"SimpleCLIConsole",
"RichCLIConsole",
"ConsoleFactory",
Expand Down
21 changes: 21 additions & 0 deletions trae_agent/utils/cli/cli_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
from rich.table import Table

from trae_agent.agent.agent_basics import AgentExecution, AgentStep, AgentStepState
from trae_agent.tools.base import ToolCall
from trae_agent.utils.config import LakeviewConfig
from trae_agent.utils.lake_view import LakeView


class ToolConfirmationResult(Enum):
"""Result of a tool confirmation request."""

APPROVE = "approve"
REJECT = "reject"
APPROVE_ALL = "approve_all" # Approve this and all matching future tool calls


class ConsoleMode(Enum):
"""Console operation modes."""

Expand Down Expand Up @@ -115,6 +124,18 @@ def stop(self):
"""Stop the console and cleanup resources."""
pass

@abstractmethod
def get_tool_confirmation(self, tool_call: ToolCall) -> ToolConfirmationResult:
"""Ask the user for confirmation before executing a tool call.

Args:
tool_call: The tool call that is about to be executed.

Returns:
ToolConfirmationResult indicating user's decision.
"""
pass

def set_lakeview(self, lakeview_config: LakeviewConfig | None = None):
"""Set the lakeview configuration for the console."""
if lakeview_config:
Expand Down
36 changes: 36 additions & 0 deletions trae_agent/utils/cli/rich_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CLIConsole,
ConsoleMode,
ConsoleStep,
ToolConfirmationResult,
generate_agent_step_table,
)
from trae_agent.utils.config import LakeviewConfig
Expand Down Expand Up @@ -348,6 +349,41 @@ def get_working_dir_input(self) -> str:
# For now, return current directory. Could be enhanced with a dialog
return os.getcwd()

@override
def get_tool_confirmation(self, tool_call) -> ToolConfirmationResult:
"""Ask the user for confirmation before executing a tool call."""
tool_name = tool_call.name
if tool_name == "bash":
command = tool_call.arguments.get("command", "")
detail = f"[bold]Command:[/bold] {command}"
else:
detail = f"[bold]Arguments:[/bold] {tool_call.arguments}"

# Display in the TUI log
if self.app and self.app.execution_log:
_ = self.app.execution_log.write(
Panel(
f"[bold]Tool:[/bold] {tool_name}\n{detail}",
title="Tool Confirmation Required",
border_style="yellow",
)
)
_ = self.app.execution_log.write(
"[bold]Options:[/bold] (y)es / (n)o / (a)lways approve this pattern"
)

while True:
try:
response = input("[y/n/a]: ").strip().lower()
if response in ("y", "yes"):
return ToolConfirmationResult.APPROVE
elif response in ("n", "no"):
return ToolConfirmationResult.REJECT
elif response in ("a", "always", "all"):
return ToolConfirmationResult.APPROVE_ALL
except (EOFError, KeyboardInterrupt):
return ToolConfirmationResult.REJECT

@override
def stop(self):
"""Stop the console and cleanup resources."""
Expand Down
37 changes: 37 additions & 0 deletions trae_agent/utils/cli/simple_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CLIConsole,
ConsoleMode,
ConsoleStep,
ToolConfirmationResult,
generate_agent_step_table,
)
from trae_agent.utils.config import LakeviewConfig
Expand Down Expand Up @@ -213,6 +214,42 @@ def get_working_dir_input(self) -> str:
except (EOFError, KeyboardInterrupt):
return ""

@override
def get_tool_confirmation(self, tool_call) -> ToolConfirmationResult:
"""Ask the user for confirmation before executing a tool call."""
tool_name = tool_call.name
# Show command for bash, arguments summary for others
if tool_name == "bash":
command = tool_call.arguments.get("command", "")
detail = f"[bold]Command:[/bold] {command}"
else:
detail = f"[bold]Arguments:[/bold] {tool_call.arguments}"

self.console.print(
Panel(
f"[bold]Tool:[/bold] {tool_name}\n{detail}",
title="Tool Confirmation Required",
border_style="yellow",
)
)
self.console.print(
"[bold]Options:[/bold] (y)es / (n)o / (a)lways approve this pattern"
)

while True:
try:
response = input("[y/n/a]: ").strip().lower()
if response in ("y", "yes"):
return ToolConfirmationResult.APPROVE
elif response in ("n", "no"):
return ToolConfirmationResult.REJECT
elif response in ("a", "always", "all"):
return ToolConfirmationResult.APPROVE_ALL
else:
self.console.print("[yellow]Please enter y, n, or a[/yellow]")
except (EOFError, KeyboardInterrupt):
return ToolConfirmationResult.REJECT

@override
def stop(self):
"""Stop the console and cleanup resources."""
Expand Down
Loading