Skip to content

Commit 01e59d7

Browse files
committed
fix: use durable_interrupt for tool confirmation
1 parent 115bcfd commit 01e59d7

File tree

7 files changed

+29
-205
lines changed

7 files changed

+29
-205
lines changed

src/uipath_langchain/agent/tools/durable_interrupt/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,12 @@
22

33
from .decorator import (
44
_durable_state,
5-
_interrupt_offset,
6-
add_interrupt_offset,
75
durable_interrupt,
86
)
97
from .skip_interrupt import SkipInterruptValue
108

119
__all__ = [
12-
"add_interrupt_offset",
1310
"durable_interrupt",
1411
"SkipInterruptValue",
1512
"_durable_state",
16-
"_interrupt_offset",
1713
]

src/uipath_langchain/agent/tools/durable_interrupt/decorator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,6 @@ async def start_job():
4848
_durable_state: contextvars.ContextVar[tuple[int, int] | None] = contextvars.ContextVar(
4949
"_durable_interrupt_state", default=None
5050
)
51-
# Number of interrupt() calls before the first @durable_interrupt
52-
_interrupt_offset: contextvars.ContextVar[int] = contextvars.ContextVar(
53-
"_durable_interrupt_offset", default=0
54-
)
55-
56-
57-
def add_interrupt_offset(n: int = 1) -> None:
58-
"""Increment durable_interrupt's starting index offset by n"""
59-
_interrupt_offset.set(_interrupt_offset.get(0) + n)
6051

6152

6253
def _next_durable_index() -> tuple[Any, int]:
@@ -76,8 +67,7 @@ def _next_durable_index() -> tuple[Any, int]:
7667
state = _durable_state.get()
7768

7869
if state is None or state[0] != sp_id:
79-
idx = _interrupt_offset.get(0)
80-
_interrupt_offset.set(0) # consume offset
70+
idx = 0
8171
else:
8272
idx = state[1]
8373

src/uipath_langchain/chat/hitl.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from langchain_core.messages.tool import ToolCall, ToolMessage
88
from langchain_core.tools import BaseTool, InjectedToolCallId
99
from langchain_core.tools import tool as langchain_tool
10-
from langgraph.types import interrupt
1110
from uipath.core.chat import (
1211
UiPathConversationToolCallConfirmationValue,
1312
)
@@ -119,21 +118,20 @@ def request_approval(
119118
if tool_call_schema is not None:
120119
input_schema = tool_call_schema.model_json_schema()
121120

122-
response = interrupt(
123-
UiPathConversationToolCallConfirmationValue(
121+
# Lazy import to avoid circular dependency:
122+
# hitl -> agent.tools.durable_interrupt -> agent.tools -> tool_node -> hitl
123+
from uipath_langchain.agent.tools.durable_interrupt import durable_interrupt
124+
125+
@durable_interrupt
126+
def ask_confirmation():
127+
return UiPathConversationToolCallConfirmationValue(
124128
tool_call_id=tool_call_id,
125129
tool_name=tool.name,
126130
input_schema=input_schema,
127131
input_value=tool_args,
128132
)
129-
)
130-
# Lazy import to avoid circular dependency:
131-
# hitl -> agent.tools.durable_interrupt -> agent.tools -> tool_node -> hitl
132-
from uipath_langchain.agent.tools.durable_interrupt import add_interrupt_offset
133133

134-
# Workaround for langgraph#6792 — remove when subgraph @task + interrupt()
135-
# checkpoint caching is fixed upstream
136-
add_interrupt_offset()
134+
response = ask_confirmation()
137135

138136
# The resume payload from CAS has shape:
139137
# {"type": "uipath_cas_tool_call_confirmation",

src/uipath_langchain/runtime/messages.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,7 @@ async def map_tool_message_to_events(
433433
# Keep as string if not valid JSON
434434
pass
435435

436-
events: list[UiPathConversationMessageEvent] = []
437-
438-
events.append(
436+
events = [
439437
UiPathConversationMessageEvent(
440438
message_id=message_id,
441439
tool_call=UiPathConversationToolCallEvent(
@@ -447,7 +445,7 @@ async def map_tool_message_to_events(
447445
),
448446
),
449447
)
450-
)
448+
]
451449

452450
if is_last_tool_call:
453451
events.append(self.map_to_message_end_event(message_id))

tests/agent/tools/test_durable_interrupt.py

Lines changed: 3 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from uipath_langchain.agent.tools.durable_interrupt import (
1111
_durable_state,
12-
_interrupt_offset,
13-
add_interrupt_offset,
1412
durable_interrupt,
1513
)
1614

@@ -35,12 +33,10 @@ def _make_config(scratchpad: FakeScratchpad | None = None) -> dict[str, Any]:
3533

3634
@pytest.fixture(autouse=True)
3735
def _reset_durable_state() -> Generator[None]:
38-
"""Reset per-node counter and offset between tests for isolation."""
39-
state_token = _durable_state.set(None)
40-
offset_token = _interrupt_offset.set(0)
36+
"""Reset per-node counter between tests for isolation."""
37+
token = _durable_state.set(None)
4138
yield
42-
_durable_state.reset(state_token)
43-
_interrupt_offset.reset(offset_token)
39+
_durable_state.reset(token)
4440

4541

4642
class TestAsyncFirstExecution:
@@ -515,161 +511,3 @@ def my_sync_fn() -> str:
515511
return "value"
516512

517513
assert not asyncio.iscoroutinefunction(my_sync_fn)
518-
519-
520-
class TestInterruptOffsetWithPriorInterrupts:
521-
"""Offset accounts for prior interrupt() calls (e.g. HITL) in the same node."""
522-
523-
@patch(PATCH_INTERRUPT)
524-
@patch(PATCH_GET_CONFIG)
525-
def test_offset_skips_hitl_resume_slot(
526-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
527-
) -> None:
528-
"""With resume=["hitl_value"] and offset=1, durable's idx=1, 1 < 1 is False → body runs."""
529-
scratchpad = FakeScratchpad(resume=["hitl_value"])
530-
mock_get_config.return_value = _make_config(scratchpad)
531-
mock_interrupt.side_effect = lambda v: v
532-
533-
action = MagicMock(return_value="job-started")
534-
535-
@durable_interrupt
536-
def start_job() -> str:
537-
return action()
538-
539-
add_interrupt_offset()
540-
start_job()
541-
542-
action.assert_called_once()
543-
mock_interrupt.assert_called_once_with("job-started")
544-
545-
@patch(PATCH_INTERRUPT, return_value="durable-result")
546-
@patch(PATCH_GET_CONFIG)
547-
def test_offset_full_resume(
548-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
549-
) -> None:
550-
"""With resume=["hitl_value", "durable_result"] and offset=1, durable reads index 1."""
551-
scratchpad = FakeScratchpad(resume=["hitl_value", "durable_result"])
552-
mock_get_config.return_value = _make_config(scratchpad)
553-
554-
action = MagicMock()
555-
556-
@durable_interrupt
557-
def start_job() -> Any:
558-
return action()
559-
560-
add_interrupt_offset()
561-
result = start_job()
562-
563-
action.assert_not_called() # body skipped — idx=1 < len(resume)=2
564-
mock_interrupt.assert_called_once_with(None)
565-
assert result == "durable-result"
566-
567-
@patch(PATCH_INTERRUPT, return_value="resumed")
568-
@patch(PATCH_GET_CONFIG)
569-
def test_offset_resets_on_scratchpad_change(
570-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
571-
) -> None:
572-
"""After consuming offset on first scratchpad, a new scratchpad starts at 0."""
573-
sp1 = FakeScratchpad(resume=["hitl", "durable"])
574-
mock_get_config.return_value = _make_config(sp1)
575-
576-
@durable_interrupt
577-
def task_a() -> str:
578-
return "should-not-run"
579-
580-
add_interrupt_offset()
581-
task_a() # consumes offset, idx=1
582-
583-
# New scratchpad — offset should have been consumed/reset
584-
sp2 = FakeScratchpad(resume=["val"])
585-
mock_get_config.return_value = _make_config(sp2)
586-
587-
action = MagicMock()
588-
589-
@durable_interrupt
590-
def task_b() -> Any:
591-
return action()
592-
593-
task_b() # idx should be 0, not 1
594-
595-
action.assert_not_called() # idx=0 < len(resume)=1 → skipped
596-
597-
@patch(PATCH_INTERRUPT)
598-
@patch(PATCH_GET_CONFIG)
599-
def test_offset_with_multiple_durable_interrupts(
600-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
601-
) -> None:
602-
"""HITL + two durable_interrupts: offset shifts both indices correctly."""
603-
# resume[0] = HITL value, resume[1] = durable_a result
604-
# durable_a at idx=1 → resumed, durable_b at idx=2 → body runs
605-
scratchpad = FakeScratchpad(resume=["hitl_approval", "durable_a_result"])
606-
mock_get_config.return_value = _make_config(scratchpad)
607-
mock_interrupt.side_effect = lambda v: f"interrupt({v})"
608-
609-
action_a = MagicMock()
610-
action_b = MagicMock(return_value="job-B")
611-
612-
@durable_interrupt
613-
def durable_a() -> Any:
614-
return action_a()
615-
616-
@durable_interrupt
617-
def durable_b() -> str:
618-
return action_b()
619-
620-
add_interrupt_offset()
621-
result_a = durable_a() # idx=1, 1 < 2 → skipped
622-
result_b = durable_b() # idx=2, 2 < 2 → False → body runs
623-
624-
action_a.assert_not_called()
625-
action_b.assert_called_once()
626-
assert result_a == "interrupt(None)"
627-
assert result_b == "interrupt(job-B)"
628-
629-
@patch(PATCH_INTERRUPT)
630-
@patch(PATCH_GET_CONFIG)
631-
def test_multiple_offsets_accumulate(
632-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
633-
) -> None:
634-
"""Two prior interrupt() calls → offset accumulates to 2."""
635-
# resume[0] = first confirmation, resume[1] = second confirmation
636-
# durable at idx=2 → body runs (2 < 2 is False)
637-
scratchpad = FakeScratchpad(resume=["confirm_1", "confirm_2"])
638-
mock_get_config.return_value = _make_config(scratchpad)
639-
mock_interrupt.side_effect = lambda v: v
640-
641-
action = MagicMock(return_value="job-started")
642-
643-
@durable_interrupt
644-
def start_job() -> str:
645-
return action()
646-
647-
add_interrupt_offset() # first confirmation
648-
add_interrupt_offset() # second confirmation
649-
start_job()
650-
651-
action.assert_called_once() # idx=2, 2 < 2 → False → body runs
652-
mock_interrupt.assert_called_once_with("job-started")
653-
654-
@patch(PATCH_INTERRUPT, return_value="durable-result")
655-
@patch(PATCH_GET_CONFIG)
656-
def test_multiple_offsets_full_resume(
657-
self, mock_get_config: MagicMock, mock_interrupt: MagicMock
658-
) -> None:
659-
"""Two prior interrupts fully resumed + durable resumed → body skipped."""
660-
scratchpad = FakeScratchpad(resume=["confirm_1", "confirm_2", "durable-result"])
661-
mock_get_config.return_value = _make_config(scratchpad)
662-
663-
action = MagicMock()
664-
665-
@durable_interrupt
666-
def start_job() -> Any:
667-
return action()
668-
669-
add_interrupt_offset()
670-
add_interrupt_offset()
671-
result = start_job()
672-
673-
action.assert_not_called() # idx=2, 2 < 3 → True → skipped
674-
mock_interrupt.assert_called_once_with(None)
675-
assert result == "durable-result"

tests/agent/tools/test_tool_node.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
AgentRuntimeError,
1616
AgentRuntimeErrorCode,
1717
)
18-
from uipath_langchain.agent.react.types import AgentGraphState
1918
from uipath_langchain.agent.tools.tool_node import (
2019
ToolWrapperMixin,
2120
UiPathToolNode,
@@ -73,9 +72,10 @@ class FilteredState(BaseModel):
7372
session_id: str = "test_session"
7473

7574

76-
class MockState(AgentGraphState):
75+
class MockState(BaseModel):
7776
"""Mock state for testing."""
7877

78+
messages: list[Any] = []
7979
user_id: str = "test_user"
8080
session_id: str = "test_session"
8181

@@ -316,7 +316,8 @@ def test_tool_error_propagates_when_handle_errors_false(self, mock_state):
316316
node = UiPathToolNode(failing_tool, handle_tool_errors=False)
317317

318318
with pytest.raises(ValueError) as exc_info:
319-
node._func(state)
319+
node._func(state) # type: ignore[arg-type]
320+
320321
assert "Tool execution failed: test input" in str(exc_info.value)
321322

322323
async def test_async_tool_error_propagates_when_handle_errors_false(self):
@@ -333,7 +334,8 @@ async def test_async_tool_error_propagates_when_handle_errors_false(self):
333334
node = UiPathToolNode(failing_tool, handle_tool_errors=False)
334335

335336
with pytest.raises(ValueError) as exc_info:
336-
await node._afunc(state)
337+
await node._afunc(state) # type: ignore[arg-type]
338+
337339
assert "Async tool execution failed: test input" in str(exc_info.value)
338340

339341
def test_tool_error_captured_when_handle_errors_true(self):
@@ -349,7 +351,8 @@ def test_tool_error_captured_when_handle_errors_true(self):
349351

350352
node = UiPathToolNode(failing_tool, handle_tool_errors=True)
351353

352-
result = node._func(state)
354+
result = node._func(state) # type: ignore[arg-type]
355+
353356
assert result is not None
354357
assert isinstance(result, dict)
355358
assert "messages" in result
@@ -375,7 +378,8 @@ async def test_async_tool_error_captured_when_handle_errors_true(self):
375378

376379
node = UiPathToolNode(failing_tool, handle_tool_errors=True)
377380

378-
result = await node._afunc(state)
381+
result = await node._afunc(state) # type: ignore[arg-type]
382+
379383
assert result is not None
380384
assert isinstance(result, dict)
381385
assert "messages" in result

tests/chat/test_hitl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,39 +146,39 @@ def test_annotate_wraps_content_when_modified(self):
146146
class TestRequestApprovalTruthiness:
147147
"""Tests for the truthiness fix in request_approval."""
148148

149-
@patch("uipath_langchain.chat.hitl.interrupt")
149+
@patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt")
150150
def test_empty_dict_input_preserved(self, mock_interrupt):
151151
"""Empty dict from user edits should not be replaced by original args."""
152152
mock_interrupt.return_value = {"value": {"approved": True, "input": {}}}
153153
tool = MockTool()
154154
result = request_approval({"query": "test", "tool_call_id": "c1"}, tool)
155155
assert result == {}
156156

157-
@patch("uipath_langchain.chat.hitl.interrupt")
157+
@patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt")
158158
def test_empty_list_input_preserved(self, mock_interrupt):
159159
"""Empty list from user edits should not be replaced by original args."""
160160
mock_interrupt.return_value = {"value": {"approved": True, "input": []}}
161161
tool = MockTool()
162162
result = request_approval({"query": "test", "tool_call_id": "c1"}, tool)
163163
assert result == []
164164

165-
@patch("uipath_langchain.chat.hitl.interrupt")
165+
@patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt")
166166
def test_none_input_falls_back_to_original(self, mock_interrupt):
167167
"""None input should fall back to original tool_args."""
168168
mock_interrupt.return_value = {"value": {"approved": True, "input": None}}
169169
tool = MockTool()
170170
result = request_approval({"query": "test", "tool_call_id": "c1"}, tool)
171171
assert result == {"query": "test"}
172172

173-
@patch("uipath_langchain.chat.hitl.interrupt")
173+
@patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt")
174174
def test_missing_input_falls_back_to_original(self, mock_interrupt):
175175
"""Missing input key should fall back to original tool_args."""
176176
mock_interrupt.return_value = {"value": {"approved": True}}
177177
tool = MockTool()
178178
result = request_approval({"query": "test", "tool_call_id": "c1"}, tool)
179179
assert result == {"query": "test"}
180180

181-
@patch("uipath_langchain.chat.hitl.interrupt")
181+
@patch("uipath_langchain.agent.tools.durable_interrupt.decorator.interrupt")
182182
def test_rejected_returns_none(self, mock_interrupt):
183183
"""Rejected approval returns None."""
184184
mock_interrupt.return_value = {"value": {"approved": False}}

0 commit comments

Comments
 (0)