-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathcli_init.py
More file actions
366 lines (305 loc) · 13.2 KB
/
cli_init.py
File metadata and controls
366 lines (305 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import asyncio
import json
import os
import uuid
from typing import Any, Dict
from llama_index.core.workflow import StopEvent, Workflow
from llama_index.core.workflow.drawing import StepConfig # type: ignore
from llama_index.core.workflow.events import (
HumanResponseEvent,
InputRequiredEvent,
)
from llama_index.core.workflow.utils import (
get_steps_from_class,
get_steps_from_instance,
)
from uipath._cli._utils._console import ConsoleLogger
from uipath._cli._utils._parse_ast import generate_bindings_json # type: ignore
from uipath._cli.middlewares import MiddlewareResult
from ._utils._config import LlamaIndexConfig
console = ConsoleLogger()
def resolve_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve references in a schema"""
if "$ref" in schema:
ref = schema["$ref"].split("/")[-1]
if "definitions" in schema and ref in schema["definitions"]:
return resolve_refs(schema["definitions"][ref])
properties = schema.get("properties", {})
for prop, prop_schema in properties.items():
if "$ref" in prop_schema:
properties[prop] = resolve_refs(prop_schema)
return schema
def process_nullable_types(properties: Dict[str, Any]) -> Dict[str, Any]:
"""Process properties to handle nullable types correctly"""
result = {}
for name, prop in properties.items():
if "anyOf" in prop:
types = [item.get("type") for item in prop["anyOf"] if "type" in item]
if "null" in types:
non_null_types = [t for t in types if t != "null"]
if len(non_null_types) == 1:
result[name] = {"type": non_null_types[0], "nullable": True}
else:
result[name] = {"type": non_null_types, "nullable": True}
else:
result[name] = prop
else:
result[name] = prop
return result
def generate_schema_from_workflow(workflow: Workflow) -> Dict[str, Any]:
"""Extract input/output schema from a LlamaIndex workflow"""
schema = {
"input": {"type": "object", "properties": {}, "required": []},
"output": {"type": "object", "properties": {}, "required": []},
}
# Find the actual StartEvent and StopEvent classes used in this workflow
start_event_class = workflow._start_event_class
stop_event_class = workflow._stop_event_class
# Generate input schema from StartEvent using Pydantic's schema method
try:
input_schema = start_event_class.model_json_schema()
# Resolve references and handle nullable types
input_schema = resolve_refs(input_schema)
schema["input"]["properties"] = process_nullable_types(
input_schema.get("properties", {})
)
schema["input"]["required"] = input_schema.get("required", [])
except (AttributeError, Exception):
pass
# For output schema, check if it's the base StopEvent or a custom subclass
if stop_event_class is StopEvent:
# base StopEvent
schema["output"] = {
"type": "object",
"properties": {
"result": {
"title": "Result",
"type": "object",
}
},
"required": ["result"],
}
else:
# For custom StopEvent subclasses, extract their Pydantic schema
try:
output_schema = stop_event_class.model_json_schema()
# Resolve references and handle nullable types
output_schema = resolve_refs(output_schema)
schema["output"]["properties"] = process_nullable_types(
output_schema.get("properties", {})
)
schema["output"]["required"] = output_schema.get("required", [])
except (AttributeError, Exception):
pass
return schema
def draw_all_possible_flows_mermaid(
workflow: Workflow,
filename: str = "workflow_all_flows.mermaid",
) -> str:
"""Draws all possible flows of the workflow as a Mermaid diagram."""
# Initialize Mermaid flowchart string
mermaid_diagram = ["flowchart TD"]
# Add nodes from all steps
steps = get_steps_from_class(workflow)
if not steps:
# If no steps are defined in the class, try to get them from the instance
steps = get_steps_from_instance(workflow)
# Track all nodes and edges to avoid duplicates
nodes = set()
edges = set()
# Track event types to avoid duplicates
event_types = {}
current_stop_event = (
None # Only one kind of `StopEvent` is allowed in a `Workflow`.
)
step_config: StepConfig | None = None
for _, step_func in steps.items():
step_config = getattr(step_func, "__step_config", None)
if step_config is None:
continue
for return_type in step_config.return_types:
if issubclass(return_type, StopEvent):
current_stop_event = return_type
break
if current_stop_event:
break
# First pass: collect all event types (both return types and accepted events)
for _, step_func in steps.items():
step_config = getattr(step_func, "__step_config", None)
if step_config is None:
continue
# Collect accepted event types
for event_type in step_config.accepted_events:
if event_type == StopEvent and event_type != current_stop_event:
continue
event_name = event_type.__name__
event_types[event_name] = event_type
# Collect return types
for return_type in step_config.return_types:
if return_type is type(None):
continue
return_name = return_type.__name__
event_types[return_name] = return_type
# Generate step nodes
for step_name, step_func in steps.items():
step_config = getattr(step_func, "__step_config", None)
if step_config is None:
continue
# Add step node (use step_name with cleaned ID)
step_id = f"step_{clean_id(step_name)}"
if step_id not in nodes:
nodes.add(step_id)
mermaid_diagram.append(f' {step_id}["{step_name}"]:::stepStyle')
# Generate event nodes (only once per event type)
for event_name, event_type in event_types.items():
event_id = f"event_{clean_id(event_name)}"
if event_id not in nodes:
nodes.add(event_id)
style = get_event_style(event_type)
mermaid_diagram.append(f" {event_id}([<p>{event_name}</p>]):::{style}")
if issubclass(event_type, InputRequiredEvent):
# Add node for conceptual external step
if "external_step" not in nodes:
nodes.add("external_step")
mermaid_diagram.append(
' external_step["external_step"]:::externalStyle'
)
# Generate edges
for step_name, step_func in steps.items():
step_config = getattr(step_func, "__step_config", None)
if step_config is None:
continue
step_id = f"step_{clean_id(step_name)}"
# Add edges for return types
for return_type in step_config.return_types:
if return_type is not type(None):
return_name = return_type.__name__
return_id = f"event_{clean_id(return_name)}"
edge = f"{step_id} --> {return_id}"
if edge not in edges:
edges.add(edge)
mermaid_diagram.append(f" {edge}")
if issubclass(return_type, InputRequiredEvent):
return_name = return_type.__name__
return_id = f"event_{clean_id(return_name)}"
edge = f"{return_id} --> external_step"
if edge not in edges:
edges.add(edge)
mermaid_diagram.append(f" {edge}")
# Add edges for accepted events
for event_type in step_config.accepted_events:
event_name = event_type.__name__
event_id = f"event_{clean_id(event_name)}"
if step_name == "_done" and issubclass(event_type, StopEvent):
if current_stop_event:
stop_event_name = current_stop_event.__name__
stop_event_id = f"event_{clean_id(stop_event_name)}"
edge = f"{stop_event_id} --> {step_id}"
if edge not in edges:
edges.add(edge)
mermaid_diagram.append(f" {edge}")
else:
edge = f"{event_id} --> {step_id}"
if edge not in edges:
edges.add(edge)
mermaid_diagram.append(f" {edge}")
if issubclass(event_type, HumanResponseEvent):
edge = f"external_step --> {event_id}"
if edge not in edges:
edges.add(edge)
mermaid_diagram.append(f" {edge}")
# Add style definitions
mermaid_diagram.append(" classDef stepStyle fill:#f2f0ff,line-height:1.2")
mermaid_diagram.append(" classDef externalStyle fill:#f2f0ff,line-height:1.2")
mermaid_diagram.append(" classDef defaultEventStyle fill-opacity:0")
mermaid_diagram.append(" classDef stopEventStyle fill:#bfb6fc")
mermaid_diagram.append(
" classDef inputRequiredStyle fill:#f2f0ff,line-height:1.2"
)
# Join all lines
mermaid_string = "\n".join(mermaid_diagram)
# Write to file if filename is provided
if filename:
with open(filename, "w") as f:
f.write(mermaid_string)
return mermaid_string
def clean_id(name: str) -> str:
"""Convert a name to a valid Mermaid ID."""
# Replace invalid characters with underscores
return name.replace(" ", "_").replace("-", "_").replace(".", "_")
def get_event_style(event_type) -> str:
"""Return the appropriate Mermaid style class for an event type."""
if issubclass(event_type, StopEvent):
return "stopEventStyle"
elif issubclass(event_type, InputRequiredEvent):
return "inputRequiredStyle"
else:
return "defaultEventStyle"
async def llamaindex_init_middleware_async(entrypoint: str) -> MiddlewareResult:
"""Middleware to check for llama_index.json and create uipath.json with schemas"""
config = LlamaIndexConfig()
if not config.exists:
return MiddlewareResult(
should_continue=True
) # Continue with normal flow if no llama_index.json
try:
config.load_config()
entrypoints = []
all_bindings = {"version": "2.0", "resources": []}
for workflow in config.workflows:
if entrypoint and workflow.name != entrypoint:
continue
try:
loaded_workflow = await workflow.load_workflow()
schema = generate_schema_from_workflow(loaded_workflow)
try:
# Make sure the file path exists
if os.path.exists(workflow.file_path):
file_bindings = generate_bindings_json(workflow.file_path)
# Merge bindings
if "resources" in file_bindings:
all_bindings["resources"] = file_bindings["resources"]
except Exception as e:
console.warning(
f"Warning: Could not generate bindings for {workflow.file_path}: {str(e)}"
)
new_entrypoint: dict[str, Any] = {
"filePath": workflow.name,
"uniqueId": str(uuid.uuid4()),
"type": "agent",
"input": schema["input"],
"output": schema["output"],
}
entrypoints.append(new_entrypoint)
draw_all_possible_flows_mermaid(
loaded_workflow, filename=f"{workflow.name}.mermaid"
)
except Exception as e:
console.error(f"Error during workflow load: {e}")
return MiddlewareResult(
should_continue=False,
should_include_stacktrace=True,
)
finally:
await workflow.cleanup()
if entrypoint and not entrypoints:
console.error(f"Error: No workflow found with name '{entrypoint}'")
return MiddlewareResult(
should_continue=False,
)
uipath_config = {"entryPoints": entrypoints, "bindings": all_bindings}
# Save the uipath.json file
config_path = "uipath.json"
with open(config_path, "w") as f:
json.dump(uipath_config, f, indent=2)
console.success(f" Created '{config_path}' file.")
return MiddlewareResult(should_continue=False)
except Exception as e:
console.error(f"Error processing LlamaIndex configuration: {str(e)}")
return MiddlewareResult(
should_continue=False,
should_include_stacktrace=True,
)
def llamaindex_init_middleware(entrypoint: str) -> MiddlewareResult:
"""Middleware to check for llama_index.json and create uipath.json with schemas"""
return asyncio.run(llamaindex_init_middleware_async(entrypoint))