Skip to content

Commit 294dee7

Browse files
authored
fix(llms): dynamic routing of Bedrock model-specific inference args (e.g., top_k) (#248)
1 parent a527f43 commit 294dee7

3 files changed

Lines changed: 43 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
1010
- `pw.io.postgres.write` now streams each batch into PostgreSQL through the binary `COPY` protocol instead of issuing one `INSERT` per row, giving a large throughput improvement (up to ~100x) on bulk writes. Both output modes use it: stream-of-changes copies straight into the target, while snapshot mode stages each batch in a temporary table and merges it with a single set-based upsert/delete.
1111

1212
### Fixed
13+
- `BedrockChat` now correctly routes `top_k` and other model-specific arguments to the AWS Converse API via `additionalModelRequestFields`.
1314
- Improved concurrent write handling in pw.io.sqlite.write for SQLite databases. Writes to the same database file now produce deterministic output in multi-worker and multi-table setups.
1415
- `pw.io.elasticsearch.write` no longer fails when a minibatch is big enough that its Elasticsearch `_bulk` request would exceed a server-side limit. The connector reads both the cluster's `http.max_content_length` (the `413 Request Entity Too Large` limit) and `indexing_pressure.memory.limit` (the `429 Too Many Requests` limit, which on a small-heap node trips well below 100 MB) at start-up, and splits the buffered documents across as many bulk requests as needed to stay under whichever is hit first — so large batches are still written in as few requests as possible instead of being rejected. (Both limits fall back to a conservative default if they cannot be read.)
1516
- `pw.io.elasticsearch.write` now retries transient bulk failures with backoff instead of failing the run on the first hiccup. A whole-request rejection or an individual document failing with `429`/`503` (back-pressure / temporary unavailability) is retried — resending only the documents the server reports as not yet applied, so a retry never duplicates data — while deterministic per-document failures (e.g. a type-mismatched value rejected with `400`) are now logged and skipped rather than silently dropped.

python/pathway/xpacks/llm/llms.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ class BedrockChat(BaseChat):
795795
max_tokens: Maximum number of tokens to generate. Defaults to ``1024``.
796796
temperature: Sampling temperature (``0.0`` to ``1.0``).
797797
top_p: Top-p sampling parameter.
798+
top_k: Top-k sampling parameter (supported by Anthropic models).
798799
stop_sequences: List of sequences that will stop generation.
799800
800801
Example:
@@ -818,6 +819,9 @@ class BedrockChat(BaseChat):
818819
ROLE_SYSTEM = "system"
819820
_SUPPORTED_ROLES = {ROLE_USER, ROLE_ASSISTANT, ROLE_SYSTEM}
820821

822+
# Arguments specific to certain models (sent via additionalModelRequestFields)
823+
_MODEL_SPECIFIC_ARGS = {"top_k"}
824+
821825
@staticmethod
822826
def _convert_messages_to_bedrock_format(messages: list[dict]) -> list[dict]:
823827
"""Convert OpenAI-style messages to AWS Bedrock Converse API format."""
@@ -971,6 +975,15 @@ async def __wrapped__(self, messages: list[dict] | pw.Json, **kwargs) -> str | N
971975
"inferenceConfig": inference_config,
972976
}
973977

978+
# Extract model-specific parameters (like top_k) into additionalModelRequestFields
979+
additional_fields = {}
980+
for arg in self._MODEL_SPECIFIC_ARGS:
981+
if arg in kwargs:
982+
additional_fields[arg] = kwargs.pop(arg)
983+
984+
if additional_fields:
985+
converse_kwargs["additionalModelRequestFields"] = additional_fields
986+
974987
if system_prompts:
975988
converse_kwargs["system"] = system_prompts
976989

@@ -1024,8 +1037,7 @@ def _accepts_call_arg(self, arg_name: str) -> bool:
10241037
"temperature",
10251038
"top_p",
10261039
"stop_sequences",
1027-
"top_k", # Some models support this
1028-
}
1040+
}.union(self._MODEL_SPECIFIC_ARGS)
10291041
return arg_name in supported_args
10301042

10311043

python/pathway/xpacks/llm/tests/test_llms.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_bedrock_empty_init_kwargs():
183183
assert llm.model is None
184184

185185

186-
BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences"]
186+
BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences", "top_k"]
187187
BEDROCK_INVALID_ARGS = ["made_up_arg", "logit_bias"]
188188

189189

@@ -197,3 +197,30 @@ def test_bedrock_call_args(model_id, call_arg):
197197

198198
# BedrockChat always returns based on supported_args, model_id doesn't affect it
199199
assert llm._accepts_call_arg(call_arg) is (call_arg in BEDROCK_VALID_ARGS)
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_bedrock_dynamic_args_routing():
204+
from unittest.mock import AsyncMock, patch
205+
206+
llm = llms.BedrockChat(model_id="anthropic.claude-3", region_name="us-east-1")
207+
208+
mock_client = AsyncMock()
209+
mock_client.converse = AsyncMock(
210+
return_value={"output": {"message": {"content": [{"text": "mocked"}]}}}
211+
)
212+
213+
mock_session = AsyncMock()
214+
mock_session.client.return_value.__aenter__.return_value = mock_client
215+
216+
with patch.object(llm, "_session", mock_session):
217+
await llm.__wrapped__(
218+
[{"role": "user", "content": "hi"}], top_k=250, temperature=0.7
219+
)
220+
221+
mock_client.converse.assert_called_once()
222+
call_kwargs = mock_client.converse.call_args.kwargs
223+
224+
assert call_kwargs["inferenceConfig"]["temperature"] == 0.7
225+
assert "additionalModelRequestFields" in call_kwargs
226+
assert call_kwargs["additionalModelRequestFields"]["top_k"] == 250

0 commit comments

Comments
 (0)