Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-langchain"
version = "0.4.10"
version = "0.4.11"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
2 changes: 2 additions & 0 deletions src/uipath_langchain/agent/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .exceptions import (
AgentNodeRoutingException,
AgentStateException,
AgentTerminationException,
)

__all__ = [
"AgentNodeRoutingException",
"AgentStateException",
"AgentTerminationException",
]
4 changes: 4 additions & 0 deletions src/uipath_langchain/agent/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class AgentNodeRoutingException(Exception):

class AgentTerminationException(UiPathRuntimeError):
pass


class AgentStateException(Exception):
pass
37 changes: 27 additions & 10 deletions src/uipath_langchain/agent/guardrails/actions/escalate_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
)
from uipath.runtime.errors import UiPathErrorCode

from ...exceptions import AgentTerminationException
from ...exceptions import AgentStateException, AgentTerminationException
from ...react.types import AgentGuardrailsGraphState
from ...react.utils import extract_current_tool_call_index, find_latest_ai_message
from ..types import ExecutionStage
from ..utils import _extract_tool_args_from_message, get_message_content
from .base_action import GuardrailAction, GuardrailActionNode
Expand Down Expand Up @@ -420,9 +421,10 @@ def _process_tool_escalation_response(
if not msgs or reviewed_field not in escalation_result:
return {}

last_message = msgs[-1]
if execution_stage == ExecutionStage.PRE_EXECUTION:
if not isinstance(last_message, AIMessage):
# Find the latest AI message instead of assuming last message is AI
ai_message = find_latest_ai_message(msgs)
if not ai_message:
return {}

# Get reviewed tool calls args from escalation result
Expand All @@ -434,25 +436,40 @@ def _process_tool_escalation_response(
if not isinstance(reviewed_tool_calls_args, dict):
return {}

# Find and update only the tool call with matching name
if last_message.tool_calls:
tool_calls = list(last_message.tool_calls)
for tool_call in tool_calls:
# Find the current tool call index for the specific tool
if ai_message.tool_calls:
tool_calls = list(ai_message.tool_calls)
current_index = extract_current_tool_call_index(msgs, tool_name)

# If we found the current index and it's valid
if current_index is not None and current_index < len(tool_calls):
tool_call = tool_calls[current_index]
call_name = (
tool_call.get("name")
if isinstance(tool_call, dict)
else getattr(tool_call, "name", None)
)

# Verify this is the correct tool by name
if call_name == tool_name:
# Update args for the matching tool call
# Update args for the specific tool call at current index
if isinstance(reviewed_tool_calls_args, dict):
if isinstance(tool_call, dict):
tool_call["args"] = reviewed_tool_calls_args
else:
tool_call.args = reviewed_tool_calls_args
break
last_message.tool_calls = tool_calls

ai_message.tool_calls = tool_calls
else:
raise AgentStateException(
f"Tool call name [{call_name}] does not match expected tool name [{tool_name}]."
)
else:
return {}

else:
# POST_EXECUTION: last message should be ToolMessage for tool escalation
last_message = msgs[-1]
if not isinstance(last_message, ToolMessage):
return {}

Expand Down
84 changes: 46 additions & 38 deletions src/uipath_langchain/agent/guardrails/actions/filter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode

from uipath_langchain.agent.guardrails.types import ExecutionStage
from uipath_langchain.agent.react.utils import (
extract_current_tool_call_index,
find_latest_ai_message,
)

from ...exceptions import AgentTerminationException
from ...exceptions import AgentStateException, AgentTerminationException
from ...react.types import AgentGuardrailsGraphState
from .base_action import GuardrailAction, GuardrailActionNode

Expand Down Expand Up @@ -149,12 +153,9 @@ def _filter_tool_input_fields(

# Find the AIMessage with tool calls
# At PRE_EXECUTION, this is always the last message
ai_message = None
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_message = msg
break
ai_message = find_latest_ai_message(msgs)
if ai_message is None or not ai_message.tool_calls:
return {}

if ai_message is None:
return {}
Expand All @@ -165,40 +166,47 @@ def _filter_tool_input_fields(
tool_calls = list(ai_message.tool_calls)
modified = False

for tool_call in tool_calls:
call_name = (
tool_call.get("name")
current_tool_call_index = extract_current_tool_call_index(msgs, tool_name)
if current_tool_call_index is None:
return {}

tool_call = tool_calls[current_tool_call_index]

call_name = (
tool_call.get("name")
if isinstance(tool_call, dict)
else getattr(tool_call, "name", None)
)

if call_name == tool_name:
# Get the current args
args = (
tool_call.get("args")
if isinstance(tool_call, dict)
else getattr(tool_call, "name", None)
else getattr(tool_call, "args", None)
)

if call_name == tool_name:
# Get the current args
args = (
tool_call.get("args")
if isinstance(tool_call, dict)
else getattr(tool_call, "args", None)
)

if args and isinstance(args, dict):
# Filter out the specified input fields
filtered_args = args.copy()
for field_ref in fields_to_filter:
# Only filter input fields
if (
field_ref.source == FieldSource.INPUT
and field_ref.path in filtered_args
):
del filtered_args[field_ref.path]
modified = True

# Update the tool call with filtered args
if isinstance(tool_call, dict):
tool_call["args"] = filtered_args
else:
tool_call.args = filtered_args

break
if args and isinstance(args, dict):
# Filter out the specified input fields
filtered_args = args.copy()
for field_ref in fields_to_filter:
# Only filter input fields
if (
field_ref.source == FieldSource.INPUT
and field_ref.path in filtered_args
):
del filtered_args[field_ref.path]
modified = True

# Update the tool call with filtered args
if isinstance(tool_call, dict):
tool_call["args"] = filtered_args
else:
tool_call.args = filtered_args
else:
raise AgentStateException(
f"Tool call name [{call_name}] does not match expected tool name [{tool_name}]."
)

if modified:
ai_message.tool_calls = tool_calls
Expand Down
3 changes: 1 addition & 2 deletions src/uipath_langchain/agent/react/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def create_agent(
)

for tool_name in tool_node_names:
builder.add_edge(tool_name, AgentGraphNode.AGENT)

builder.add_conditional_edges(tool_name, route_agent, target_node_names)
builder.add_edge(AgentGraphNode.TERMINATE, END)

return builder
21 changes: 19 additions & 2 deletions src/uipath_langchain/agent/react/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Literal, Sequence

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.messages import AIMessage, AnyMessage, ToolCall
from langchain_core.tools import BaseTool

from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES
from .types import AgentGraphState
from .types import FLOW_CONTROL_TOOLS, AgentGraphState
from .utils import count_consecutive_thinking_messages

OPENAI_COMPATIBLE_CHAT_MODELS = (
Expand All @@ -33,6 +33,16 @@ def _get_required_tool_choice_by_model(
return "any"


def _filter_control_flow_tool_calls(
tool_calls: list[ToolCall],
) -> list[ToolCall]:
"""Remove control flow tools when multiple tool calls exist."""
if len(tool_calls) <= 1:
return tool_calls

return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS]


def create_llm_node(
model: BaseChatModel,
tools: Sequence[BaseTool] | None = None,
Expand Down Expand Up @@ -74,6 +84,13 @@ async def llm_node(state: AgentGraphState):
f"LLM returned {type(response).__name__} instead of AIMessage"
)

# filter out flow control tools when multiple tool calls exist
if response.tool_calls:
filtered_tool_calls = _filter_control_flow_tool_calls(response.tool_calls)
if len(filtered_tool_calls) != len(response.tool_calls):
# todo: this does not actually work, but fixing tool call modifying is a separate task
response.tool_calls = filtered_tool_calls

return {"messages": [response]}

return llm_node
95 changes: 45 additions & 50 deletions src/uipath_langchain/agent/react/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,13 @@

from typing import Literal

from langchain_core.messages import ToolCall
from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL

from ..exceptions import AgentNodeRoutingException
from .router_utils import validate_last_message_is_AI
from .types import AgentGraphNode, AgentGraphState
from .utils import count_consecutive_thinking_messages

FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name]


def __filter_control_flow_tool_calls(
tool_calls: list[ToolCall],
) -> list[ToolCall]:
"""Remove control flow tools when multiple tool calls exist."""
if len(tool_calls) <= 1:
return tool_calls

return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS]


def __has_control_flow_tool(tool_calls: list[ToolCall]) -> bool:
"""Check if any tool call is of a control flow tool."""
return any(tc.get("name") in FLOW_CONTROL_TOOLS for tc in tool_calls)
from .types import FLOW_CONTROL_TOOLS, AgentGraphNode, AgentGraphState
from .utils import (
count_consecutive_thinking_messages,
extract_current_tool_call_index,
find_latest_ai_message,
)


def create_route_agent(thinking_messages_limit: int = 0):
Expand All @@ -40,50 +23,62 @@ def create_route_agent(thinking_messages_limit: int = 0):

def route_agent(
state: AgentGraphState,
) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
"""Route after agent: handles all routing logic including control flow detection.
) -> str | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
"""Route after agent: handles sequential tool execution.

Routing logic:
1. If multiple tool calls exist, filter out control flow tools (EndExecution, RaiseError)
2. If control flow tool(s) remain, route to TERMINATE
3. If regular tool calls remain, route to specific tool nodes (return list of tool names)
4. If no tool calls, handle consecutive completions
1. Get current tool call index from messages
2. If current tool call index is None (all tools completed), route to AGENT
3. If current tool call is a flow control tool, route to TERMINATE
4. Otherwise, route to the specific tool node

Returns:
- list[str]: Tool node names for parallel execution
- AgentGraphNode.AGENT: For consecutive completions
- str: Single tool node name for sequential execution
- AgentGraphNode.AGENT: When all tool calls completed or no tool calls
- AgentGraphNode.TERMINATE: For control flow termination

Raises:
AgentNodeRoutingException: When encountering unexpected state (empty messages, non-AIMessage, or excessive completions)
AgentNodeRoutingException: When encountering unexpected state
"""
messages = state.messages
last_message = validate_last_message_is_AI(messages)

tool_calls = list(last_message.tool_calls) if last_message.tool_calls else []
tool_calls = __filter_control_flow_tool_calls(tool_calls)
last_message = find_latest_ai_message(messages)
if last_message is None:
raise AgentNodeRoutingException(
"No AIMessage found in messages for routing."
)

if tool_calls and __has_control_flow_tool(tool_calls):
return AgentGraphNode.TERMINATE
if not last_message.tool_calls:
consecutive_thinking_messages = count_consecutive_thinking_messages(
messages
)

if tool_calls:
return [tc["name"] for tc in tool_calls]
if consecutive_thinking_messages > thinking_messages_limit:
raise AgentNodeRoutingException(
f"Agent exceeded consecutive completions limit without producing tool calls "
f"(completions: {consecutive_thinking_messages}, max: {thinking_messages_limit}). "
f"This should not happen as tool_choice='required' is enforced at the limit."
)

consecutive_thinking_messages = count_consecutive_thinking_messages(messages)
if last_message.content:
return AgentGraphNode.AGENT

if consecutive_thinking_messages > thinking_messages_limit:
raise AgentNodeRoutingException(
f"Agent exceeded consecutive completions limit without producing tool calls "
f"(completions: {consecutive_thinking_messages}, max: {thinking_messages_limit}). "
f"This should not happen as tool_choice='required' is enforced at the limit."
f"Agent produced empty response without tool calls "
f"(completions: {consecutive_thinking_messages}, has_content: False)"
)

if last_message.content:
current_index = extract_current_tool_call_index(messages)

# all tool calls completed, go back to agent
if current_index is None:
return AgentGraphNode.AGENT

raise AgentNodeRoutingException(
f"Agent produced empty response without tool calls "
f"(completions: {consecutive_thinking_messages}, has_content: False)"
)
current_tool_call = last_message.tool_calls[current_index]
current_tool_name = current_tool_call["name"]

if current_tool_name in FLOW_CONTROL_TOOLS:
return AgentGraphNode.TERMINATE

return current_tool_name

return route_agent
Loading