diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a1c52101..cca4d9418 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed - `pw.io.milvus.write` no longer intermittently fails with a "server unavailable" / "connect failed" error when pointed at a local `.db` file. The embedded local Milvus server reports itself as started before it actually accepts connections, so under load the first connection could lose the race against the server coming up; the connector now retries the initial connection until the local server is ready. +- `BedrockChat` now correctly routes `top_k` and other model-specific arguments to the AWS Converse API via `additionalModelRequestFields`. - Improved concurrent write handling in pw.io.sqlite.write for SQLite databases. Writes to the same database file now produce deterministic output in multi-worker and multi-table setups. - `pw.io.elasticsearch.write` no longer fails when a minibatch is big enough that its Elasticsearch `_bulk` request would exceed a server-side limit. The connector reads both the cluster's `http.max_content_length` (the `413 Request Entity Too Large` limit) and `indexing_pressure.memory.limit` (the `429 Too Many Requests` limit, which on a small-heap node trips well below 100 MB) at start-up, and splits the buffered documents across as many bulk requests as needed to stay under whichever is hit first — so large batches are still written in as few requests as possible instead of being rejected. (Both limits fall back to a conservative default if they cannot be read.) - `pw.io.elasticsearch.write` now retries transient bulk failures with backoff instead of failing the run on the first hiccup. A whole-request rejection or an individual document failing with `429`/`503` (back-pressure / temporary unavailability) is retried — resending only the documents the server reports as not yet applied, so a retry never duplicates data — while deterministic per-document failures (e.g. a type-mismatched value rejected with `400`) are now logged and skipped rather than silently dropped. diff --git a/python/pathway/xpacks/llm/llms.py b/python/pathway/xpacks/llm/llms.py index 0970bd618..405a9e4d1 100644 --- a/python/pathway/xpacks/llm/llms.py +++ b/python/pathway/xpacks/llm/llms.py @@ -795,6 +795,7 @@ class BedrockChat(BaseChat): max_tokens: Maximum number of tokens to generate. Defaults to ``1024``. temperature: Sampling temperature (``0.0`` to ``1.0``). top_p: Top-p sampling parameter. + top_k: Top-k sampling parameter (supported by Anthropic models). stop_sequences: List of sequences that will stop generation. Example: @@ -818,6 +819,9 @@ class BedrockChat(BaseChat): ROLE_SYSTEM = "system" _SUPPORTED_ROLES = {ROLE_USER, ROLE_ASSISTANT, ROLE_SYSTEM} + # Arguments specific to certain models (sent via additionalModelRequestFields) + _MODEL_SPECIFIC_ARGS = {"top_k"} + @staticmethod def _convert_messages_to_bedrock_format(messages: list[dict]) -> list[dict]: """Convert OpenAI-style messages to AWS Bedrock Converse API format.""" @@ -971,6 +975,15 @@ async def __wrapped__(self, messages: list[dict] | pw.Json, **kwargs) -> str | N "inferenceConfig": inference_config, } + # Extract model-specific parameters (like top_k) into additionalModelRequestFields + additional_fields = {} + for arg in self._MODEL_SPECIFIC_ARGS: + if arg in kwargs: + additional_fields[arg] = kwargs.pop(arg) + + if additional_fields: + converse_kwargs["additionalModelRequestFields"] = additional_fields + if system_prompts: converse_kwargs["system"] = system_prompts @@ -1024,8 +1037,7 @@ def _accepts_call_arg(self, arg_name: str) -> bool: "temperature", "top_p", "stop_sequences", - "top_k", # Some models support this - } + }.union(self._MODEL_SPECIFIC_ARGS) return arg_name in supported_args diff --git a/python/pathway/xpacks/llm/tests/test_llms.py b/python/pathway/xpacks/llm/tests/test_llms.py index 2a7a1aec7..82f084a00 100644 --- a/python/pathway/xpacks/llm/tests/test_llms.py +++ b/python/pathway/xpacks/llm/tests/test_llms.py @@ -183,7 +183,7 @@ def test_bedrock_empty_init_kwargs(): assert llm.model is None -BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences"] +BEDROCK_VALID_ARGS = ["max_tokens", "temperature", "top_p", "stop_sequences", "top_k"] BEDROCK_INVALID_ARGS = ["made_up_arg", "logit_bias"] @@ -197,3 +197,30 @@ def test_bedrock_call_args(model_id, call_arg): # BedrockChat always returns based on supported_args, model_id doesn't affect it assert llm._accepts_call_arg(call_arg) is (call_arg in BEDROCK_VALID_ARGS) + + +@pytest.mark.asyncio +async def test_bedrock_dynamic_args_routing(): + from unittest.mock import AsyncMock, patch + + llm = llms.BedrockChat(model_id="anthropic.claude-3", region_name="us-east-1") + + mock_client = AsyncMock() + mock_client.converse = AsyncMock( + return_value={"output": {"message": {"content": [{"text": "mocked"}]}}} + ) + + mock_session = AsyncMock() + mock_session.client.return_value.__aenter__.return_value = mock_client + + with patch.object(llm, "_session", mock_session): + await llm.__wrapped__( + [{"role": "user", "content": "hi"}], top_k=250, temperature=0.7 + ) + + mock_client.converse.assert_called_once() + call_kwargs = mock_client.converse.call_args.kwargs + + assert call_kwargs["inferenceConfig"]["temperature"] == 0.7 + assert "additionalModelRequestFields" in call_kwargs + assert call_kwargs["additionalModelRequestFields"]["top_k"] == 250