Skip to content

Commit 004318a

Browse files
STHITAPRAJNASclaude
authored andcommitted
feat(a2a): expose a2a_task_store and a2a_push_config_store in get_fast_api_app
## Problem `get_fast_api_app()` unconditionally creates an `InMemoryTaskStore` and `InMemoryPushNotificationConfigStore`, making it impossible for callers to inject persistent or shared stores without patching ADK internals. This is especially painful in production deployments where: - Multiple replicas need a shared task store to route A2A callbacks correctly - Operators want task state to survive server restarts (e.g. SQLite/Postgres) ## Solution Adds two new optional keyword arguments to `get_fast_api_app()`: - `a2a_task_store: Optional[Any] = None` - `a2a_push_config_store: Optional[Any] = None` When `None` (the default), the existing `InMemory*` defaults are used — fully backward-compatible. When provided, the caller-supplied instances are forwarded directly to `DefaultRequestHandler`. This mirrors the pattern introduced for the lower-level `to_a2a()` helper in PR #3839. ## Usage from a2a.server.tasks import DatabaseTaskStore from sqlalchemy.ext.asyncio import create_async_engine engine = create_async_engine("postgresql+asyncpg://user:pw@host/db") app = get_fast_api_app( agents_dir="agents/", web=True, a2a=True, a2a_task_store=DatabaseTaskStore(engine), ) ## Tests Two new test cases added to tests/unittests/cli/test_fast_api.py: - test_a2a_uses_in_memory_task_store_by_default - test_a2a_custom_task_store_bypasses_in_memory_default
1 parent f95ac48 commit 004318a

2 files changed

Lines changed: 158 additions & 2 deletions

File tree

src/google/adk/cli/fast_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def get_fast_api_app(
8585
allow_origins: Optional[list[str]] = None,
8686
web: bool,
8787
a2a: bool = False,
88+
a2a_task_store: Optional[Any] = None,
89+
a2a_push_config_store: Optional[Any] = None,
8890
host: str = "127.0.0.1",
8991
port: int = 8000,
9092
url_prefix: Optional[str] = None,
@@ -128,6 +130,23 @@ def get_fast_api_app(
128130
allow_origins: List of allowed origins for CORS.
129131
web: Whether to enable the web UI and serve its assets.
130132
a2a: Whether to enable Agent-to-Agent (A2A) protocol support.
133+
a2a_task_store: Optional A2A TaskStore instance. Defaults to
134+
InMemoryTaskStore when a2a=True. Pass a DatabaseTaskStore (from the
135+
a2a-sdk) for persistence across server restarts and horizontal replicas.
136+
Example::
137+
138+
from a2a.server.tasks import DatabaseTaskStore
139+
from sqlalchemy.ext.asyncio import create_async_engine
140+
141+
engine = create_async_engine("postgresql+asyncpg://user:pw@host/db")
142+
app = get_fast_api_app(
143+
agents_dir="agents/", web=True, a2a=True,
144+
a2a_task_store=DatabaseTaskStore(engine),
145+
)
146+
147+
a2a_push_config_store: Optional A2A PushNotificationConfigStore instance.
148+
Defaults to InMemoryPushNotificationConfigStore when a2a=True. Pass a
149+
DatabasePushNotificationConfigStore for persistence across restarts.
131150
host: Host address for the server (defaults to 127.0.0.1).
132151
port: Port number for the server (defaults to 8000).
133152
url_prefix: Optional prefix for all URL routes.
@@ -598,7 +617,8 @@ async def get_agent_builder(
598617
base_path = Path.cwd() / agents_dir
599618
# the root agents directory should be an existing folder
600619
if base_path.exists() and base_path.is_dir():
601-
a2a_task_store = InMemoryTaskStore()
620+
if a2a_task_store is None:
621+
a2a_task_store = InMemoryTaskStore()
602622

603623
def create_a2a_runner_loader(captured_app_name: str):
604624
"""Factory function to create A2A runner with proper closure."""
@@ -626,7 +646,11 @@ async def _get_a2a_runner_async() -> Runner:
626646
runner=create_a2a_runner_loader(app_name),
627647
)
628648

629-
push_config_store = InMemoryPushNotificationConfigStore()
649+
push_config_store = (
650+
a2a_push_config_store
651+
if a2a_push_config_store is not None
652+
else InMemoryPushNotificationConfigStore()
653+
)
630654

631655
request_handler = DefaultRequestHandler(
632656
agent_executor=agent_executor,

tests/unittests/cli/test_fast_api.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,138 @@ def test_a2a_disabled_by_default(test_app):
18661866
logger.info("A2A disabled by default test passed")
18671867

18681868

1869+
def test_a2a_uses_in_memory_task_store_by_default(
1870+
mock_session_service,
1871+
mock_artifact_service,
1872+
mock_memory_service,
1873+
mock_agent_loader,
1874+
mock_eval_sets_manager,
1875+
mock_eval_set_results_manager,
1876+
temp_agents_dir_with_a2a,
1877+
monkeypatch,
1878+
):
1879+
"""Test that InMemoryTaskStore is created when no task_store is provided."""
1880+
with (
1881+
patch("signal.signal", return_value=None),
1882+
patch(
1883+
"google.adk.cli.fast_api.create_session_service_from_options",
1884+
return_value=mock_session_service,
1885+
),
1886+
patch(
1887+
"google.adk.cli.fast_api.create_artifact_service_from_options",
1888+
return_value=mock_artifact_service,
1889+
),
1890+
patch(
1891+
"google.adk.cli.fast_api.create_memory_service_from_options",
1892+
return_value=mock_memory_service,
1893+
),
1894+
patch(
1895+
"google.adk.cli.fast_api.AgentLoader",
1896+
return_value=mock_agent_loader,
1897+
),
1898+
patch(
1899+
"google.adk.cli.fast_api.LocalEvalSetsManager",
1900+
return_value=mock_eval_sets_manager,
1901+
),
1902+
patch(
1903+
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
1904+
return_value=mock_eval_set_results_manager,
1905+
),
1906+
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class,
1907+
patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"),
1908+
patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"),
1909+
patch("a2a.server.request_handlers.DefaultRequestHandler"),
1910+
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
1911+
):
1912+
mock_a2a_app.return_value.routes.return_value = []
1913+
monkeypatch.chdir(temp_agents_dir_with_a2a)
1914+
1915+
_ = get_fast_api_app(
1916+
agents_dir=".",
1917+
web=True,
1918+
session_service_uri="",
1919+
artifact_service_uri="",
1920+
memory_service_uri="",
1921+
allow_origins=["*"],
1922+
a2a=True,
1923+
host="127.0.0.1",
1924+
port=8000,
1925+
)
1926+
1927+
mock_task_store_class.assert_called_once()
1928+
1929+
1930+
def test_a2a_custom_task_store_bypasses_in_memory_default(
1931+
mock_session_service,
1932+
mock_artifact_service,
1933+
mock_memory_service,
1934+
mock_agent_loader,
1935+
mock_eval_sets_manager,
1936+
mock_eval_set_results_manager,
1937+
temp_agents_dir_with_a2a,
1938+
monkeypatch,
1939+
):
1940+
"""Test that a custom task_store is forwarded and InMemoryTaskStore is not created."""
1941+
custom_task_store = MagicMock()
1942+
1943+
with (
1944+
patch("signal.signal", return_value=None),
1945+
patch(
1946+
"google.adk.cli.fast_api.create_session_service_from_options",
1947+
return_value=mock_session_service,
1948+
),
1949+
patch(
1950+
"google.adk.cli.fast_api.create_artifact_service_from_options",
1951+
return_value=mock_artifact_service,
1952+
),
1953+
patch(
1954+
"google.adk.cli.fast_api.create_memory_service_from_options",
1955+
return_value=mock_memory_service,
1956+
),
1957+
patch(
1958+
"google.adk.cli.fast_api.AgentLoader",
1959+
return_value=mock_agent_loader,
1960+
),
1961+
patch(
1962+
"google.adk.cli.fast_api.LocalEvalSetsManager",
1963+
return_value=mock_eval_sets_manager,
1964+
),
1965+
patch(
1966+
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
1967+
return_value=mock_eval_set_results_manager,
1968+
),
1969+
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class,
1970+
patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"),
1971+
patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"),
1972+
patch(
1973+
"a2a.server.request_handlers.DefaultRequestHandler"
1974+
) as mock_handler,
1975+
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
1976+
):
1977+
mock_a2a_app.return_value.routes.return_value = []
1978+
monkeypatch.chdir(temp_agents_dir_with_a2a)
1979+
1980+
_ = get_fast_api_app(
1981+
agents_dir=".",
1982+
web=True,
1983+
session_service_uri="",
1984+
artifact_service_uri="",
1985+
memory_service_uri="",
1986+
allow_origins=["*"],
1987+
a2a=True,
1988+
a2a_task_store=custom_task_store,
1989+
host="127.0.0.1",
1990+
port=8000,
1991+
)
1992+
1993+
# InMemoryTaskStore must NOT be instantiated when a custom store is supplied
1994+
mock_task_store_class.assert_not_called()
1995+
1996+
# The custom store must be passed through to DefaultRequestHandler
1997+
call_kwargs = mock_handler.call_args.kwargs
1998+
assert call_kwargs["task_store"] is custom_task_store
1999+
2000+
18692001
def test_patch_memory(test_app, create_test_session, mock_memory_service):
18702002
"""Test adding a session to memory."""
18712003
info = create_test_session

0 commit comments

Comments
 (0)