Skip to content

Commit 886c24b

Browse files
committed
feat(query-generation): handle missing token count, handle duplicate context
1 parent cf56d36 commit 886c24b

4 files changed

Lines changed: 21 additions & 17 deletions

File tree

src/components/db-query/db-query.graph.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,8 @@ export class DbQueryGraph extends BaseGraph<DbQueryState> {
129129
// GetColumns → GenerateChecklist (fast pass) → parallel fan-out
130130
.addEdge(DbQueryNodes.GetColumns, DbQueryNodes.GenerateChecklist)
131131
.addEdge(DbQueryNodes.GenerateChecklist, DbQueryNodes.SqlGeneration)
132-
.addEdge(DbQueryNodes.GenerateChecklist, DbQueryNodes.GenerateDescription)
133132
.addEdge(DbQueryNodes.GenerateChecklist, DbQueryNodes.VerifyChecklist)
134-
// All three fan-in to PreValidation
135-
.addEdge(DbQueryNodes.GenerateDescription, DbQueryNodes.PreValidation)
133+
// Both fan-in to PreValidation
136134
.addEdge(DbQueryNodes.VerifyChecklist, DbQueryNodes.PreValidation)
137135
// SqlGeneration routes to validation or failure
138136
.addConditionalEdges(
@@ -146,12 +144,14 @@ export class DbQueryGraph extends BaseGraph<DbQueryState> {
146144
Failed: DbQueryNodes.Failed,
147145
},
148146
)
149-
// Parallel fan-out: both validators run concurrently
147+
// Parallel fan-out: validators and description generation run concurrently
150148
.addEdge(DbQueryNodes.PreValidation, DbQueryNodes.SyntacticValidator)
151149
.addEdge(DbQueryNodes.PreValidation, DbQueryNodes.SemanticValidator)
150+
.addEdge(DbQueryNodes.PreValidation, DbQueryNodes.GenerateDescription)
152151
// Fan-in at PostValidation
153152
.addEdge(DbQueryNodes.SyntacticValidator, DbQueryNodes.PostValidation)
154153
.addEdge(DbQueryNodes.SemanticValidator, DbQueryNodes.PostValidation)
154+
.addEdge(DbQueryNodes.GenerateDescription, DbQueryNodes.PostValidation)
155155
.addConditionalEdges(
156156
DbQueryNodes.PostValidation,
157157
(state: DbQueryState) => {
@@ -164,9 +164,7 @@ export class DbQueryGraph extends BaseGraph<DbQueryState> {
164164
return 'Failed';
165165
},
166166
{
167-
// SaveDataset fans-in from both PostValidation and GenerateDescription
168167
Accepted: DbQueryNodes.SaveDataset,
169-
// FixSQL goes through GenerateChecklist (no-op) to re-trigger both SqlGeneration and GenerateDescription
170168
FixSQL: DbQueryNodes.GenerateChecklist,
171169
ReselectTables: DbQueryNodes.GetTables,
172170
Failed: DbQueryNodes.Failed,

src/components/db-query/nodes/generate-description.node.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,19 @@ export class GenerateDescriptionNode implements IGraphNode<DbQueryState> {
2727

2828
prompt = PromptTemplate.fromTemplate(`
2929
<instructions>
30-
You are an AI assistant that summarizes what data a query would fetch to answer the user's question.
31-
Write a concise, bulleted summary in plain english. No SQL, no technical jargon, no table/column names.
30+
You are an AI assistant that describes what a SQL query does in plain english.
31+
Analyze the actual query below and write a concise, bulleted summary of the data it retrieves and any filters/conditions it applies.
32+
Write in plain english. No SQL, no technical jargon, no table/column names.
3233
</instructions>
3334
3435
<user-question>
3536
{prompt}
3637
</user-question>
3738
39+
<sql-query>
40+
{sql}
41+
</sql-query>
42+
3843
<database-schema>
3944
{schema}
4045
</database-schema>
@@ -56,7 +61,7 @@ Return a short bulleted list where each bullet is one condition, filter, or piec
5661
const generateDesc =
5762
this.config.nodes?.sqlGenerationNode?.generateDescription !== false;
5863

59-
if (!generateDesc || state.description) {
64+
if (!generateDesc || !state.sql) {
6065
return {} as DbQueryState;
6166
}
6267

@@ -68,6 +73,7 @@ Return a short bulleted list where each bullet is one condition, filter, or piec
6873
const chain = RunnableSequence.from([this.prompt, this.llm]);
6974
const stream = await chain.stream({
7075
prompt: state.prompt,
76+
sql: state.sql,
7177
schema: this.schemaHelper.asString(state.schema),
7278
checks: [
7379
'<must-follow-rules>',

src/components/db-query/services/db-schema-helper.service.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ export class DbSchemaHelperService {
2828
private readonly config: DbQueryConfig,
2929
) {}
3030
getTablesContext(schema: DatabaseSchema) {
31-
const tableContexts: string[] = [];
31+
const contextSet = new Set<string>();
3232
Object.keys(schema.tables).forEach(table => {
3333
if (schema.tables[table].context) {
3434
for (const item of schema.tables[table].context) {
3535
if (typeof item === 'string' && item.trim().length > 0) {
36-
tableContexts.push(item.trim());
36+
contextSet.add(item.trim());
3737
} else if (typeof item === 'object') {
3838
const tableSet = new Set(
3939
Object.keys(schema.tables).map(t => t.split('.').pop() ?? t),
4040
);
4141
Object.keys(item).forEach(withTable => {
4242
if (tableSet.has(`${withTable}`)) {
43-
tableContexts.push(item[withTable].trim());
43+
contextSet.add(item[withTable].trim());
4444
}
4545
});
4646
} else {
@@ -49,7 +49,7 @@ export class DbSchemaHelperService {
4949
}
5050
}
5151
});
52-
return tableContexts;
52+
return [...contextSet];
5353
}
5454

5555
asString(schema: DatabaseSchema): string {

src/services/token-counter.service.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ export class TokenCounter {
3737
outputTokens: 0,
3838
};
3939
if (usageMetadata) {
40-
this.inputs += usageMetadata.input_tokens;
41-
this.outputs += usageMetadata.output_tokens;
42-
prev.inputTokens += usageMetadata.input_tokens;
43-
prev.outputTokens += usageMetadata.output_tokens;
40+
this.inputs += usageMetadata.input_tokens ?? 0;
41+
this.outputs += usageMetadata.output_tokens ?? 0;
42+
prev.inputTokens += usageMetadata.input_tokens ?? 0;
43+
prev.outputTokens += usageMetadata.output_tokens ?? 0;
4444
this.countMap.set(llmName, prev);
4545
}
4646
return {

0 commit comments

Comments
 (0)