diff --git a/src/google/adk/utils/instructions_utils.py b/src/google/adk/utils/instructions_utils.py index 505b5cf128..f340968f32 100644 --- a/src/google/adk/utils/instructions_utils.py +++ b/src/google/adk/utils/instructions_utils.py @@ -26,6 +26,29 @@ logger = logging.getLogger('google_adk.' + __name__) +_MISSING = object() + + +def _resolve_nested(state, path: str, optional: bool = False): + """Traverse a nested dict/State using a dot-separated path. + + Args: + state: The state mapping to traverse. + path: A dot-separated key path, e.g. "user.profile.name". + optional: If True, return _MISSING for missing keys instead of raising. + + Returns: + The resolved value, or _MISSING if not found and optional is True. + """ + keys = path.split('.') + value = state + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + return _MISSING + return value + async def inject_session_state( template: str, @@ -36,6 +59,13 @@ async def inject_session_state( This method is intended to be used in InstructionProvider based instruction and global_instruction which are called with readonly_context. + Supports dot-separated paths for nested state values, e.g. + ``{user.profile.name}`` resolves to + ``session.state['user']['profile']['name']``. + + Use ``?`` suffix for optional variables that may not exist, e.g. + ``{user.preferences?}`` returns empty string if not found. + e.g. ``` ... @@ -45,7 +75,8 @@ async def build_instruction( readonly_context: ReadonlyContext, ) -> str: return await inject_session_state( - 'You can inject a state variable like {var_name} or an artifact ' + 'You can inject a state variable like {var_name} or a nested ' + 'value like {user.profile.name} or an artifact ' '{artifact.file_name} into the instruction template.', readonly_context, ) @@ -106,12 +137,16 @@ async def _replace_match(match) -> str: else: if not _is_valid_state_name(var_name): return match.group() - if var_name in invocation_context.session.state: - value = invocation_context.session.state[var_name] - if value is None: - return '' - return str(value) + state = invocation_context.session.state + if '.' in var_name and not var_name.startswith( + (State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX) + ): + value = _resolve_nested(state, var_name, optional) + elif var_name in state: + value = state[var_name] else: + value = _MISSING + if value is _MISSING: if optional: logger.debug( 'Context variable %s not found, replacing with empty string', @@ -120,6 +155,9 @@ async def _replace_match(match) -> str: return '' else: raise KeyError(f'Context variable not found: `{var_name}`.') + if value is None: + return '' + return str(value) return await _async_sub(r'{+[^{}]*}+', _replace_match, template) @@ -129,6 +167,7 @@ def _is_valid_state_name(var_name): Valid state is either: - Valid identifier + - Dot-separated valid identifiers (e.g. "user.profile.name") - : All the others will just return as it is. @@ -140,7 +179,8 @@ def _is_valid_state_name(var_name): """ parts = var_name.split(':') if len(parts) == 1: - return var_name.isidentifier() + # Support dot-separated nested paths like "user.profile.name" + return all(seg.isidentifier() for seg in var_name.split('.')) if len(parts) == 2: prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX] diff --git a/tests/unittests/utils/test_instructions_utils.py b/tests/unittests/utils/test_instructions_utils.py index d76e5032ec..3fa87c1ac0 100644 --- a/tests/unittests/utils/test_instructions_utils.py +++ b/tests/unittests/utils/test_instructions_utils.py @@ -267,3 +267,77 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty(): instruction_template, invocation_context ) assert populated_instruction == "Optional value: " + + +@pytest.mark.asyncio +async def test_inject_session_state_with_nested_value(): + instruction_template = "Hello {user.profile.name}, age {user.profile.age}." + invocation_context = await _create_test_readonly_context( + state={"user": {"profile": {"name": "Alice", "age": 30}}} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello Alice, age 30." + + +@pytest.mark.asyncio +async def test_inject_session_state_with_nested_missing_raises_key_error(): + instruction_template = "Value: {user.missing.key}" + invocation_context = await _create_test_readonly_context( + state={"user": {"profile": {"name": "Alice"}}} + ) + + with pytest.raises( + KeyError, match="Context variable not found: `user.missing.key`." + ): + await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + + +@pytest.mark.asyncio +async def test_inject_session_state_with_nested_optional_missing(): + instruction_template = "Value: {user.missing.key?}" + invocation_context = await _create_test_readonly_context( + state={"user": {"profile": {"name": "Alice"}}} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Value: " + + +@pytest.mark.asyncio +async def test_inject_session_state_nested_does_not_conflict_with_artifact(): + instruction_template = ( + "Name: {user.name}, Artifact: {artifact.my_file}" + ) + mock_artifact_service = MockArtifactService( + {"my_file": "artifact content"} + ) + invocation_context = await _create_test_readonly_context( + state={"user": {"name": "Bob"}}, + artifact_service=mock_artifact_service, + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Name: Bob, Artifact: artifact content" + + +@pytest.mark.asyncio +async def test_inject_session_state_flat_key_still_works(): + """Flat keys still work even when nested is supported.""" + instruction_template = "Value: {simple_key}" + invocation_context = await _create_test_readonly_context( + state={"simple_key": "flat_value"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Value: flat_value"