Skip to content

Commit eb4674b

Browse files
wukathcopybara-github
authored andcommitted
feat: Add support for model endpoints in Agent Registry
Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 893585328
1 parent dcee290 commit eb4674b

File tree

7 files changed

+277
-18
lines changed

7 files changed

+277
-18
lines changed

contributing/samples/agent_registry_agent/agent.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from google.adk.agents.llm_agent import LlmAgent
2020
from google.adk.integrations.agent_registry import AgentRegistry
21+
from google.adk.models.google_llm import Gemini
2122

2223
# Project and location can be set via environment variables:
2324
# GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
@@ -27,6 +28,8 @@
2728
# Initialize Agent Registry client
2829
registry = AgentRegistry(project_id=project_id, location=location)
2930

31+
# List agents, MCP servers, and endpoints resource names from the registry.
32+
# They can be used to initialize the agent, toolset, and model below.
3033
print(f"Listing agents in {project_id}/{location}...")
3134
agents = registry.list_agents()
3235
for agent in agents.get("agents", []):
@@ -37,6 +40,11 @@
3740
for server in mcp_servers.get("mcpServers", []):
3841
print(f"- MCP Server: {server.get('displayName')} ({server.get('name')})")
3942

43+
print(f"\nListing endpoints in {project_id}/{location}...")
44+
endpoints = registry.list_endpoints()
45+
for endpoint in endpoints.get("endpoints", []):
46+
print(f"- Endpoint: {endpoint.get('displayName')} ({endpoint.get('name')})")
47+
4048
# Example of using a specific agent or MCP server from the registry:
4149
# (Note: These names should be full resource names as returned by list methods)
4250

@@ -52,8 +60,19 @@
5260
f"projects/{project_id}/locations/{location}/mcpServers/MCP_SERVER_NAME"
5361
)
5462

63+
# 3. Getting a specific model endpoint configuration
64+
# This returns a string like:
65+
# "projects/adk12345/locations/us-central1/publishers/google/models/gemini-2.5-flash"
66+
# TODO: Replace ENDPOINT_NAME with your endpoint name
67+
model_name = registry.get_model_name(
68+
f"projects/{project_id}/locations/{location}/endpoints/ENDPOINT_NAME"
69+
)
70+
71+
# Initialize the model using the resolved model name from registry.
72+
gemini_model = Gemini(model=model_name)
73+
5574
root_agent = LlmAgent(
56-
model="gemini-2.5-flash",
75+
model=gemini_model,
5776
name="discovery_agent",
5877
instruction=(
5978
"You have access to tools and sub-agents discovered via Registry."

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from typing import Callable
2525
from typing import Dict
2626
from typing import List
27+
from typing import Mapping
2728
from typing import Optional
2829
from typing import Sequence
30+
from typing import TypedDict
2931
from typing import Union
3032
from urllib.parse import parse_qs
3133
from urllib.parse import urlparse
@@ -109,6 +111,26 @@ class _ProtocolType(str, Enum):
109111
CUSTOM = "CUSTOM"
110112

111113

114+
class Interface(TypedDict, total=False):
115+
"""Details for a single connection interface."""
116+
117+
url: str
118+
protocolBinding: str
119+
120+
121+
class Endpoint(TypedDict, total=False):
122+
"""Full metadata for a registered Endpoint."""
123+
124+
name: str
125+
endpointId: str
126+
displayName: str
127+
description: str
128+
interfaces: List[Interface]
129+
createTime: str
130+
updateTime: str
131+
attributes: Dict[str, Any]
132+
133+
112134
class AgentRegistry:
113135
"""Client for interacting with the Google Cloud Agent Registry service.
114136
@@ -194,7 +216,7 @@ def _make_request(
194216

195217
def _get_connection_uri(
196218
self,
197-
resource_details: Dict[str, Any],
219+
resource_details: Mapping[str, Any],
198220
protocol_type: Optional[_ProtocolType] = None,
199221
protocol_binding: Optional[A2ATransport] = None,
200222
) -> Optional[str]:
@@ -273,6 +295,56 @@ def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset:
273295
header_provider=self._header_provider,
274296
)
275297

298+
# --- Endpoint Methods ---
299+
300+
def list_endpoints(
301+
self,
302+
filter_str: Optional[str] = None,
303+
page_size: Optional[int] = None,
304+
page_token: Optional[str] = None,
305+
) -> Dict[str, Any]:
306+
"""Fetches a list of Endpoints."""
307+
params = {}
308+
if filter_str:
309+
params["filter"] = filter_str
310+
if page_size:
311+
params["pageSize"] = str(page_size)
312+
if page_token:
313+
params["pageToken"] = page_token
314+
return self._make_request("endpoints", params=params)
315+
316+
def get_endpoint(self, name: str) -> Endpoint:
317+
"""Retrieves details of a specific Endpoint."""
318+
return self._make_request(name) # type: ignore
319+
320+
def get_model_name(self, endpoint_name: str) -> str:
321+
"""Retrieves and parses an endpoint into a model resource name.
322+
323+
Args:
324+
endpoint_name: The full resource name of the endpoint.
325+
326+
Returns:
327+
The resolved model resource name string (e.g.
328+
projects/.../locations/.../publishers/google/models/...).
329+
"""
330+
endpoint_details = self.get_endpoint(endpoint_name)
331+
uri = self._get_connection_uri(endpoint_details)
332+
if not uri:
333+
raise ValueError(
334+
f"Connection URI not found for endpoint: {endpoint_name}"
335+
)
336+
337+
uri = re.sub(r":\w+$", "", uri)
338+
339+
if uri.startswith("projects/"):
340+
return uri
341+
342+
match = re.search(r"(projects/.+)", uri)
343+
if match:
344+
return match.group(1)
345+
346+
return uri
347+
276348
# --- Agent Methods ---
277349

278350
def list_agents(

src/google/adk/models/anthropic_llm.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import json
2424
import logging
2525
import os
26+
import re
2627
from typing import Any
2728
from typing import AsyncGenerator
2829
from typing import Iterable
@@ -364,10 +365,23 @@ class AnthropicLlm(BaseLlm):
364365
def supported_models(cls) -> list[str]:
365366
return [r"claude-3-.*", r"claude-.*-4.*"]
366367

368+
def _resolve_model_name(self, model: Optional[str]) -> str:
369+
if not model:
370+
return self.model
371+
if model.startswith("projects/"):
372+
match = re.search(
373+
r"projects/[^/]+/locations/[^/]+/(?:publishers/anthropic/models|endpoints)/([^/:]+)",
374+
model,
375+
)
376+
if match:
377+
return match.group(1)
378+
return model
379+
367380
@override
368381
async def generate_content_async(
369382
self, llm_request: LlmRequest, stream: bool = False
370383
) -> AsyncGenerator[LlmResponse, None]:
384+
model_to_use = self._resolve_model_name(llm_request.model)
371385
messages = [
372386
content_to_message_param(content)
373387
for content in llm_request.contents or []
@@ -390,7 +404,7 @@ async def generate_content_async(
390404

391405
if not stream:
392406
message = await self._anthropic_client.messages.create(
393-
model=llm_request.model,
407+
model=model_to_use,
394408
system=llm_request.config.system_instruction,
395409
messages=messages,
396410
tools=tools,
@@ -416,8 +430,9 @@ async def _generate_content_streaming(
416430
Yields partial LlmResponse objects as content arrives, followed by
417431
a final aggregated LlmResponse with all content.
418432
"""
433+
model_to_use = self._resolve_model_name(llm_request.model)
419434
raw_stream = await self._anthropic_client.messages.create(
420-
model=llm_request.model,
435+
model=model_to_use,
421436
system=llm_request.config.system_instruction,
422437
messages=messages,
423438
tools=tools,
@@ -511,17 +526,26 @@ class Claude(AnthropicLlm):
511526
@cached_property
512527
@override
513528
def _anthropic_client(self) -> AsyncAnthropicVertex:
514-
if (
515-
"GOOGLE_CLOUD_PROJECT" not in os.environ
516-
or "GOOGLE_CLOUD_LOCATION" not in os.environ
517-
):
529+
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT")
530+
location = os.environ.get("GOOGLE_CLOUD_LOCATION")
531+
532+
if self.model.startswith("projects/"):
533+
match = re.search(
534+
r"projects/([^/]+)/locations/([^/]+)/",
535+
self.model,
536+
)
537+
if match:
538+
project_id = match.group(1)
539+
location = match.group(2)
540+
541+
if not project_id or not location:
518542
raise ValueError(
519543
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
520544
" Anthropic on Vertex."
521545
)
522546

523547
return AsyncAnthropicVertex(
524-
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
525-
region=os.environ["GOOGLE_CLOUD_LOCATION"],
548+
project_id=project_id,
549+
region=location,
526550
default_headers=get_tracking_headers(),
527551
)

src/google/adk/models/google_llm.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,19 @@ def api_client(self) -> Client:
304304
"""
305305
from google.genai import Client
306306

307-
return Client(
308-
http_options=types.HttpOptions(
307+
base_url = self.base_url
308+
309+
kwargs: dict[str, Any] = {
310+
'http_options': types.HttpOptions(
309311
headers=self._tracking_headers(),
310312
retry_options=self.retry_options,
311-
base_url=self.base_url,
313+
base_url=base_url,
312314
)
313-
)
315+
}
316+
if self.model.startswith('projects/'):
317+
kwargs['vertexai'] = True
318+
319+
return Client(**kwargs)
314320

315321
@cached_property
316322
def _api_backend(self) -> GoogleLLMVariant:
@@ -336,11 +342,19 @@ def _live_api_version(self) -> str:
336342
def _live_api_client(self) -> Client:
337343
from google.genai import Client
338344

339-
return Client(
340-
http_options=types.HttpOptions(
341-
headers=self._tracking_headers(), api_version=self._live_api_version
345+
base_url = self.base_url
346+
347+
kwargs: dict[str, Any] = {
348+
'http_options': types.HttpOptions(
349+
headers=self._tracking_headers(),
350+
api_version=self._live_api_version,
351+
base_url=base_url,
342352
)
343-
)
353+
}
354+
if self.model.startswith('projects/'):
355+
kwargs['vertexai'] = True
356+
357+
return Client(**kwargs)
344358

345359
@contextlib.asynccontextmanager
346360
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,37 @@ def test_get_mcp_server(self, mock_httpx, registry):
272272
server = registry.get_mcp_server("test-mcp")
273273
assert server == {"name": "test-mcp"}
274274

275+
@patch("httpx.Client")
276+
def test_list_endpoints(self, mock_httpx, registry):
277+
mock_response = MagicMock()
278+
mock_response.json.return_value = {"endpoints": []}
279+
mock_response.raise_for_status = MagicMock()
280+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
281+
mock_response
282+
)
283+
284+
# Mock auth refresh
285+
registry._credentials.token = "token"
286+
registry._credentials.refresh = MagicMock()
287+
288+
endpoints = registry.list_endpoints()
289+
assert endpoints == {"endpoints": []}
290+
291+
@patch("httpx.Client")
292+
def test_get_endpoint(self, mock_httpx, registry):
293+
mock_response = MagicMock()
294+
mock_response.json.return_value = {"name": "test-endpoint"}
295+
mock_response.raise_for_status = MagicMock()
296+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
297+
mock_response
298+
)
299+
300+
registry._credentials.token = "token"
301+
registry._credentials.refresh = MagicMock()
302+
303+
server = registry.get_endpoint("test-endpoint")
304+
assert server == {"name": "test-endpoint"}
305+
275306
@patch("httpx.Client")
276307
def test_get_mcp_toolset(self, mock_httpx, registry):
277308
mock_response = MagicMock()
@@ -420,3 +451,41 @@ def test_make_request_raises_generic_exception(self, mock_httpx, registry):
420451

421452
with pytest.raises(RuntimeError, match="API request failed: Generic error"):
422453
registry._make_request("test-path")
454+
455+
@patch.object(AgentRegistry, "get_endpoint")
456+
def test_get_model_name_starts_with_projects(
457+
self, mock_get_endpoint, registry
458+
):
459+
mock_get_endpoint.return_value = {
460+
"interfaces": [{"url": "projects/p1/locations/l1/models/m1"}]
461+
}
462+
model_name = registry.get_model_name("test-endpoint")
463+
assert model_name == "projects/p1/locations/l1/models/m1"
464+
465+
@patch.object(AgentRegistry, "get_endpoint")
466+
def test_get_model_name_contains_projects(self, mock_get_endpoint, registry):
467+
mock_get_endpoint.return_value = {
468+
"interfaces": [{
469+
"url": (
470+
"https://vertexai.googleapis.com/v1/projects/p1/locations/l1/models/m1"
471+
)
472+
}]
473+
}
474+
model_name = registry.get_model_name("test-endpoint")
475+
assert model_name == "projects/p1/locations/l1/models/m1"
476+
477+
@patch.object(AgentRegistry, "get_endpoint")
478+
def test_get_model_name_strips_suffix(self, mock_get_endpoint, registry):
479+
mock_get_endpoint.return_value = {
480+
"interfaces": [{"url": "projects/p1/locations/l1/models/m1:predict"}]
481+
}
482+
model_name = registry.get_model_name("test-endpoint")
483+
assert model_name == "projects/p1/locations/l1/models/m1"
484+
485+
@patch.object(AgentRegistry, "get_endpoint")
486+
def test_get_model_name_raises_value_error_if_no_uri(
487+
self, mock_get_endpoint, registry
488+
):
489+
mock_get_endpoint.return_value = {}
490+
with pytest.raises(ValueError, match="Connection URI not found"):
491+
registry.get_model_name("test-endpoint")

0 commit comments

Comments
 (0)