Skip to content

Commit 8ba0e6a

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add mTLS support for DiscoveryEngineSearchTool
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 938274974
1 parent 6b831d5 commit 8ba0e6a

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

src/google/adk/tools/discovery_engine_search_tool.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import discoveryengine_v1beta as discoveryengine
2929
from google.genai import types
3030

31+
from ..utils._mtls_utils import get_api_endpoint
3132
from .function_tool import FunctionTool
3233

3334
logger = logging.getLogger('google_adk.' + __name__)
@@ -37,6 +38,7 @@
3738
)
3839

3940
_DEFAULT_ENDPOINT = 'discoveryengine.googleapis.com'
41+
_DEFAULT_MTLS_ENDPOINT = 'discoveryengine.mtls.googleapis.com'
4042
_GLOBAL_LOCATION = 'global'
4143
_LOCATION_PATTERN = re.compile(
4244
r'/locations/([a-z0-9-]+)(?:/|$)', flags=re.IGNORECASE
@@ -85,6 +87,18 @@ def _resolve_location(resource_id: str, location: Optional[str]) -> str:
8587
return _GLOBAL_LOCATION
8688

8789

90+
def _get_api_endpoint(location: str) -> str:
91+
"""Returns API endpoint based on mTLS configuration and cert availability."""
92+
default_template = '{location}-' + _DEFAULT_ENDPOINT
93+
mtls_template = '{location}-' + _DEFAULT_MTLS_ENDPOINT
94+
95+
return get_api_endpoint(
96+
location=location,
97+
default_template=default_template,
98+
mtls_template=mtls_template,
99+
)
100+
101+
88102
def _build_client_options(
89103
resource_id: str,
90104
quota_project_id: Optional[str],
@@ -95,9 +109,7 @@ def _build_client_options(
95109
resolved_location = _resolve_location(resource_id, location)
96110

97111
if resolved_location != _GLOBAL_LOCATION:
98-
client_options_kwargs['api_endpoint'] = (
99-
f'{resolved_location}-{_DEFAULT_ENDPOINT}'
100-
)
112+
client_options_kwargs['api_endpoint'] = _get_api_endpoint(resolved_location)
101113
if quota_project_id:
102114
client_options_kwargs['quota_project_id'] = quota_project_id
103115

tests/unittests/tools/test_discovery_engine_search_tool.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,19 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
112112
),
113113
],
114114
)
115+
@mock.patch.object(discovery_engine_search_tool, "_get_api_endpoint")
115116
@mock.patch.object(discovery_engine_search_tool, "client_options")
116117
@mock.patch.object(discoveryengine, "SearchServiceClient")
117118
def test_init_with_regional_location_uses_regional_endpoint(
118119
self,
119120
mock_search_client,
120121
mock_client_options,
122+
mock_get_api_endpoint,
121123
tool_kwargs,
122124
expected_endpoint,
123125
):
124126
"""Test initialization uses the expected regional API endpoint."""
127+
mock_get_api_endpoint.return_value = expected_endpoint
125128
DiscoveryEngineSearchTool(**tool_kwargs)
126129

127130
mock_client_options.ClientOptions.assert_called_once_with(
@@ -132,12 +135,14 @@ def test_init_with_regional_location_uses_regional_endpoint(
132135
client_options=mock_client_options.ClientOptions.return_value,
133136
)
134137

138+
@mock.patch.object(discovery_engine_search_tool, "_get_api_endpoint")
135139
@mock.patch.object(discovery_engine_search_tool, "client_options")
136140
@mock.patch.object(discoveryengine, "SearchServiceClient")
137141
def test_init_with_explicit_location_override_uses_input_location(
138-
self, mock_search_client, mock_client_options
142+
self, mock_search_client, mock_client_options, mock_get_api_endpoint
139143
):
140144
"""Test initialization uses explicit location when resource has none."""
145+
mock_get_api_endpoint.return_value = "eu-discoveryengine.googleapis.com"
141146
DiscoveryEngineSearchTool(
142147
data_store_id="test_data_store",
143148
location="eu",
@@ -239,12 +244,14 @@ def test_init_with_global_location_keeps_default_endpoint(
239244
credentials="credentials", client_options=None
240245
)
241246

247+
@mock.patch.object(discovery_engine_search_tool, "_get_api_endpoint")
242248
@mock.patch.object(discovery_engine_search_tool, "client_options")
243249
@mock.patch.object(discoveryengine, "SearchServiceClient")
244250
def test_init_with_regional_location_and_quota_project_id(
245-
self, mock_search_client, mock_client_options
251+
self, mock_search_client, mock_client_options, mock_get_api_endpoint
246252
):
247253
"""Test initialization uses endpoint and quota project id together."""
254+
mock_get_api_endpoint.return_value = "eu-discoveryengine.googleapis.com"
248255
mock_credentials = mock.MagicMock()
249256
mock_credentials.quota_project_id = "test-quota-project"
250257

0 commit comments

Comments
 (0)