-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
119 lines (91 loc) · 4.75 KB
/
main.py
File metadata and controls
119 lines (91 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations as _annotations
from dataclasses import dataclass, field
from typing import List
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
from DspyModules.Helpers.LoadOptimizedPrograms import load_optimized_error_resolver_program, load_optimized_query_generator_program
from DspyModules.ReportRequestExtractorModule import ReportRequestExtractor
from dto import ReportRequest
from Schema.full_chema_graphql import full_schema
from query_validator import validate_graphql_query_for_workflow
import logfire
logfire.configure()
logfire.instrument_pydantic_ai()
# Load the graphql schema that you want to work with
# Replace with your actual schema definition
schema = full_schema
@dataclass
class State:
input: str = field(default="")
schema: str = field(default="")
retry_count: int = field(default=0)
report_request: ReportRequest = field(default=None)
is_query_validated: bool = field(default=False)
# Node that helps to extract the request from the user into a structured format
@dataclass
class ExtractReportReuest(BaseNode[State]):
async def run(self, ctx: GraphRunContext[State]) -> GenerateGraphQlQuery:
extractor = ReportRequestExtractor()
result = extractor(
user_input= ctx.state.input,
graphQl_schema= schema
)
ctx.state.report_request = result.report_request
return GenerateGraphQlQuery(result.report_request)
# Node that helps to generate the GraphQL query from the structured report request
@dataclass
class GenerateGraphQlQuery(BaseNode[State, None, str]):
user_request: ReportRequest
async def run(self, ctx: GraphRunContext[State]) -> validateGraphQlQuery:
#use optimized dspy program to generate the query.
# This will use the optimized dspy program to generate the query based on the schema and user request
# you can create your own optimized program using the DspyModules/QueryGeneratorModule.py file
query_model = load_optimized_query_generator_program()
result = query_model(
graphql_schema= schema ,
request = self.user_request
)
return validateGraphQlQuery(user_request=self.user_request,query_to_be_validated= result.query)
# Node that validates the generated GraphQL query
@dataclass
class validateGraphQlQuery(BaseNode[State, None, str]):
user_request: ReportRequest
query_to_be_validated: str
async def run(self, ctx: GraphRunContext[State]) -> ResolveError | End[str]:
result = validate_graphql_query_for_workflow(query=self.query_to_be_validated, schema_str=schema)
ctx.state.is_query_validated = True
if result is None:
return End(self.query_to_be_validated)
else:
ctx.state.retry_count += 1
if ctx.state.retry_count > 3:
return End("Unable to generate a valid GraphQL query for the user request.")
return ResolveError(user_request=self.user_request, validation_error=result, query_to_be_Resolved=self.query_to_be_validated)
# Node that helps to resolve errors in the GraphQL query
@dataclass
class ResolveError(BaseNode[State, None, str]):
user_request: ReportRequest
query_to_be_Resolved: str
validation_error: List[str]
async def run(self, ctx: GraphRunContext[State]) -> validateGraphQlQuery:
# Use the optimized error resolver program to resolve the errors in the query
# This will use the optimized dspy program to resolve the errors in the query
# you can create your own optimized program using the DspyModules/ErrorResolverModule.py file
optimized_error_resolver = load_optimized_error_resolver_program()
corrected_query = optimized_error_resolver(
graphql_schema=schema,
request=self.user_request,
validation_error= self.validation_error,
initial_query= self.query_to_be_Resolved
)
ctx.state.is_query_validated = False
return validateGraphQlQuery(user_request= self.user_request, query_to_be_validated= corrected_query.query)
query_generation_graph = Graph(nodes=(ExtractReportReuest, GenerateGraphQlQuery, validateGraphQlQuery, ResolveError))
# # Option 1: Display as image in Jupyter notebook
# display(Image(query_generation_graph.mermaid_code(start_node=ExtractReportReuest)))
# # Option 2: Print the raw Mermaid code
# print("Mermaid Code:")
# mermaid_code = query_generation_graph.mermaid_code(start_node=ExtractReportReuest)
# print(mermaid_code)
user_prompt = "Generate a chart to visualize how bill amount varies from month to month for the commercial account of John."
result = query_generation_graph.run_sync(ExtractReportReuest(), state=State(input=user_prompt, schema=schema))
print("Final GraphQL Query:", result.output)