|
1 | 1 | """LLM node for ReAct Agent graph.""" |
2 | 2 |
|
3 | | -from typing import Any, Sequence |
4 | | -from typing import Literal, Sequence, TypeVar |
| 3 | +from typing import Sequence, TypeVar |
5 | 4 |
|
6 | 5 | from langchain_core.language_models import BaseChatModel |
7 | 6 | from langchain_core.messages import AIMessage, AnyMessage, ToolCall |
8 | 7 | from langchain_core.tools import BaseTool |
9 | 8 | from pydantic import BaseModel |
10 | 9 | from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode |
11 | 10 |
|
| 11 | +from uipath_langchain.agent.tools.static_args import ( |
| 12 | + apply_static_argument_properties_to_schema, |
| 13 | +) |
12 | 14 | from uipath_langchain.agent.tools.structured_tool_with_argument_properties import ( |
13 | 15 | StructuredToolWithArgumentProperties, |
14 | 16 | ) |
| 17 | +from uipath_langchain.llm import get_payload_handler |
15 | 18 |
|
16 | 19 | from ..exceptions import AgentTerminationException |
17 | 20 | from .constants import ( |
18 | 21 | DEFAULT_MAX_CONSECUTIVE_THINKING_MESSAGES, |
19 | 22 | DEFAULT_MAX_LLM_MESSAGES, |
20 | 23 | ) |
21 | 24 | from .types import FLOW_CONTROL_TOOLS, AgentGraphState |
22 | | -from uipath_langchain.chat.types import APIFlavor |
23 | | - |
24 | | -from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES |
25 | | -from .types import AgentGraphState |
26 | | -from .utils import count_consecutive_thinking_messages |
27 | | - |
28 | | -OPENAI_COMPATIBLE_CHAT_MODELS = ( |
29 | | - "UiPathChatOpenAI", |
30 | | - "AzureChatOpenAI", |
31 | | - "ChatOpenAI", |
32 | | - "UiPathChat", |
33 | | - "UiPathAzureChatOpenAI", |
34 | | -) |
35 | | - |
36 | | - |
37 | | -def _get_required_tool_choice_by_model( |
38 | | - model: BaseChatModel, |
39 | | -) -> str | dict[str, Any]: |
40 | | - """Get the appropriate tool_choice value to enforce tool usage based on model type. |
41 | | -
|
42 | | - Returns: |
43 | | - - "required" for OpenAI compatible models |
44 | | - - "any" for Bedrock Converse and Vertex models (string format) |
45 | | - - {"type": "any"} for Bedrock Invoke API (dict format required) |
46 | | - """ |
47 | | - model_class_name = model.__class__.__name__ |
48 | | - if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS: |
49 | | - return "required" |
50 | | - |
51 | | - api_flavor = getattr(model, "api_flavor", None) |
52 | | - if api_flavor == APIFlavor.AWS_BEDROCK_INVOKE: |
53 | | - return {"type": "any"} |
54 | | - |
55 | | - return "any" |
| 25 | +from .utils import count_consecutive_thinking_messages, extract_input_data_from_state |
56 | 26 |
|
57 | 27 |
|
58 | 28 | def _filter_control_flow_tool_calls( |
@@ -91,7 +61,8 @@ def create_llm_node( |
91 | 61 | before enforcing tool usage. 0 = force tools every time. |
92 | 62 | """ |
93 | 63 | bindable_tools = list(tools) if tools else [] |
94 | | - tool_choice_required_value = _get_required_tool_choice_by_model(model) |
| 64 | + payload_handler = get_payload_handler(model) |
| 65 | + tool_choice_required_value = payload_handler.get_required_tool_choice() |
95 | 66 |
|
96 | 67 | async def llm_node(state: StateT): |
97 | 68 | messages: list[AnyMessage] = state.messages |
|
0 commit comments