Skip to content

Commit 428e789

Browse files
wyf7107copybara-github
authored andcommitted
fix(workflow): Resolve raw Content output crash on rehydration
Port of GitHub PR: #5909 Centralizes text extraction and schema validation for rehydrated output via a shared helper `extract_text_from_content`. If a stored output fails validation against the node's schema due to schema drift, gracefully fallback to parsing unvalidated JSON to avoid blocking resumption, rather than crashing. Co-authored-by: Yifan Wang <wanyif@google.com> PiperOrigin-RevId: 927408084
1 parent 9133858 commit 428e789

5 files changed

Lines changed: 143 additions & 3 deletions

File tree

src/google/adk/utils/content_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,10 @@ def filter_audio_parts(content: types.Content) -> types.Content | None:
3636
if not filtered_parts:
3737
return None
3838
return types.Content(role=content.role, parts=filtered_parts)
39+
40+
41+
def extract_text_from_content(content: types.Content | None) -> str:
42+
"""Extracts text from a Content object, filtering out thoughts."""
43+
if not content or not content.parts:
44+
return ''
45+
return ''.join(p.text for p in content.parts if p.text and not p.thought)

src/google/adk/workflow/_base_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pydantic import ValidationError
2929

3030
from ..utils._schema_utils import SchemaType
31+
from ..utils.content_utils import extract_text_from_content
3132
from ._retry_config import RetryConfig
3233

3334
if TYPE_CHECKING:
@@ -143,7 +144,7 @@ def _validate_input_data(self, data: Any) -> Any:
143144
"""Validates data against input_schema if set."""
144145
if self.input_schema and isinstance(data, types.Content):
145146
# Extract text from Content (e.g. user input from START node).
146-
text = ''.join(part.text for part in data.parts if part.text)
147+
text = extract_text_from_content(data)
147148
if self.input_schema is str:
148149
return text
149150
# If schema is defined, try to parse the text as JSON.
@@ -168,7 +169,7 @@ def _validate_output_data(self, data: Any) -> Any:
168169
except ValidationError as e:
169170
# 2. If failed, try to parse JSON ONLY if it's Content
170171
if isinstance(data, types.Content):
171-
text = ''.join(part.text for part in data.parts if part.text)
172+
text = extract_text_from_content(data)
172173
if self.output_schema is str:
173174
return text
174175
if text.strip():

src/google/adk/workflow/utils/_rehydration_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,23 @@
2020
from dataclasses import dataclass
2121
from dataclasses import field
2222
import json
23+
import logging
2324
from typing import Any
25+
from typing import TYPE_CHECKING
2426

27+
from google.genai import types
2528
from pydantic import TypeAdapter
2629
from pydantic import ValidationError
2730

2831
from ...events._node_path_builder import _NodePathBuilder
2932
from ...events.event import Event
3033
from ._workflow_hitl_utils import REQUEST_INPUT_FUNCTION_CALL_NAME
3134

35+
if TYPE_CHECKING:
36+
from .._base_node import BaseNode
37+
38+
logger = logging.getLogger('google_adk.' + __name__)
39+
3240
_RESULT_KEY = 'result'
3341

3442

@@ -96,6 +104,49 @@ def _extract_schema_from_event(event: Event, interrupt_id: str) -> Any | None:
96104
return None
97105

98106

107+
def _process_rehydrated_output(node: BaseNode, output: Any) -> Any:
108+
"""Process rehydrated output from event.content using the node's output schema.
109+
110+
Protects type consistency between fresh runs and rehydrated runs by
111+
properly respecting output schemas, handling model reasoning thought
112+
blocks, and ensuring raw strings are returned when no output schema is
113+
configured.
114+
"""
115+
if not isinstance(output, types.Content):
116+
return output
117+
118+
from google.adk.utils.content_utils import extract_text_from_content
119+
120+
text = extract_text_from_content(output).strip()
121+
122+
if not text:
123+
return None
124+
125+
if node.output_schema:
126+
if node.output_schema is str:
127+
return text
128+
try:
129+
validated = TypeAdapter(node.output_schema).validate_json(text)
130+
return node._to_serializable(validated)
131+
except ValidationError as e:
132+
# Fallback to unvalidated JSON parsing on validation failure
133+
# to prevent blocking resumption on schema drift.
134+
try:
135+
parsed = json.loads(text)
136+
logger.warning(
137+
'Validation failed for rehydrated output against schema: %s. '
138+
'Falling back to unvalidated JSON output to allow resumption.',
139+
e,
140+
)
141+
return parsed
142+
except ValueError:
143+
raise ValueError(
144+
f'Validation failed for rehydrated output against schema: {e}'
145+
) from e
146+
else:
147+
return text
148+
149+
99150
def _validate_resume_response(response_data: Any, schema: Any) -> Any:
100151
"""Validates and coerces resume response data against a schema.
101152

src/google/adk/workflow/utils/_replay_interceptor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .._node_state import NodeState
2828
from .._node_status import NodeStatus
2929
from ._rehydration_utils import _ChildScanState
30+
from ._rehydration_utils import _process_rehydrated_output
3031

3132
if TYPE_CHECKING:
3233
from .._dynamic_node_scheduler import DynamicNodeRun
@@ -112,7 +113,7 @@ def check_interception(
112113
):
113114
# Case 3: Cross-turn successfully completed in a prior turn (fast-forward).
114115
# Bypass execution completely and return the cached output and route.
115-
output = recovered.output
116+
output = _process_rehydrated_output(node, recovered.output)
116117
route = recovered.route
117118

118119
elif recovered.interrupt_ids:

tests/unittests/workflow/utils/test_rehydration_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from google.adk.events.event import Event
1818
from google.adk.events.event import NodeInfo
1919
from google.adk.events.request_input import RequestInput
20+
from google.adk.workflow._base_node import BaseNode
2021
from google.adk.workflow.utils._rehydration_utils import _ChildScanState
22+
from google.adk.workflow.utils._rehydration_utils import _process_rehydrated_output
2123
from google.adk.workflow.utils._rehydration_utils import _reconstruct_node_states
2224
from google.adk.workflow.utils._rehydration_utils import _unwrap_response
2325
from google.adk.workflow.utils._rehydration_utils import _validate_resume_response
@@ -103,6 +105,84 @@ def test_roundtrip_wrap_unwrap_dict(self):
103105
assert _unwrap_response(_wrap_response(d)) == d
104106

105107

108+
# --- _process_rehydrated_output ---
109+
110+
111+
class TestProcessRehydratedOutput:
112+
113+
def test_extracts_plain_text_without_schema(self):
114+
node = BaseNode(name="dummy")
115+
content = types.Content(parts=[types.Part(text="hello world")])
116+
assert _process_rehydrated_output(node, content) == "hello world"
117+
118+
def test_returns_plain_text_even_if_json_when_no_schema(self):
119+
node = BaseNode(name="dummy")
120+
content = types.Content(parts=[types.Part(text='{"foo": "bar"}')])
121+
assert _process_rehydrated_output(node, content) == '{"foo": "bar"}'
122+
123+
def test_parses_json_text_with_output_schema(self):
124+
class MySchema(BaseModel):
125+
foo: str
126+
127+
node = BaseNode(name="dummy", output_schema=MySchema)
128+
content = types.Content(parts=[types.Part(text='{"foo": "bar"}')])
129+
assert _process_rehydrated_output(node, content) == {"foo": "bar"}
130+
131+
def test_joins_multiple_parts(self):
132+
node = BaseNode(name="dummy")
133+
content = types.Content(
134+
parts=[types.Part(text="hello "), types.Part(text="world")]
135+
)
136+
assert _process_rehydrated_output(node, content) == "hello world"
137+
138+
def test_filters_thought_parts(self):
139+
class MySchema(BaseModel):
140+
answer: int
141+
142+
node = BaseNode(name="dummy", output_schema=MySchema)
143+
content = types.Content(
144+
parts=[
145+
types.Part(text="thinking...", thought=True),
146+
types.Part(text='{"answer": 42}'),
147+
]
148+
)
149+
assert _process_rehydrated_output(node, content) == {"answer": 42}
150+
151+
def test_returns_none_for_empty_text(self):
152+
node = BaseNode(name="dummy")
153+
content = types.Content(parts=[types.Part(text=" ")])
154+
assert _process_rehydrated_output(node, content) is None
155+
156+
def test_gracefully_falls_back_on_schema_mismatch(self, caplog):
157+
class MySchema(BaseModel):
158+
foo: str
159+
bar: int # Required field that is missing in the stored output
160+
161+
node = BaseNode(name="dummy", output_schema=MySchema)
162+
content = types.Content(parts=[types.Part(text='{"foo": "only"}')])
163+
164+
# Should NOT raise ValueError, but fallback to unvalidated parsed dict
165+
res = _process_rehydrated_output(node, content)
166+
assert res == {"foo": "only"}
167+
assert (
168+
"Validation failed for rehydrated output against schema" in caplog.text
169+
)
170+
171+
def test_raises_value_error_if_not_valid_json_on_schema_mismatch(self):
172+
class MySchema(BaseModel):
173+
foo: str
174+
175+
node = BaseNode(name="dummy", output_schema=MySchema)
176+
content = types.Content(parts=[types.Part(text="invalid json")])
177+
178+
# Should raise ValueError because it's not valid JSON
179+
with pytest.raises(
180+
ValueError,
181+
match="Validation failed for rehydrated output against schema",
182+
):
183+
_process_rehydrated_output(node, content)
184+
185+
106186
# --- _validate_resume_response ---
107187

108188

0 commit comments

Comments
 (0)