Skip to content
9 changes: 9 additions & 0 deletions src/google/adk/models/apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING

from google.adk import version as adk_version
import google.auth
from google.genai import types
import httpx
import tenacity
Expand All @@ -52,6 +53,11 @@
_PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT'
_LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION'

_APIGEE_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
]

_CUSTOM_METADATA_FIELDS = (
'id',
'created',
Expand Down Expand Up @@ -234,13 +240,16 @@ def api_client(self) -> Client:
**kwargs_for_http_options,
)

credentials, _ = google.auth.default(scopes=_APIGEE_SCOPES)

kwargs_for_client = {}
kwargs_for_client['vertexai'] = self._isvertexai
if self._isvertexai:
kwargs_for_client['project'] = self._project
kwargs_for_client['location'] = self._location

return Client(
credentials=credentials,
http_options=http_options,
**kwargs_for_client,
)
Expand Down
47 changes: 47 additions & 0 deletions tests/unittests/models/test_apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest import mock
from unittest.mock import AsyncMock

from google.adk.models.apigee_llm import _APIGEE_SCOPES
from google.adk.models.apigee_llm import ApigeeLlm
from google.adk.models.apigee_llm import CompletionsHTTPClient
from google.adk.models.llm_request import LlmRequest
Expand All @@ -33,6 +34,16 @@
PROXY_URL = 'https://test.apigee.net'


@pytest.fixture(autouse=True)
def mock_google_auth_default():
"""Mocks google.auth.default to avoid requiring real credentials in tests."""
with mock.patch(
'google.adk.models.apigee_llm.google.auth.default'
) as mock_auth:
mock_auth.return_value = (mock.Mock(), 'test-project')
yield mock_auth


@pytest.fixture
def llm_request():
"""Provides a sample LlmRequest for testing."""
Expand Down Expand Up @@ -651,6 +662,42 @@ def test_parse_response_usage_metadata():
assert llm_response.usage_metadata.thoughts_token_count == 4


@pytest.mark.asyncio
@mock.patch('google.genai.Client')
async def test_api_client_requests_userinfo_email_scope(
mock_client_constructor, llm_request, mock_google_auth_default
):
"""Tests that api_client requests userinfo.email scope for Apigee Gateway tokeninfo."""
mock_credentials = mock.Mock()
mock_google_auth_default.return_value = (mock_credentials, 'test-project')

mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance

apigee_llm = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
)
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]

mock_google_auth_default.assert_called_once_with(scopes=_APIGEE_SCOPES)

_, kwargs = mock_client_constructor.call_args
assert kwargs['credentials'] is mock_credentials


def test_parse_response_with_refusal():
"""Tests that CompletionsHTTPClient parses refusal correctly."""
client = CompletionsHTTPClient(base_url='http://test')
Expand Down