Skip to content

Commit 12fa56c

Browse files
Store CodeAgent code outputs in ActionStep (TransformerOptimus#1463)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
1 parent 0028149 commit 12fa56c

4 files changed

Lines changed: 24 additions & 0 deletions

File tree

src/smolagents/agents.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
16681668
else:
16691669
code_action = parse_code_blobs(output_text)
16701670
code_action = fix_final_answer_code(code_action)
1671+
memory_step.code_action = code_action
16711672
except Exception as e:
16721673
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
16731674
raise AgentParsingError(error_msg, self.logger)

src/smolagents/memory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class ActionStep(MemoryStep):
5252
error: AgentError | None = None
5353
model_output_message: ChatMessage | None = None
5454
model_output: str | None = None
55+
code_action: str | None = None
5556
observations: str | None = None
5657
observations_images: list["PIL.Image.Image"] | None = None
5758
action_output: Any = None
@@ -68,6 +69,7 @@ def dict(self):
6869
"error": self.error.dict() if self.error else None,
6970
"model_output_message": self.model_output_message.dict() if self.model_output_message else None,
7071
"model_output": self.model_output,
72+
"code_action": self.code_action,
7173
"observations": self.observations,
7274
"observations_images": [image.tobytes() for image in self.observations_images]
7375
if self.observations_images
@@ -245,5 +247,11 @@ def replay(self, logger: AgentLogger, detailed: bool = False):
245247
logger.log_messages(step.model_input_messages, level=LogLevel.ERROR)
246248
logger.log_markdown(title="Agent output:", content=step.plan, level=LogLevel.ERROR)
247249

250+
def return_full_code(self) -> str:
251+
"""Returns all code actions from the agent's steps, concatenated as a single script."""
252+
return "\n\n".join(
253+
[step.code_action for step in self.steps if isinstance(step, ActionStep) and step.code_action is not None]
254+
)
255+
248256

249257
__all__ = ["AgentMemory"]

tests/test_agents.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from contextlib import nullcontext as does_not_raise
2323
from dataclasses import dataclass
2424
from pathlib import Path
25+
from textwrap import dedent
2526
from typing import Optional
2627
from unittest.mock import MagicMock, patch
2728

@@ -1580,6 +1581,11 @@ def test_syntax_error_show_offending_lines(self):
15801581
assert isinstance(output, AgentText)
15811582
assert output == "got an error"
15821583
assert ' print("Failing due to unexpected indent")' in str(agent.memory.steps)
1584+
assert isinstance(agent.memory.steps[-2], ActionStep)
1585+
assert agent.memory.steps[-2].code_action == dedent("""a = 2
1586+
b = a * 2
1587+
print("Failing due to unexpected indent")
1588+
print("Ok, calculation done!")""")
15831589

15841590
def test_end_code_appending(self):
15851591
# Checking original output message

tests/test_memory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ def test_initialization(self):
2222
assert memory.system_prompt.system_prompt == system_prompt
2323
assert memory.steps == []
2424

25+
def test_return_all_code_actions(self):
26+
memory = AgentMemory(system_prompt="This is a system prompt.")
27+
memory.steps = [
28+
ActionStep(step_number=1, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('Hello')"),
29+
ActionStep(step_number=2, timing=Timing(start_time=0.0, end_time=1.0), code_action=None),
30+
ActionStep(step_number=3, timing=Timing(start_time=0.0, end_time=1.0), code_action="print('World')"),
31+
] # type: ignore
32+
assert memory.return_full_code() == "print('Hello')\n\nprint('World')"
33+
2534

2635
class TestMemoryStep:
2736
def test_initialization(self):

0 commit comments

Comments
 (0)