Skip to content

Commit 612b07e

Browse files
authored
fix: Populate tool_args correctly for steering (#1531)
1 parent 66d3db2 commit 612b07e

5 files changed

Lines changed: 59 additions & 5 deletions

File tree

src/strands/experimental/steering/context_providers/ledger_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext
4747
tool_call_entry = {
4848
"timestamp": datetime.now().isoformat(),
4949
"tool_name": event.tool_use.get("name"),
50-
"tool_args": event.tool_use.get("arguments", {}),
50+
"tool_args": event.tool_use.get("input", {}),
5151
"status": "pending",
5252
}
5353
ledger["tool_calls"].append(tool_call_entry)

tests/strands/experimental/steering/context_providers/test_ledger_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_ledger_before_tool_call_new_ledger(mock_datetime):
3030
callback = LedgerBeforeToolCall()
3131
steering_context = SteeringContext()
3232

33-
tool_use = {"name": "test_tool", "arguments": {"param": "value"}}
33+
tool_use = {"name": "test_tool", "input": {"param": "value"}}
3434
event = Mock(spec=BeforeToolCallEvent)
3535
event.tool_use = tool_use
3636

@@ -65,7 +65,7 @@ def test_ledger_before_tool_call_existing_ledger(mock_datetime):
6565
}
6666
steering_context.data.set("ledger", existing_ledger)
6767

68-
tool_use = {"name": "new_tool", "arguments": {"param": "value"}}
68+
tool_use = {"name": "new_tool", "input": {"param": "value"}}
6969
event = Mock(spec=BeforeToolCallEvent)
7070
event.tool_use = tool_use
7171

tests/strands/experimental/steering/core/test_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def test_context_callbacks_receive_steering_context():
241241

242242
# Create a mock event and call the callback
243243
event = Mock(spec=BeforeToolCallEvent)
244-
event.tool_use = {"name": "test_tool", "arguments": {}}
244+
event.tool_use = {"name": "test_tool", "input": {}}
245245

246246
# The callback should execute without error and update the steering context
247247
before_callback(event)

tests_integ/steering/test_model_steering.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Integration tests for model steering (steer_after_model)."""
22

33
from strands import Agent, tool
4+
from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider
45
from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed
56
from strands.experimental.steering.core.handler import SteeringHandler
67
from strands.types.content import Message
@@ -154,7 +155,7 @@ class ForceToolUsageHandler(SteeringHandler):
154155
"""Handler that forces a specific tool to be used before allowing termination."""
155156

156157
def __init__(self, required_tool: str):
157-
super().__init__()
158+
super().__init__(context_providers=[LedgerProvider()])
158159
self.required_tool = required_tool
159160
self.tool_was_used = False
160161
self.guidance_given = False
@@ -171,6 +172,15 @@ async def steer_after_model(
171172
for block in content_blocks:
172173
if "toolUse" in block and block["toolUse"].get("name") == self.required_tool:
173174
self.tool_was_used = True
175+
176+
# Verify tool is in the ledger
177+
ledger = self.steering_context.data.get("ledger")
178+
if ledger:
179+
tool_calls = ledger.get("tool_calls", [])
180+
assert any(tc.get("tool_name") == self.required_tool for tc in tool_calls), (
181+
f"{self.required_tool} should be in ledger when tool_was_used=True"
182+
)
183+
174184
return Proceed(reason="Required tool was used")
175185

176186
# If tool wasn't used and we haven't guided yet, force its usage

tests_integ/steering/test_tool_steering.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import pytest
44

55
from strands import Agent, tool
6+
from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider
67
from strands.experimental.steering.core.action import Guide, Interrupt, Proceed
8+
from strands.experimental.steering.core.handler import SteeringHandler
79
from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler
810

911

@@ -98,3 +100,45 @@ def test_agent_with_tool_steering_e2e():
98100
notification_metrics = tool_metrics["send_notification"]
99101
assert notification_metrics.call_count >= 1, "send_notification should have been called"
100102
assert notification_metrics.success_count >= 1, "send_notification should have succeeded"
103+
104+
105+
def test_ledger_captures_tool_calls():
106+
"""Test that ledger correctly captures tool call information."""
107+
108+
class LedgerCheckingHandler(SteeringHandler):
109+
def __init__(self):
110+
super().__init__(context_providers=[LedgerProvider()])
111+
112+
async def steer_before_tool(self, *, agent, tool_use, **kwargs):
113+
ledger = self.steering_context.data.get("ledger")
114+
assert ledger is not None, "Ledger should exist"
115+
assert "tool_calls" in ledger, "Ledger should have tool_calls"
116+
117+
# Find the current tool call in the ledger
118+
tool_calls = ledger["tool_calls"]
119+
current_call = next((tc for tc in tool_calls if tc["tool_name"] == tool_use["name"]), None)
120+
assert current_call is not None, f"{tool_use['name']} should be in ledger"
121+
assert current_call["tool_args"] == tool_use["input"], "tool_args should match input"
122+
assert current_call["status"] == "pending", "Status should be pending before execution"
123+
124+
return Proceed(reason="Ledger verified")
125+
126+
handler = LedgerCheckingHandler()
127+
agent = Agent(tools=[send_notification], hooks=[handler])
128+
129+
agent("Send a notification to alice saying test message")
130+
131+
# Verify the ledger has the completed tool call
132+
ledger = handler.steering_context.data.get("ledger")
133+
assert ledger is not None
134+
assert len(ledger["tool_calls"]) >= 1, "At least one tool call should be recorded"
135+
136+
# Check the tool call details
137+
tool_call = ledger["tool_calls"][-1]
138+
assert tool_call["tool_name"] == "send_notification"
139+
assert "tool_args" in tool_call
140+
assert tool_call["tool_args"]["recipient"] == "alice"
141+
assert tool_call["tool_args"]["message"] == "test message"
142+
assert tool_call["status"] == "success"
143+
assert "completion_timestamp" in tool_call
144+
assert tool_call["error"] is None

0 commit comments

Comments
 (0)