Skip to content
10 changes: 10 additions & 0 deletions src/google/adk/models/apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .llm_response import LlmResponse

if TYPE_CHECKING:
from google.auth.credentials import Credentials
from google.genai import Client

from .llm_request import LlmRequest
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
custom_headers: dict[str, str] | None = None,
retry_options: Optional[types.HttpRetryOptions] = None,
api_type: ApiType | str = ApiType.UNKNOWN,
credentials: Optional[Credentials] = None,
):
"""Initializes the Apigee LLM backend.

Expand Down Expand Up @@ -123,6 +125,11 @@ def __init__(
authorization headers in Vertex AI and Gemini API calls.
retry_options: Allow google-genai to retry failed responses.
api_type: The type of API to use. One of `ApiType` or string.
credentials: Optional google-auth credentials passed through to the
underlying `genai.Client`. Use this when the Apigee proxy requires
additional OAuth scopes (e.g., `userinfo.email` for tokeninfo-based
caller identification). When omitted, the default `genai.Client`
authentication flow is used.
""" # fmt: skip

super().__init__(model=model, retry_options=retry_options)
Expand Down Expand Up @@ -165,6 +172,7 @@ def __init__(
)
self._custom_headers = custom_headers or {}
self._user_agent = f'google-adk/{adk_version.__version__}'
self._credentials = credentials

@classmethod
@override
Expand Down Expand Up @@ -239,6 +247,8 @@ def api_client(self) -> Client:
if self._isvertexai:
kwargs_for_client['project'] = self._project
kwargs_for_client['location'] = self._location
if self._credentials is not None:
kwargs_for_client['credentials'] = self._credentials

return Client(
http_options=http_options,
Expand Down
65 changes: 65 additions & 0 deletions tests/unittests/models/test_apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,71 @@ def test_parse_response_usage_metadata():
assert llm_response.usage_metadata.thoughts_token_count == 4


@pytest.mark.asyncio
@mock.patch('google.genai.Client')
async def test_api_client_passes_credentials_when_provided(
mock_client_constructor, llm_request
):
"""Tests that credentials passed to __init__ are forwarded to genai.Client."""
mock_credentials = mock.Mock()

mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance

apigee_llm = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
credentials=mock_credentials,
)
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]

_, kwargs = mock_client_constructor.call_args
assert kwargs['credentials'] is mock_credentials


@pytest.mark.asyncio
@mock.patch('google.genai.Client')
async def test_api_client_omits_credentials_when_not_provided(
mock_client_constructor, llm_request
):
"""Tests that credentials kwarg is not forwarded when not supplied."""
mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance

apigee_llm = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
)
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]

_, kwargs = mock_client_constructor.call_args
assert 'credentials' not in kwargs


def test_parse_response_with_refusal():
"""Tests that CompletionsHTTPClient parses refusal correctly."""
client = CompletionsHTTPClient(base_url='http://test')
Expand Down