|
4 | 4 | import uuid |
5 | 5 | from typing import Any, Callable, Dict, overload |
6 | 6 |
|
| 7 | +from llama_index.core.agent.workflow import BaseWorkflowAgent |
7 | 8 | from llama_index.core.workflow import ( |
8 | 9 | HumanResponseEvent, |
9 | 10 | InputRequiredEvent, |
@@ -71,18 +72,45 @@ def generate_schema_from_workflow(workflow: Workflow) -> Dict[str, Any]: |
71 | 72 |
|
72 | 73 | # Generate input schema from StartEvent using Pydantic's schema method |
73 | 74 | try: |
74 | | - input_schema = start_event_class.model_json_schema() |
75 | | - # Resolve references and handle nullable types |
76 | | - input_schema = resolve_refs(input_schema) |
77 | | - schema["input"]["properties"] = process_nullable_types( |
78 | | - input_schema.get("properties", {}) |
79 | | - ) |
80 | | - schema["input"]["required"] = input_schema.get("required", []) |
| 75 | + if isinstance(workflow, BaseWorkflowAgent): |
| 76 | + # For workflow agents, define a simple schema with just user_msg |
| 77 | + schema["input"] = { |
| 78 | + "type": "object", |
| 79 | + "properties": { |
| 80 | + "user_msg": { |
| 81 | + "type": "string", |
| 82 | + "title": "User Message", |
| 83 | + "description": "The user's question or request" |
| 84 | + } |
| 85 | + }, |
| 86 | + "required": ["user_msg"] |
| 87 | + } |
| 88 | + else: |
| 89 | + input_schema = start_event_class.model_json_schema() |
| 90 | + # Resolve references and handle nullable types |
| 91 | + input_schema = resolve_refs(input_schema) |
| 92 | + schema["input"]["properties"] = process_nullable_types( |
| 93 | + input_schema.get("properties", {}) |
| 94 | + ) |
| 95 | + schema["input"]["required"] = input_schema.get("required", []) |
81 | 96 | except (AttributeError, Exception): |
82 | 97 | pass |
83 | 98 |
|
84 | | - # For output schema, check if it's the base StopEvent or a custom subclass |
85 | | - if stop_event_class is StopEvent: |
| 99 | + # Handle output schema - check if it's a workflow agent with output_cls first |
| 100 | + if isinstance(workflow, BaseWorkflowAgent) and getattr(workflow, 'output_cls', None) is not None: |
| 101 | + # Use the output_cls schema for structured output |
| 102 | + try: |
| 103 | + output_schema = workflow.output_cls.model_json_schema() |
| 104 | + # Resolve references and handle nullable types |
| 105 | + output_schema = resolve_refs(output_schema) |
| 106 | + schema["output"]["properties"] = process_nullable_types( |
| 107 | + output_schema.get("properties", {}) |
| 108 | + ) |
| 109 | + schema["output"]["required"] = output_schema.get("required", []) |
| 110 | + except (AttributeError, Exception): |
| 111 | + pass |
| 112 | + # Check if it's the base StopEvent or a custom subclass |
| 113 | + elif stop_event_class is StopEvent: |
86 | 114 | # base StopEvent |
87 | 115 | schema["output"] = { |
88 | 116 | "type": "object", |
|
0 commit comments