@@ -358,17 +358,16 @@ class GenerateSQLResponse:
358358 context: The database context used to generate the SQL statement.
359359 command_type: The type of SQL statement generated (e.g. SELECT, INSERT, UPDATE)
360360 query_plan: The PostgreSQL query plan for the generated SQL statement.
361- final_prompt: The final prompt that was sent to the model.
362- messages: List of all messages exchanged during the generation process .
361+ messages: List of all messages exchanged during the generation process, where the
362+ ModelRequest has two parts: (SystemPromptPart, UserPromptPart) .
363363 usage: Usage statistics for the AI model calls.
364364 """
365365
366366 sql_statement : str
367367 context : DatabaseContext
368368 command_type : str
369369 query_plan : dict [str , Any ]
370- final_prompt : str
371- messages : list [ModelRequest | ModelResponse ]
370+ messages : list [tuple [ModelRequest , ModelResponse ]]
372371 usage : Usage
373372
374373
@@ -727,7 +726,7 @@ async def generate_sql(
727726 answer : str | None = None
728727 command_type : str | None = None
729728 pgversion : int | None = await _get_database_version (target_con )
730- messages : list [ModelRequest | ModelResponse ] = []
729+ messages : list [tuple [ ModelRequest , ModelResponse ] ] = []
731730 user_prompt : str | None = None
732731 query_plan : dict [str , Any ] | None = None
733732 error : str | None = None
@@ -755,16 +754,15 @@ async def generate_sql(
755754 error = error ,
756755 )
757756
757+ request = ModelRequest (
758+ parts = [
759+ SystemPromptPart (content = system_prompt ),
760+ UserPromptPart (content = user_prompt ),
761+ ]
762+ )
758763 model_response : ModelResponse = await model_request (
759764 model = model ,
760- messages = [
761- ModelRequest (
762- parts = [
763- SystemPromptPart (content = system_prompt ),
764- UserPromptPart (content = user_prompt ),
765- ]
766- )
767- ],
765+ messages = [request ],
768766 model_request_parameters = ModelRequestParameters (
769767 function_tools = [_search_tool_definition ()]
770768 if iteration < iteration_limit
@@ -775,7 +773,7 @@ async def generate_sql(
775773 model_settings = model_settings ,
776774 )
777775
778- messages .append (model_response )
776+ messages .append (( request , model_response ) )
779777 usage = usage + model_response .usage
780778
781779 for part in model_response .parts :
@@ -882,7 +880,6 @@ async def generate_sql(
882880 context = ctx ,
883881 command_type = command_type or "UNKNOWN" ,
884882 query_plan = query_plan or {},
885- final_prompt = user_prompt or "MISSING" ,
886883 messages = messages ,
887884 usage = usage ,
888885 )
0 commit comments