Skip to content

Commit 0c640e8

Browse files
authored
feat: allow custom-client for OpenAIModel and GeminiModel (strands-agents#1366)
1 parent 894ba80 commit 0c640e8

4 files changed

Lines changed: 249 additions & 10 deletions

File tree

src/strands/models/gemini.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,44 @@ class GeminiConfig(TypedDict, total=False):
5454
def __init__(
5555
self,
5656
*,
57+
client: Optional[genai.Client] = None,
5758
client_args: Optional[dict[str, Any]] = None,
5859
**model_config: Unpack[GeminiConfig],
5960
) -> None:
6061
"""Initialize provider instance.
6162
6263
Args:
64+
client: Pre-configured Gemini client to reuse across requests.
65+
When provided, this client will be reused for all requests and will NOT be closed
66+
by the model. The caller is responsible for managing the client lifecycle.
67+
This is useful for:
68+
- Injecting custom client wrappers
69+
- Reusing connection pools within a single event loop/worker
70+
- Centralizing observability, retries, and networking policy
71+
Note: The client should not be shared across different asyncio event loops.
6372
client_args: Arguments for the underlying Gemini client (e.g., api_key).
6473
For a complete list of supported arguments, see https://googleapis.github.io/python-genai/.
6574
**model_config: Configuration options for the Gemini model.
75+
76+
Raises:
77+
ValueError: If both `client` and `client_args` are provided.
6678
"""
6779
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6880
self.config = GeminiModel.GeminiConfig(**model_config)
6981

82+
# Validate that only one client configuration method is provided
83+
if client is not None and client_args is not None and len(client_args) > 0:
84+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
85+
86+
self._custom_client = client
87+
self.client_args = client_args or {}
88+
7089
# Validate gemini_tools if provided
7190
if "gemini_tools" in self.config:
7291
self._validate_gemini_tools(self.config["gemini_tools"])
7392

7493
logger.debug("config=<%s> | initializing", self.config)
7594

76-
self.client_args = client_args or {}
77-
7895
@override
7996
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
8097
"""Update the Gemini model configuration with the provided arguments.
@@ -97,6 +114,24 @@ def get_config(self) -> GeminiConfig:
97114
"""
98115
return self.config
99116

117+
def _get_client(self) -> genai.Client:
118+
"""Get a Gemini client for making requests.
119+
120+
This method handles client lifecycle management:
121+
- If an injected client was provided during initialization, it returns that client
122+
without managing its lifecycle (caller is responsible for cleanup).
123+
- Otherwise, creates a new genai.Client from client_args.
124+
125+
Returns:
126+
genai.Client: A Gemini client instance.
127+
"""
128+
if self._custom_client is not None:
129+
# Use the injected client (caller manages lifecycle)
130+
return self._custom_client
131+
else:
132+
# Create a new client from client_args
133+
return genai.Client(**self.client_args)
134+
100135
def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part:
101136
"""Format content block into a Gemini part instance.
102137
@@ -382,7 +417,8 @@ async def stream(
382417
"""
383418
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
384419

385-
client = genai.Client(**self.client_args).aio
420+
client = self._get_client().aio
421+
386422
try:
387423
response = await client.models.generate_content_stream(**request)
388424

@@ -465,7 +501,7 @@ async def structured_output(
465501
"response_schema": output_model.model_json_schema(),
466502
}
467503
request = self._format_request(prompt, None, system_prompt, params)
468-
client = genai.Client(**self.client_args).aio
504+
client = self._get_client().aio
469505
response = await client.models.generate_content(**request)
470506
yield {"output": output_model.model_validate(response.parsed)}
471507

src/strands/models/openai.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
10+
from contextlib import asynccontextmanager
11+
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
1112

1213
import openai
1314
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
@@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False):
5556
model_id: str
5657
params: Optional[dict[str, Any]]
5758

58-
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
59+
def __init__(
60+
self,
61+
client: Optional[Client] = None,
62+
client_args: Optional[dict[str, Any]] = None,
63+
**model_config: Unpack[OpenAIConfig],
64+
) -> None:
5965
"""Initialize provider instance.
6066
6167
Args:
62-
client_args: Arguments for the OpenAI client.
68+
client: Pre-configured OpenAI-compatible client to reuse across requests.
69+
When provided, this client will be reused for all requests and will NOT be closed
70+
by the model. The caller is responsible for managing the client lifecycle.
71+
This is useful for:
72+
- Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
73+
- Reusing connection pools within a single event loop/worker
74+
- Centralizing observability, retries, and networking policy
75+
- Pointing to custom model gateways
76+
Note: The client should not be shared across different asyncio event loops.
77+
client_args: Arguments for the OpenAI client (legacy approach).
6378
For a complete list of supported arguments, see https://pypi.org/project/openai/.
6479
**model_config: Configuration options for the OpenAI model.
80+
81+
Raises:
82+
ValueError: If both `client` and `client_args` are provided.
6583
"""
6684
validate_config_keys(model_config, self.OpenAIConfig)
6785
self.config = dict(model_config)
86+
87+
# Validate that only one client configuration method is provided
88+
if client is not None and client_args is not None and len(client_args) > 0:
89+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
90+
91+
self._custom_client = client
6892
self.client_args = client_args or {}
6993

7094
logger.debug("config=<%s> | initializing", self.config)
@@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
422446
case _:
423447
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
424448

449+
@asynccontextmanager
450+
async def _get_client(self) -> AsyncIterator[Any]:
451+
"""Get an OpenAI client for making requests.
452+
453+
This context manager handles client lifecycle management:
454+
- If an injected client was provided during initialization, it yields that client
455+
without closing it (caller manages lifecycle).
456+
- Otherwise, creates a new AsyncOpenAI client from client_args and automatically
457+
closes it when the context exits.
458+
459+
Note: We create a new client per request to avoid connection sharing in the underlying
460+
httpx client, as the asyncio event loop does not allow connections to be shared.
461+
For more details, see https://github.com/encode/httpx/discussions/2959.
462+
463+
Yields:
464+
Client: An OpenAI-compatible client instance.
465+
"""
466+
if self._custom_client is not None:
467+
# Use the injected client (caller manages lifecycle)
468+
yield self._custom_client
469+
else:
470+
# Create a new client from client_args
471+
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
472+
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
473+
# refer to https://github.com/encode/httpx/discussions/2959.
474+
async with openai.AsyncOpenAI(**self.client_args) as client:
475+
yield client
476+
425477
@override
426478
async def stream(
427479
self,
@@ -457,7 +509,7 @@ async def stream(
457509
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
458510
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
459511
# https://github.com/encode/httpx/discussions/2959.
460-
async with openai.AsyncOpenAI(**self.client_args) as client:
512+
async with self._get_client() as client:
461513
try:
462514
response = await client.chat.completions.create(**request)
463515
except openai.BadRequestError as e:
@@ -576,7 +628,7 @@ async def structured_output(
576628
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
577629
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
578630
# https://github.com/encode/httpx/discussions/2959.
579-
async with openai.AsyncOpenAI(**self.client_args) as client:
631+
async with self._get_client() as client:
580632
try:
581633
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
582634
model=self.get_config()["model_id"],

tests/strands/models/test_gemini.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap
720720

721721
assert "Gemini API returned non-JSON error" in caplog.text
722722
assert f"error_message=<{error_message}>" in caplog.text
723+
724+
725+
@pytest.mark.asyncio
726+
async def test_stream_with_injected_client(model_id, agenerator, alist):
727+
"""Test that stream works with an injected client and doesn't close it."""
728+
# Create a mock injected client
729+
mock_injected_client = unittest.mock.Mock()
730+
mock_injected_client.aio = unittest.mock.AsyncMock()
731+
732+
mock_injected_client.aio.models.generate_content_stream.return_value = agenerator(
733+
[
734+
genai.types.GenerateContentResponse(
735+
candidates=[
736+
genai.types.Candidate(
737+
content=genai.types.Content(
738+
parts=[genai.types.Part(text="Hello")],
739+
),
740+
finish_reason="STOP",
741+
),
742+
],
743+
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
744+
prompt_token_count=1,
745+
total_token_count=3,
746+
),
747+
),
748+
]
749+
)
750+
751+
# Create model with injected client
752+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
753+
754+
messages = [{"role": "user", "content": [{"text": "test"}]}]
755+
response = model.stream(messages)
756+
tru_events = await alist(response)
757+
758+
# Verify events were generated
759+
assert len(tru_events) > 0
760+
761+
# Verify the injected client was used
762+
mock_injected_client.aio.models.generate_content_stream.assert_called_once()
763+
764+
765+
@pytest.mark.asyncio
766+
async def test_structured_output_with_injected_client(model_id, weather_output, alist):
767+
"""Test that structured_output works with an injected client and doesn't close it."""
768+
# Create a mock injected client
769+
mock_injected_client = unittest.mock.Mock()
770+
mock_injected_client.aio = unittest.mock.AsyncMock()
771+
772+
mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock(
773+
parsed=weather_output.model_dump()
774+
)
775+
776+
# Create model with injected client
777+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
778+
779+
messages = [{"role": "user", "content": [{"text": "Generate weather"}]}]
780+
stream = model.structured_output(type(weather_output), messages)
781+
events = await alist(stream)
782+
783+
# Verify output was generated
784+
assert len(events) == 1
785+
assert events[0] == {"output": weather_output}
786+
787+
# Verify the injected client was used
788+
mock_injected_client.aio.models.generate_content.assert_called_once()
789+
790+
791+
def test_init_with_both_client_and_client_args_raises_error():
792+
"""Test that providing both client and client_args raises ValueError."""
793+
mock_client = unittest.mock.Mock()
794+
795+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
796+
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

tests/strands/models/test_openai.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
def openai_client():
1414
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
1515
mock_client = unittest.mock.AsyncMock()
16-
mock_client_cls.return_value.__aenter__.return_value = mock_client
16+
# Make the mock client work as an async context manager
17+
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
18+
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
19+
mock_client_cls.return_value = mock_client
1720
yield mock_client
1821

1922

@@ -986,3 +989,77 @@ def test_format_request_messages_drops_cache_points():
986989
]
987990

988991
assert result == expected
992+
993+
994+
@pytest.mark.asyncio
995+
async def test_stream_with_injected_client(model_id, agenerator, alist):
996+
"""Test that stream works with an injected client and doesn't close it."""
997+
# Create a mock injected client
998+
mock_injected_client = unittest.mock.AsyncMock()
999+
mock_injected_client.close = unittest.mock.AsyncMock()
1000+
1001+
mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None)
1002+
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
1003+
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
1004+
mock_event_3 = unittest.mock.Mock()
1005+
1006+
mock_injected_client.chat.completions.create = unittest.mock.AsyncMock(
1007+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3])
1008+
)
1009+
1010+
# Create model with injected client
1011+
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})
1012+
1013+
messages = [{"role": "user", "content": [{"text": "test"}]}]
1014+
response = model.stream(messages)
1015+
tru_events = await alist(response)
1016+
1017+
# Verify events were generated
1018+
assert len(tru_events) > 0
1019+
1020+
# Verify the injected client was used
1021+
mock_injected_client.chat.completions.create.assert_called_once()
1022+
1023+
# Verify the injected client was NOT closed
1024+
mock_injected_client.close.assert_not_called()
1025+
1026+
1027+
@pytest.mark.asyncio
1028+
async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist):
1029+
"""Test that structured_output works with an injected client and doesn't close it."""
1030+
# Create a mock injected client
1031+
mock_injected_client = unittest.mock.AsyncMock()
1032+
mock_injected_client.close = unittest.mock.AsyncMock()
1033+
1034+
mock_parsed_instance = test_output_model_cls(name="John", age=30)
1035+
mock_choice = unittest.mock.Mock()
1036+
mock_choice.message.parsed = mock_parsed_instance
1037+
mock_response = unittest.mock.Mock()
1038+
mock_response.choices = [mock_choice]
1039+
1040+
mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response)
1041+
1042+
# Create model with injected client
1043+
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})
1044+
1045+
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
1046+
stream = model.structured_output(test_output_model_cls, messages)
1047+
events = await alist(stream)
1048+
1049+
# Verify output was generated
1050+
assert len(events) == 1
1051+
assert events[0] == {"output": test_output_model_cls(name="John", age=30)}
1052+
1053+
# Verify the injected client was used
1054+
mock_injected_client.beta.chat.completions.parse.assert_called_once()
1055+
1056+
# Verify the injected client was NOT closed
1057+
mock_injected_client.close.assert_not_called()
1058+
1059+
1060+
def test_init_with_both_client_and_client_args_raises_error():
1061+
"""Test that providing both client and client_args raises ValueError."""
1062+
mock_client = unittest.mock.AsyncMock()
1063+
1064+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
1065+
OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

0 commit comments

Comments
 (0)