-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathfactory.py
More file actions
350 lines (293 loc) · 12 KB
/
factory.py
File metadata and controls
350 lines (293 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import asyncio
import inspect
import os
from typing import Any, AsyncContextManager
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.graph.state import CompiledStateGraph, StateGraph
from openinference.instrumentation.langchain import (
LangChainInstrumentor,
get_ancestor_spans,
get_current_span,
)
from uipath.core.tracing import UiPathSpanUtils, UiPathTraceManager
from uipath.platform.resume_triggers import (
UiPathResumeTriggerHandler,
)
from uipath.runtime import (
UiPathResumableRuntime,
UiPathRuntimeContext,
UiPathRuntimeFactorySettings,
UiPathRuntimeProtocol,
UiPathRuntimeStorageProtocol,
)
from uipath.runtime.errors import UiPathErrorCategory
from uipath_langchain._tracing import _instrument_traceable_attributes
from uipath_langchain.runtime.config import LangGraphConfig
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
from uipath_langchain.runtime.graph import LangGraphLoader
from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime
from uipath_langchain.runtime.storage import SqliteResumableStorage
def _collect_sdk_interrupt_modules() -> list[tuple[str, str]]:
"""Return `(module, class_name)` pairs for every SDK interrupt model."""
from uipath.platform.common import interrupt_models
return [
(cls.__module__, cls.__name__)
for _, cls in inspect.getmembers(interrupt_models, inspect.isclass)
if cls.__module__ == interrupt_models.__name__
]
class UiPathLangGraphRuntimeFactory:
"""Factory for creating LangGraph runtimes from langgraph.json configuration."""
def __init__(
self,
context: UiPathRuntimeContext,
):
"""
Initialize the factory.
Args:
context: UiPathRuntimeContext to use for runtime creation
"""
self.context = context
self._config: LangGraphConfig | None = None
self._memory: AsyncSqliteSaver | None = None
self._memory_cm: AsyncContextManager[AsyncSqliteSaver] | None = None
self._memory_lock = asyncio.Lock()
self._graph_cache: dict[str, CompiledStateGraph[Any, Any, Any, Any]] = {}
self._graph_loaders: dict[str, LangGraphLoader] = {}
self._graph_lock = asyncio.Lock()
self._setup_instrumentation(self.context.trace_manager)
def _setup_instrumentation(self, trace_manager: UiPathTraceManager | None) -> None:
"""Setup tracing and instrumentation."""
_instrument_traceable_attributes()
LangChainInstrumentor().instrument()
UiPathSpanUtils.register_current_span_provider(get_current_span)
UiPathSpanUtils.register_current_span_ancestors_provider(get_ancestor_spans)
def _get_connection_string(self) -> str:
"""Get the database connection string."""
if self.context.state_file_path is not None:
return self.context.state_file_path
if self.context.runtime_dir and self.context.state_file:
path = os.path.join(self.context.runtime_dir, self.context.state_file)
if (
not self.context.resume
and self.context.job_id is None
and not self.context.keep_state_file
):
# If not resuming and no job id, delete the previous state file
if os.path.exists(path):
try:
os.remove(path)
except OSError:
pass # File may be held by another process
os.makedirs(self.context.runtime_dir, exist_ok=True)
return path
default_path = os.path.join("__uipath", "state.db")
os.makedirs(os.path.dirname(default_path), exist_ok=True)
return default_path
async def _get_memory(self) -> AsyncSqliteSaver:
"""Get or create the shared memory instance."""
async with self._memory_lock:
if self._memory is None:
connection_string = self._get_connection_string()
self._memory_cm = AsyncSqliteSaver.from_conn_string(connection_string)
self._memory = await self._memory_cm.__aenter__()
await self._memory.setup()
self._apply_msgpack_allowlist(self._memory)
return self._memory
def _apply_msgpack_allowlist(self, memory: AsyncSqliteSaver) -> None:
"""Apply the user's msgpack allowlist (unioned with SDK interrupt models)."""
user_modules = self._load_config().allowed_msgpack_modules
if user_modules is None:
return
sdk_modules = _collect_sdk_interrupt_modules()
memory.serde = JsonPlusSerializer(
allowed_msgpack_modules=[*sdk_modules, *user_modules],
)
def _load_config(self) -> LangGraphConfig:
"""Load langgraph.json configuration."""
if self._config is None:
self._config = LangGraphConfig()
return self._config
async def _load_graph(
self, entrypoint: str, **kwargs
) -> StateGraph[Any, Any, Any] | CompiledStateGraph[Any, Any, Any, Any]:
"""
Load a graph for the given entrypoint.
Args:
entrypoint: Name of the graph to load
Returns:
The loaded StateGraph or CompiledStateGraph
Raises:
LangGraphRuntimeError: If graph cannot be loaded
"""
config = self._load_config()
if not config.exists:
raise LangGraphRuntimeError(
LangGraphErrorCode.CONFIG_MISSING,
"Invalid configuration",
"Failed to load configuration",
UiPathErrorCategory.DEPLOYMENT,
)
if entrypoint not in config.graphs:
available = ", ".join(config.entrypoints)
raise LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_NOT_FOUND,
"Graph not found",
f"Graph '{entrypoint}' not found. Available: {available}",
UiPathErrorCategory.DEPLOYMENT,
)
path = config.graphs[entrypoint]
graph_loader = LangGraphLoader.from_path_string(entrypoint, path)
self._graph_loaders[entrypoint] = graph_loader
try:
return await graph_loader.load()
except ImportError as e:
raise LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_IMPORT_ERROR,
"Graph import failed",
f"Failed to import graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except TypeError as e:
raise LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_TYPE_ERROR,
"Invalid graph type",
f"Graph '{entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}",
UiPathErrorCategory.USER,
) from e
except ValueError as e:
raise LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_VALUE_ERROR,
"Invalid graph value",
f"Invalid value in graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
except Exception as e:
raise LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_LOAD_ERROR,
"Failed to load graph",
f"Unexpected error loading graph '{entrypoint}': {str(e)}",
UiPathErrorCategory.USER,
) from e
async def _compile_graph(
self,
graph: StateGraph[Any, Any, Any] | CompiledStateGraph[Any, Any, Any, Any],
memory: AsyncSqliteSaver,
) -> CompiledStateGraph[Any, Any, Any, Any]:
"""
Compile a graph with the given memory/checkpointer.
Args:
graph: The graph to compile (StateGraph or already compiled)
memory: Checkpointer to use for compiled graph
Returns:
The compiled StateGraph
"""
builder = graph.builder if isinstance(graph, CompiledStateGraph) else graph
return builder.compile(checkpointer=memory)
async def _resolve_and_compile_graph(
self, entrypoint: str, memory: AsyncSqliteSaver, **kwargs
) -> CompiledStateGraph[Any, Any, Any, Any]:
"""
Resolve a graph from configuration and compile it.
Results are cached for reuse across multiple runtime instances.
Args:
entrypoint: Name of the graph to resolve
memory: Checkpointer to use for compiled graph
Returns:
The compiled StateGraph ready for execution
Raises:
LangGraphRuntimeError: If resolution or compilation fails
"""
async with self._graph_lock:
if entrypoint in self._graph_cache:
return self._graph_cache[entrypoint]
loaded_graph = await self._load_graph(entrypoint, **kwargs)
compiled_graph = await self._compile_graph(loaded_graph, memory)
self._graph_cache[entrypoint] = compiled_graph
return compiled_graph
def discover_entrypoints(self) -> list[str]:
"""
Discover all graph entrypoints.
Returns:
List of graph names that can be used as entrypoints
"""
config = self._load_config()
if not config.exists:
return []
return config.entrypoints
async def get_settings(self) -> UiPathRuntimeFactorySettings | None:
"""
Get the factory settings.
"""
return None
async def get_storage(self) -> UiPathRuntimeStorageProtocol | None:
"""
Get the runtime storage protocol instance.
Returns:
The storage protocol instance
"""
memory = await self._get_memory()
return SqliteResumableStorage(memory)
async def _create_runtime_instance(
self,
compiled_graph: CompiledStateGraph[Any, Any, Any, Any],
runtime_id: str,
entrypoint: str,
**kwargs,
) -> UiPathRuntimeProtocol:
"""
Create a runtime instance from a compiled graph.
Args:
compiled_graph: The compiled graph
runtime_id: Unique identifier for the runtime instance
entrypoint: Graph entrypoint name
Returns:
Configured runtime instance
"""
memory = await self._get_memory()
storage = SqliteResumableStorage(memory)
trigger_manager = UiPathResumeTriggerHandler()
base_runtime = UiPathLangGraphRuntime(
graph=compiled_graph,
runtime_id=runtime_id,
entrypoint=entrypoint,
storage=storage,
)
return UiPathResumableRuntime(
delegate=base_runtime,
storage=storage,
trigger_manager=trigger_manager,
runtime_id=runtime_id,
)
async def new_runtime(
self, entrypoint: str, runtime_id: str, **kwargs
) -> UiPathRuntimeProtocol:
"""
Create a new LangGraph runtime instance.
Args:
entrypoint: Graph name from langgraph.json
runtime_id: Unique identifier for the runtime instance
Returns:
Configured runtime instance with compiled graph
"""
# Get shared memory instance
memory = await self._get_memory()
compiled_graph = await self._resolve_and_compile_graph(
entrypoint, memory, **kwargs
)
return await self._create_runtime_instance(
compiled_graph=compiled_graph,
runtime_id=runtime_id,
entrypoint=entrypoint,
**kwargs,
)
async def dispose(self) -> None:
"""Cleanup factory resources."""
for loader in self._graph_loaders.values():
await loader.cleanup()
self._graph_loaders.clear()
self._graph_cache.clear()
if self._memory_cm is not None:
await self._memory_cm.__aexit__(None, None, None)
self._memory_cm = None
self._memory = None