Skip to content

Commit 42b1a08

Browse files
committed
fix(test): correct async context manager mocking for bedrock routing
1 parent 6d9e806 commit 42b1a08

1 file changed

Lines changed: 33 additions & 1 deletion

File tree

python/pathway/xpacks/llm/tests/test_llms.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_bedrock_empty_init_kwargs():
183183
assert llm.model is None
184184

185185

186-
BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences"]
186+
BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences", "top_k"]
187187
BEDROCK_INVALID_ARGS = ["made_up_arg", "logit_bias"]
188188

189189

@@ -197,3 +197,35 @@ def test_bedrock_call_args(model_id, call_arg):
197197

198198
# BedrockChat always returns based on supported_args, model_id doesn't affect it
199199
assert llm._accepts_call_arg(call_arg) is (call_arg in BEDROCK_VALID_ARGS)
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_bedrock_dynamic_args_routing():
204+
from unittest.mock import AsyncMock, MagicMock, patch
205+
206+
llm = llms.BedrockChat(model_id="anthropic.claude-3", region_name="us-east-1")
207+
208+
mock_client = AsyncMock()
209+
mock_client.converse = AsyncMock(
210+
return_value={"output": {"message": {"content": [{"text": "mocked"}]}}}
211+
)
212+
213+
# Explicit async context manager returned by session.client(...)
214+
mock_client_cm = AsyncMock()
215+
mock_client_cm.__aenter__.return_value = mock_client
216+
mock_client_cm.__aexit__.return_value = None
217+
218+
mock_session = MagicMock()
219+
mock_session.client.return_value = mock_client_cm
220+
221+
with patch.object(llm, "_session", mock_session):
222+
await llm.__wrapped__(
223+
[{"role": "user", "content": "hi"}], top_k=250, temperature=0.7
224+
)
225+
226+
mock_client.converse.assert_called_once()
227+
call_kwargs = mock_client.converse.call_args.kwargs
228+
229+
assert call_kwargs["inferenceConfig"]["temperature"] == 0.7
230+
assert "additionalModelRequestFields" in call_kwargs
231+
assert call_kwargs["additionalModelRequestFields"]["top_k"] == 250

0 commit comments

Comments
 (0)