|
1 | 1 | """LLM node for ReAct Agent graph.""" |
2 | 2 |
|
3 | | -from typing import Literal, Sequence, TypeVar |
| 3 | +from typing import Sequence, TypeVar |
4 | 4 |
|
5 | 5 | from langchain_core.language_models import BaseChatModel |
6 | 6 | from langchain_core.messages import AIMessage, AnyMessage, ToolCall |
|
14 | 14 | from uipath_langchain.agent.tools.structured_tool_with_argument_properties import ( |
15 | 15 | StructuredToolWithArgumentProperties, |
16 | 16 | ) |
| 17 | +from uipath_langchain.llm import get_payload_handler |
17 | 18 |
|
18 | 19 | from ..exceptions import AgentTerminationException |
19 | 20 | from .constants import ( |
|
23 | 24 | from .types import FLOW_CONTROL_TOOLS, AgentGraphState |
24 | 25 | from .utils import count_consecutive_thinking_messages, extract_input_data_from_state |
25 | 26 |
|
26 | | -OPENAI_COMPATIBLE_CHAT_MODELS = ( |
27 | | - "UiPathChatOpenAI", |
28 | | - "AzureChatOpenAI", |
29 | | - "ChatOpenAI", |
30 | | - "UiPathChat", |
31 | | - "UiPathAzureChatOpenAI", |
32 | | -) |
33 | | - |
34 | | - |
35 | | -def _get_required_tool_choice_by_model( |
36 | | - model: BaseChatModel, |
37 | | -) -> Literal["required", "any"]: |
38 | | - """Get the appropriate tool_choice value to enforce tool usage based on model type. |
39 | | -
|
40 | | - "required" - OpenAI compatible required tool_choice value |
41 | | - "any" - Vertex and Bedrock parameter for required tool_choice value |
42 | | - """ |
43 | | - model_class_name = model.__class__.__name__ |
44 | | - if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS: |
45 | | - return "required" |
46 | | - return "any" |
47 | | - |
48 | 27 |
|
49 | 28 | def _filter_control_flow_tool_calls( |
50 | 29 | tool_calls: list[ToolCall], |
@@ -82,7 +61,8 @@ def create_llm_node( |
82 | 61 | before enforcing tool usage. 0 = force tools every time. |
83 | 62 | """ |
84 | 63 | bindable_tools = list(tools) if tools else [] |
85 | | - tool_choice_required_value = _get_required_tool_choice_by_model(model) |
| 64 | + payload_handler = get_payload_handler(model) |
| 65 | + tool_choice_required_value = payload_handler.get_required_tool_choice() |
86 | 66 |
|
87 | 67 | async def llm_node(state: StateT): |
88 | 68 | messages: list[AnyMessage] = state.messages |
|
0 commit comments