Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions src/google/adk/utils/instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@

logger = logging.getLogger('google_adk.' + __name__)

_MISSING = object()


def _resolve_nested(state, path: str, optional: bool = False):
"""Traverse a nested dict/State using a dot-separated path.

Args:
state: The state mapping to traverse.
path: A dot-separated key path, e.g. "user.profile.name".
optional: If True, return _MISSING for missing keys instead of raising.

Returns:
The resolved value, or _MISSING if not found and optional is True.
"""
keys = path.split('.')
value = state
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
return _MISSING
return value


async def inject_session_state(
template: str,
Expand All @@ -36,6 +59,13 @@ async def inject_session_state(
This method is intended to be used in InstructionProvider based instruction
and global_instruction which are called with readonly_context.

Supports dot-separated paths for nested state values, e.g.
``{user.profile.name}`` resolves to
``session.state['user']['profile']['name']``.

Use ``?`` suffix for optional variables that may not exist, e.g.
``{user.preferences?}`` returns empty string if not found.

e.g.
```
...
Expand All @@ -45,7 +75,8 @@ async def build_instruction(
readonly_context: ReadonlyContext,
) -> str:
return await inject_session_state(
'You can inject a state variable like {var_name} or an artifact '
'You can inject a state variable like {var_name} or a nested '
'value like {user.profile.name} or an artifact '
'{artifact.file_name} into the instruction template.',
readonly_context,
)
Expand Down Expand Up @@ -106,12 +137,16 @@ async def _replace_match(match) -> str:
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]
if value is None:
return ''
return str(value)
state = invocation_context.session.state
if '.' in var_name and not var_name.startswith(
(State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX)
):
value = _resolve_nested(state, var_name, optional)
elif var_name in state:
value = state[var_name]
else:
value = _MISSING
if value is _MISSING:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
Expand All @@ -120,6 +155,9 @@ async def _replace_match(match) -> str:
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
if value is None:
return ''
return str(value)

return await _async_sub(r'{+[^{}]*}+', _replace_match, template)

Expand All @@ -129,6 +167,7 @@ def _is_valid_state_name(var_name):

Valid state is either:
- Valid identifier
- Dot-separated valid identifiers (e.g. "user.profile.name")
- <Valid prefix>:<Valid identifier>
All the others will just return as it is.

Expand All @@ -140,7 +179,8 @@ def _is_valid_state_name(var_name):
"""
parts = var_name.split(':')
if len(parts) == 1:
return var_name.isidentifier()
# Support dot-separated nested paths like "user.profile.name"
return all(seg.isidentifier() for seg in var_name.split('.'))

if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
Expand Down
74 changes: 74 additions & 0 deletions tests/unittests/utils/test_instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,77 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty():
instruction_template, invocation_context
)
assert populated_instruction == "Optional value: "


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_value():
instruction_template = "Hello {user.profile.name}, age {user.profile.age}."
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": {"name": "Alice", "age": 30}}}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Hello Alice, age 30."


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_missing_raises_key_error():
instruction_template = "Value: {user.missing.key}"
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": {"name": "Alice"}}}
)

with pytest.raises(
KeyError, match="Context variable not found: `user.missing.key`."
):
await instructions_utils.inject_session_state(
instruction_template, invocation_context
)


@pytest.mark.asyncio
async def test_inject_session_state_with_nested_optional_missing():
instruction_template = "Value: {user.missing.key?}"
invocation_context = await _create_test_readonly_context(
state={"user": {"profile": {"name": "Alice"}}}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Value: "


@pytest.mark.asyncio
async def test_inject_session_state_nested_does_not_conflict_with_artifact():
instruction_template = (
"Name: {user.name}, Artifact: {artifact.my_file}"
)
mock_artifact_service = MockArtifactService(
{"my_file": "artifact content"}
)
invocation_context = await _create_test_readonly_context(
state={"user": {"name": "Bob"}},
artifact_service=mock_artifact_service,
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Name: Bob, Artifact: artifact content"


@pytest.mark.asyncio
async def test_inject_session_state_flat_key_still_works():
"""Flat keys still work even when nested is supported."""
instruction_template = "Value: {simple_key}"
invocation_context = await _create_test_readonly_context(
state={"simple_key": "flat_value"}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Value: flat_value"