Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/uipath_langchain/runtime/runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextvars
import logging
import os
from collections.abc import Iterator
Expand All @@ -18,6 +19,7 @@
UiPathRuntimeStorageProtocol,
UiPathStreamOptions,
)
from uipath.tracing import ReferenceContext, ReferenceContextAccessor
from uipath.runtime.errors import (
UiPathBaseRuntimeError,
UiPathErrorCategory,
Expand Down Expand Up @@ -75,12 +77,31 @@ def __init__(
self.chat.client_side_tools = self._get_client_side_tools()
self._middleware_node_names: set[str] = self._detect_middleware_nodes()

def _push_reference_context(self) -> contextvars.Token:
"""Append this runtime's own entry to the ambient ReferenceContext.

Reads any parent context already in the accessor (e.g. set by an
upstream middleware or the agents-python runtime), then appends a
``langgraph`` entry for this runtime. Returns the ContextVar token
so the caller can reset in a ``finally`` block.
"""
agent_id = os.environ.get("UIPATH_AGENT_ID")
agent_version = os.environ.get("UIPATH_PROCESS_VERSION") or None
parent_ctx = ReferenceContextAccessor.get() or ReferenceContext.Empty
ref_ctx = (
parent_ctx.add("langgraph", agent_id, agent_version)
if agent_id
else parent_ctx
)
return ReferenceContextAccessor.set(ref_ctx)

async def execute(
self,
input: dict[str, Any] | None = None,
options: UiPathExecuteOptions | None = None,
) -> UiPathRuntimeResult:
"""Execute the graph with the provided input and configuration."""
ref_ctx_token = self._push_reference_context()
try:
graph_input = await self._get_graph_input(input, options)
graph_config = self._get_graph_config()
Expand All @@ -99,6 +120,8 @@ async def execute(

except Exception as e:
raise self.create_runtime_error(e) from e
finally:
ReferenceContextAccessor.reset(ref_ctx_token)

async def stream(
self,
Expand Down Expand Up @@ -133,6 +156,7 @@ async def stream(
Raises:
LangGraphRuntimeError: If execution fails
"""
ref_ctx_token = self._push_reference_context()
try:
graph_input = await self._get_graph_input(input, options)
graph_config = self._get_graph_config()
Expand Down Expand Up @@ -230,6 +254,8 @@ async def stream(

except Exception as e:
raise self.create_runtime_error(e) from e
finally:
ReferenceContextAccessor.reset(ref_ctx_token)

async def get_schema(self) -> UiPathRuntimeSchema:
"""Get schema for this LangGraph runtime."""
Expand Down
205 changes: 205 additions & 0 deletions tests/runtime/test_reference_context_wiring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Tests for ReferenceContext wiring in UiPathLangGraphRuntime."""

import os
import tempfile
from typing import Any, TypedDict
Comment on lines +3 to +5

import pytest
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.graph import END, START, StateGraph

from uipath.platform.common._reference_context import (
ReferenceContext,
ReferenceContextAccessor,
)
Comment on lines +11 to +14
from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime


# ---------------------------------------------------------------------------
# Minimal graph fixture
# ---------------------------------------------------------------------------

class _State(TypedDict):
value: str


def _build_graph() -> Any:
graph = StateGraph(_State)
graph.add_node("step", lambda s: {"value": s.get("value", "") + "_done"})
graph.add_edge(START, "step")
graph.add_edge("step", END)
return graph


def _clear_accessor() -> None:
token = ReferenceContextAccessor.set(None)
ReferenceContextAccessor.reset(token)
Comment on lines +34 to +36


# ---------------------------------------------------------------------------
# _push_reference_context — unit tests (no graph needed)
# ---------------------------------------------------------------------------

class TestPushReferenceContext:
def setup_method(self) -> None:
_clear_accessor()

def teardown_method(self) -> None:
_clear_accessor()

def test_sets_langgraph_entry_when_agent_id_present(
self, monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")
monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False)

from langgraph.graph import StateGraph
graph = _build_graph().compile()
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t")

token = runtime._push_reference_context()
try:
ctx = ReferenceContextAccessor.get()
assert ctx is not None
assert len(ctx) == 1
assert ctx.entries[0].service_type == "langgraph"
assert ctx.entries[0].reference_id == "550e8400-e29b-41d4-a716-446655440020"
assert ctx.entries[0].version is None
finally:
ReferenceContextAccessor.reset(token)

def test_includes_version_when_env_set(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")
monkeypatch.setenv("UIPATH_PROCESS_VERSION", "3.1.0")

graph = _build_graph().compile()
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t")

token = runtime._push_reference_context()
try:
ctx = ReferenceContextAccessor.get()
assert ctx is not None
assert ctx.entries[0].version == "3.1.0"
finally:
ReferenceContextAccessor.reset(token)

def test_no_entry_when_agent_id_absent(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("UIPATH_AGENT_ID", raising=False)
monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False)

graph = _build_graph().compile()
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t")

token = runtime._push_reference_context()
try:
ctx = ReferenceContextAccessor.get()
assert ctx is not None
assert len(ctx) == 0
finally:
ReferenceContextAccessor.reset(token)

def test_stacks_on_top_of_parent_context(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")
monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False)

parent = ReferenceContext.Empty.add(
"agent", "550e8400-e29b-41d4-a716-446655440001", "1.0"
)
parent_token = ReferenceContextAccessor.set(parent)

graph = _build_graph().compile()
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="t")

token = runtime._push_reference_context()
try:
ctx = ReferenceContextAccessor.get()
assert ctx is not None
assert len(ctx) == 2
assert ctx.entries[0].service_type == "agent"
assert ctx.entries[1].service_type == "langgraph"
finally:
ReferenceContextAccessor.reset(token)
ReferenceContextAccessor.reset(parent_token)


# ---------------------------------------------------------------------------
# execute() — context cleared after run
# ---------------------------------------------------------------------------

async def test_context_cleared_after_execute(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
_clear_accessor()
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")
monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False)

with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db = f.name

Comment on lines +142 to +144
async with AsyncSqliteSaver.from_conn_string(db) as memory:
await memory.setup()
graph = _build_graph().compile(checkpointer=memory)
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="exec-run")
await runtime.execute(input={"value": "hello"})

assert ReferenceContextAccessor.get() is None


async def test_context_cleared_after_execute_on_error(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
_clear_accessor()
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")

class _S(TypedDict):
v: str

def _boom(s: _S) -> _S:
raise ValueError("explode")

g = StateGraph(_S)
g.add_node("boom", _boom)
g.add_edge(START, "boom")
g.add_edge("boom", END)

with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db = f.name

Comment on lines +171 to +173
async with AsyncSqliteSaver.from_conn_string(db) as memory:
await memory.setup()
compiled = g.compile(checkpointer=memory)
runtime = UiPathLangGraphRuntime(graph=compiled, runtime_id="err-run")
with pytest.raises(Exception):
await runtime.execute(input={"v": "x"})

assert ReferenceContextAccessor.get() is None


# ---------------------------------------------------------------------------
# stream() — context cleared after run
# ---------------------------------------------------------------------------

async def test_context_cleared_after_stream(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
_clear_accessor()
monkeypatch.setenv("UIPATH_AGENT_ID", "550e8400-e29b-41d4-a716-446655440020")
monkeypatch.delenv("UIPATH_PROCESS_VERSION", raising=False)

with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db = f.name

Comment on lines +195 to +197
async with AsyncSqliteSaver.from_conn_string(db) as memory:
await memory.setup()
graph = _build_graph().compile(checkpointer=memory)
runtime = UiPathLangGraphRuntime(graph=graph, runtime_id="stream-run")
async for _ in runtime.stream(input={"value": "hi"}):
pass

assert ReferenceContextAccessor.get() is None
Loading