Skip to content
This repository was archived by the owner on May 27, 2026. It is now read-only.

Commit 0dd09c0

Browse files
authored
fix: store ModelRequest in generate_sql messages list (#834)
1 parent 1db2468 commit 0dd09c0

2 files changed

Lines changed: 16 additions & 16 deletions

File tree

projects/pgai/pgai/cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,10 @@ async def do() -> GenerateSQLResponse:
13471347
console.print(Syntax(resp.sql_statement, "sql", word_wrap=True))
13481348

13491349
if save_final_prompt:
1350-
save_final_prompt.expanduser().resolve().write_text(resp.final_prompt)
1350+
# The final prompt is the user prompt of the last message request we made.
1351+
save_final_prompt.expanduser().resolve().write_text(
1352+
str(resp.messages[-1][0].parts[-1].content)
1353+
)
13511354

13521355

13531356
@semantic_catalog.command()

projects/pgai/pgai/semantic_catalog/gen_sql.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -358,17 +358,16 @@ class GenerateSQLResponse:
358358
context: The database context used to generate the SQL statement.
359359
command_type: The type of SQL statement generated (e.g. SELECT, INSERT, UPDATE)
360360
query_plan: The PostgreSQL query plan for the generated SQL statement.
361-
final_prompt: The final prompt that was sent to the model.
362-
messages: List of all messages exchanged during the generation process.
361+
messages: List of all messages exchanged during the generation process, where the
362+
ModelRequest has two parts: (SystemPromptPart, UserPromptPart).
363363
usage: Usage statistics for the AI model calls.
364364
"""
365365

366366
sql_statement: str
367367
context: DatabaseContext
368368
command_type: str
369369
query_plan: dict[str, Any]
370-
final_prompt: str
371-
messages: list[ModelRequest | ModelResponse]
370+
messages: list[tuple[ModelRequest, ModelResponse]]
372371
usage: Usage
373372

374373

@@ -727,7 +726,7 @@ async def generate_sql(
727726
answer: str | None = None
728727
command_type: str | None = None
729728
pgversion: int | None = await _get_database_version(target_con)
730-
messages: list[ModelRequest | ModelResponse] = []
729+
messages: list[tuple[ModelRequest, ModelResponse]] = []
731730
user_prompt: str | None = None
732731
query_plan: dict[str, Any] | None = None
733732
error: str | None = None
@@ -755,16 +754,15 @@ async def generate_sql(
755754
error=error,
756755
)
757756

757+
request = ModelRequest(
758+
parts=[
759+
SystemPromptPart(content=system_prompt),
760+
UserPromptPart(content=user_prompt),
761+
]
762+
)
758763
model_response: ModelResponse = await model_request(
759764
model=model,
760-
messages=[
761-
ModelRequest(
762-
parts=[
763-
SystemPromptPart(content=system_prompt),
764-
UserPromptPart(content=user_prompt),
765-
]
766-
)
767-
],
765+
messages=[request],
768766
model_request_parameters=ModelRequestParameters(
769767
function_tools=[_search_tool_definition()]
770768
if iteration < iteration_limit
@@ -775,7 +773,7 @@ async def generate_sql(
775773
model_settings=model_settings,
776774
)
777775

778-
messages.append(model_response)
776+
messages.append((request, model_response))
779777
usage = usage + model_response.usage
780778

781779
for part in model_response.parts:
@@ -882,7 +880,6 @@ async def generate_sql(
882880
context=ctx,
883881
command_type=command_type or "UNKNOWN",
884882
query_plan=query_plan or {},
885-
final_prompt=user_prompt or "MISSING",
886883
messages=messages,
887884
usage=usage,
888885
)

0 commit comments

Comments
 (0)