|
1 | | -import re |
2 | 1 | from functools import cached_property |
3 | 2 | from typing import Any, Self |
4 | 3 |
|
|
14 | 13 | UiPathAPIConfig, |
15 | 14 | VendorType, |
16 | 15 | ) |
| 16 | +from uipath_langchain_client.utils import ( |
| 17 | + CLAUDE_OPUS_4_UNSUPPORTED_SAMPLING_PARAMS, |
| 18 | + is_claude_opus_4_or_above, |
| 19 | +) |
17 | 20 |
|
18 | 21 | try: |
19 | 22 | from anthropic import AnthropicBedrock, AsyncAnthropicBedrock |
@@ -51,16 +54,6 @@ def _patched_format_data_content_block(block: dict) -> dict: |
51 | 54 | ) from e |
52 | 55 |
|
53 | 56 |
|
54 | | -# Sampling parameters that Claude Opus 4+ (reasoning models) do not support. |
55 | | -# The API returns 400 if any of these are present in the request payload. |
56 | | -_CLAUDE_OPUS_4_UNSUPPORTED_PARAMS: frozenset[str] = frozenset({"temperature", "top_k", "top_p"}) |
57 | | - |
58 | | - |
59 | | -def _is_claude_opus_4_or_above(model_name: str) -> bool: |
60 | | - """Return True for Claude Opus 4+ models that reject sampling parameters.""" |
61 | | - return bool(re.search(r"claude-opus-4", model_name, re.IGNORECASE)) |
62 | | - |
63 | | - |
64 | 57 | class UiPathChatBedrockConverse(UiPathBaseChatModel, ChatBedrockConverse): # type: ignore[override] |
65 | 58 | api_config: UiPathAPIConfig = UiPathAPIConfig( |
66 | 59 | api_type=ApiType.COMPLETIONS, |
@@ -89,6 +82,16 @@ def setup_uipath_client(self) -> Self: |
89 | 82 | self.client = WrappedBotoClient(self.uipath_sync_client) |
90 | 83 | return self |
91 | 84 |
|
| 85 | + @override |
| 86 | + def _converse_params(self, **kwargs: Any) -> dict: |
| 87 | + params = super()._converse_params(**kwargs) |
| 88 | + if is_claude_opus_4_or_above(self.model_id): |
| 89 | + inference = params.get("inferenceConfig") |
| 90 | + if isinstance(inference, dict): |
| 91 | + inference.pop("temperature", None) |
| 92 | + inference.pop("topP", None) |
| 93 | + return params |
| 94 | + |
92 | 95 |
|
93 | 96 | class UiPathChatBedrock(UiPathBaseChatModel, ChatBedrock): # type: ignore[override] |
94 | 97 | api_config: UiPathAPIConfig = UiPathAPIConfig( |
@@ -116,6 +119,14 @@ def setup_model_id(cls, values: Any) -> Any: |
116 | 119 | @model_validator(mode="after") |
117 | 120 | def setup_uipath_client(self) -> Self: |
118 | 121 | self.client = WrappedBotoClient(self.uipath_sync_client) |
| 122 | + if is_claude_opus_4_or_above(self.model_id): |
| 123 | + self.temperature = None |
| 124 | + if self.model_kwargs: |
| 125 | + self.model_kwargs = { |
| 126 | + k: v |
| 127 | + for k, v in self.model_kwargs.items() |
| 128 | + if k not in CLAUDE_OPUS_4_UNSUPPORTED_SAMPLING_PARAMS |
| 129 | + } |
119 | 130 | return self |
120 | 131 |
|
121 | 132 | @property |
@@ -170,7 +181,7 @@ def _get_request_payload( |
170 | 181 | **kwargs: Any, |
171 | 182 | ) -> dict: |
172 | 183 | payload = super()._get_request_payload(input_, stop=stop, **kwargs) |
173 | | - if _is_claude_opus_4_or_above(self.model): |
174 | | - for param in _CLAUDE_OPUS_4_UNSUPPORTED_PARAMS: |
| 184 | + if is_claude_opus_4_or_above(self.model): |
| 185 | + for param in CLAUDE_OPUS_4_UNSUPPORTED_SAMPLING_PARAMS: |
175 | 186 | payload.pop(param, None) |
176 | 187 | return payload |
0 commit comments