Skip to content

Commit 22fc332

Browse files
xuanyang15copybara-github
authored andcommitted
fix: Support resolving string annotations for find_context_parameter
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 884686010
1 parent 8b97318 commit 22fc332

2 files changed

Lines changed: 36 additions & 1 deletion

File tree

src/google/adk/utils/context_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from contextlib import aclosing
2424
import inspect
25+
import typing
2526
from typing import Any
2627
from typing import Callable
2728
from typing import get_args
@@ -80,7 +81,17 @@ def find_context_parameter(func: Callable[..., Any]) -> str | None:
8081
signature = inspect.signature(func)
8182
except (ValueError, TypeError):
8283
return None
84+
# Resolve string annotations (e.g., 'Context')
85+
try:
86+
type_hints = typing.get_type_hints(func)
87+
except Exception:
88+
# get_type_hints can fail for various reasons (e.g., unresolvable forward
89+
# references). In such cases, we fall back to inspecting the parameter
90+
# annotations directly.
91+
type_hints = {}
92+
8393
for name, param in signature.parameters.items():
84-
if _is_context_type(param.annotation):
94+
annotation = type_hints.get(name, param.annotation)
95+
if _is_context_type(annotation):
8596
return name
8697
return None

tests/unittests/utils/test_context_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def my_tool(query: str, ctx: Context) -> str:
3333

3434
assert find_context_parameter(my_tool) == 'ctx'
3535

36+
def test_find_context_parameter_with_string_annotation(self):
37+
"""Test detection of string annotation 'Context'."""
38+
39+
def my_tool(query: str, ctx: 'Context') -> str:
40+
return query
41+
42+
assert find_context_parameter(my_tool) == 'ctx'
43+
44+
def test_find_context_parameter_with_string_tool_context(self):
45+
"""Test detection of string annotation 'ToolContext'."""
46+
47+
def my_tool(query: str, ctx: 'ToolContext') -> str:
48+
return query
49+
50+
assert find_context_parameter(my_tool) == 'ctx'
51+
52+
def test_find_context_parameter_with_string_optional_context(self):
53+
"""Test detection of string annotation 'Optional[Context]'."""
54+
55+
def my_tool(query: str, ctx: 'Optional[Context]' = None) -> str:
56+
return query
57+
58+
assert find_context_parameter(my_tool) == 'ctx'
59+
3660
def test_find_context_parameter_with_tool_context_type(self):
3761
"""Test detection of ToolContext type annotation."""
3862

0 commit comments

Comments
 (0)