Skip to content

Commit 9d2b9a1

Browse files
committed
add _process_content_object function in _rehydration_utils file to extract output from event.content object before assigning it to child.output, in _reconstruct_node_states
1 parent 4006fe4 commit 9d2b9a1

2 files changed

Lines changed: 64 additions & 2 deletions

File tree

src/google/adk/workflow/utils/_rehydration_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ def _extract_schema_from_event(event: Event, interrupt_id: str) -> Any | None:
9696
return None
9797

9898

99+
def _process_content_object(event: Event) -> Any:
100+
"""Extracts output from event.content."""
101+
if not event.content or not getattr(event.content, 'parts', None):
102+
return None
103+
104+
text = ''.join(
105+
p.text for p in event.content.parts if p.text and not p.thought
106+
)
107+
text = text.strip()
108+
109+
if not text:
110+
return None
111+
112+
try:
113+
return json.loads(text)
114+
except (json.JSONDecodeError, ValueError):
115+
return text
116+
117+
99118
def _validate_resume_response(response_data: Any, schema: Any) -> Any:
100119
"""Validates and coerces resume response data against a schema.
101120
@@ -275,7 +294,7 @@ def get_owner_key(event_path_builder: _NodePathBuilder) -> str | None:
275294
child.output = event.output
276295
child.branch = event.branch
277296
elif use_message_as_output:
278-
child.output = event.content
297+
child.output = _process_content_object(event)
279298
if event.actions and event.actions.route is not None:
280299
child.route = event.actions.route
281300
if event.actions and event.actions.transfer_to_agent is not None:

tests/unittests/workflow/utils/test_rehydration_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from google.adk.events.event import NodeInfo
1919
from google.adk.events.request_input import RequestInput
2020
from google.adk.workflow.utils._rehydration_utils import _ChildScanState
21+
from google.adk.workflow.utils._rehydration_utils import _process_content_object
2122
from google.adk.workflow.utils._rehydration_utils import _reconstruct_node_states
2223
from google.adk.workflow.utils._rehydration_utils import _unwrap_response
2324
from google.adk.workflow.utils._rehydration_utils import _validate_resume_response
@@ -103,6 +104,48 @@ def test_roundtrip_wrap_unwrap_dict(self):
103104
assert _unwrap_response(_wrap_response(d)) == d
104105

105106

107+
# --- _process_content_object ---
108+
109+
110+
class TestProcessContentObject:
111+
112+
def test_extracts_plain_text(self):
113+
content = types.Content(parts=[types.Part(text="hello world")])
114+
event = Event(content=content, invocation_id="id")
115+
assert _process_content_object(event) == "hello world"
116+
117+
def test_parses_json_text(self):
118+
content = types.Content(parts=[types.Part(text='{"foo": "bar"}')])
119+
event = Event(content=content, invocation_id="id")
120+
assert _process_content_object(event) == {"foo": "bar"}
121+
122+
def test_joins_multiple_parts(self):
123+
content = types.Content(
124+
parts=[types.Part(text="hello "), types.Part(text="world")]
125+
)
126+
event = Event(content=content, invocation_id="id")
127+
assert _process_content_object(event) == "hello world"
128+
129+
def test_filters_thought_parts(self):
130+
content = types.Content(
131+
parts=[
132+
types.Part(text="thinking...", thought=True),
133+
types.Part(text='{"answer": 42}'),
134+
]
135+
)
136+
event = Event(content=content, invocation_id="id")
137+
assert _process_content_object(event) == {"answer": 42}
138+
139+
def test_returns_none_for_no_content(self):
140+
event = Event(invocation_id="id")
141+
assert _process_content_object(event) is None
142+
143+
def test_returns_none_for_empty_text(self):
144+
content = types.Content(parts=[types.Part(text=" ")])
145+
event = Event(content=content, invocation_id="id")
146+
assert _process_content_object(event) is None
147+
148+
106149
# --- _validate_resume_response ---
107150

108151

@@ -192,7 +235,7 @@ def test_scan_message_as_output(self):
192235
)
193236

194237
assert "node_a@1" in results
195-
assert results["node_a@1"].output == content
238+
assert results["node_a@1"].output == "hello"
196239

197240
def test_scan_descendant_interrupts(self):
198241
event = Event(

0 commit comments

Comments
 (0)