Skip to content

Commit 16e055d

Browse files
committed
feat(query-generation): reduce cognitive complexity
1 parent b14777f commit 16e055d

10 files changed

Lines changed: 117 additions & 74 deletions

File tree

src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import {expect, sinon} from '@loopback/testlab';
22
import {EvaluationResult, SemanticValidatorNode} from '../../../../components';
3+
import {DbSchemaHelperService} from '../../../../components/db-query/services';
34
import {LLMProvider} from '../../../../types';
45

56
describe('SemanticValidatorNode Unit', function () {
@@ -9,8 +10,11 @@ describe('SemanticValidatorNode Unit', function () {
910
beforeEach(() => {
1011
llmStub = sinon.stub();
1112
const llm = llmStub as unknown as LLMProvider;
13+
const schemaHelper = {
14+
asString: sinon.stub().returns(''),
15+
} as unknown as DbSchemaHelperService;
1216

13-
node = new SemanticValidatorNode(llm, llm, {models: []});
17+
node = new SemanticValidatorNode(llm, llm, {models: []}, schemaHelper);
1418
});
1519

1620
afterEach(() => {
@@ -84,11 +88,12 @@ describe('SemanticValidatorNode Unit', function () {
8488
sinon.assert.calledOnce(llmStub);
8589

8690
const prompt = llmStub.firstCall.args[0];
87-
// Verify the prompt contains the checklist and SQL but not schema or user prompt
91+
// Verify the prompt contains the user question, checklist, SQL, and schema
8892
expect(prompt.value).to.containEql(state.sql);
93+
expect(prompt.value).to.containEql(state.prompt);
8994
expect(prompt.value).to.containEql('1. Query selects from users table');
90-
expect(prompt.value).to.not.containEql('<database-schema>');
91-
expect(prompt.value).to.not.containEql('<user-question>');
95+
expect(prompt.value).to.containEql('<database-schema>');
96+
expect(prompt.value).to.containEql('<user-question>');
9297
});
9398

9499
it('should include feedbacks in the prompt', async () => {

src/components/db-query/nodes/generate-checklist.node.ts

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {IGraphNode, LLMStreamEventType} from '../../../graphs';
77
import {AiIntegrationBindings} from '../../../keys';
88
import {LLMProvider} from '../../../types';
99
import {stripThinkingTokens} from '../../../utils';
10+
import {AIMessage} from '@langchain/core/messages';
1011
import {DbQueryAIExtensionBindings} from '../keys';
1112
import {DbQueryNodes} from '../nodes.enum';
1213
import {DbSchemaHelperService} from '../services';
@@ -68,15 +69,14 @@ If no rules are relevant, return: none
6869
state: DbQueryState,
6970
config: LangGraphRunnableConfig,
7071
): Promise<DbQueryState> {
71-
// Skip if checklist was already generated (e.g. retry paths)
72+
const empty = {} as DbQueryState;
7273
if (state.validationChecklist) {
73-
return {} as DbQueryState;
74+
return empty;
7475
}
7576

76-
// Skip for small schemas (1-2 tables) — context is already small enough
7777
const tableCount = Object.keys(state.schema?.tables ?? {}).length;
7878
if (tableCount <= 2) {
79-
return {} as DbQueryState;
79+
return empty;
8080
}
8181

8282
const allChecks = [
@@ -85,14 +85,32 @@ If no rules are relevant, return: none
8585
];
8686

8787
if (allChecks.length === 0) {
88-
return {} as DbQueryState;
88+
return empty;
8989
}
9090

9191
config.writer?.({
9292
type: LLMStreamEventType.Log,
9393
data: 'Filtering validation checklist for semantic validation.',
9494
});
9595

96+
const mergedIndexes = await this.runParallelChecklist(state, allChecks);
97+
98+
if (mergedIndexes.size === 0) {
99+
return empty;
100+
}
101+
102+
const validationChecklist = Array.from(mergedIndexes)
103+
.sort((a, b) => a - b)
104+
.map(i => allChecks[i - 1])
105+
.join('\n');
106+
107+
return {validationChecklist} as DbQueryState;
108+
}
109+
110+
private async runParallelChecklist(
111+
state: DbQueryState,
112+
allChecks: string[],
113+
): Promise<Set<number>> {
96114
const indexedChecks = allChecks
97115
.map((check, i) => `${i + 1}. ${check}`)
98116
.join('\n');
@@ -108,33 +126,25 @@ If no rules are relevant, return: none
108126
indexedChecks,
109127
};
110128

111-
// Run N parallel calls and union the results
112129
const results = await Promise.all(
113130
Array.from({length: parallelism}, () => chain.invoke(invokeArgs)),
114131
);
115132

116133
const mergedIndexes = new Set<number>();
117134
for (const output of results) {
118-
const response = stripThinkingTokens(output).trim();
119-
if (!response) continue;
120-
const indexStr = response;
121-
if (indexStr === 'none') continue;
122-
indexStr
123-
.split(',')
124-
.map(s => parseInt(s.trim(), 10))
125-
.filter(n => !isNaN(n) && n >= 1 && n <= allChecks.length)
126-
.forEach(n => mergedIndexes.add(n));
127-
}
128-
129-
if (mergedIndexes.size === 0) {
130-
return {} as DbQueryState;
135+
this.parseIndexes(output, allChecks.length).forEach(n =>
136+
mergedIndexes.add(n),
137+
);
131138
}
139+
return mergedIndexes;
140+
}
132141

133-
const validationChecklist = Array.from(mergedIndexes)
134-
.sort((a, b) => a - b)
135-
.map(i => allChecks[i - 1])
136-
.join('\n');
137-
138-
return {validationChecklist} as DbQueryState;
142+
private parseIndexes(output: AIMessage, maxIndex: number): number[] {
143+
const response = stripThinkingTokens(output).trim();
144+
if (!response || response === 'none') return [];
145+
return response
146+
.split(',')
147+
.map(s => Number.parseInt(s.trim(), 10))
148+
.filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex);
139149
}
140150
}

src/components/db-query/nodes/get-tables.node.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,13 @@ Use these if they are relevant to the table selection, otherwise ignore them, th
209209
}
210210

211211
private _filterByPermissions(tables: string[]): string[] {
212-
if (!this.permissionHelper) {
212+
const permHelper = this.permissionHelper;
213+
if (!permHelper) {
213214
return tables;
214215
}
215216
return tables.filter(t => {
216217
const name = t.toLowerCase().slice(t.indexOf('.') + 1);
217-
return this.permissionHelper!.findMissingPermissions([name]).length === 0;
218+
return permHelper.findMissingPermissions([name]).length === 0;
218219
});
219220
}
220221

src/components/db-query/nodes/semantic-validator.node.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import {PromptTemplate} from '@langchain/core/prompts';
22
import {RunnableSequence} from '@langchain/core/runnables';
33
import {LangGraphRunnableConfig} from '@langchain/langgraph';
4-
import {inject} from '@loopback/context';
4+
import {inject, service} from '@loopback/core';
55
import {graphNode} from '../../../decorators';
66
import {IGraphNode, LLMStreamEventType} from '../../../graphs';
77
import {AiIntegrationBindings} from '../../../keys';
88
import {LLMProvider} from '../../../types';
99
import {stripThinkingTokens} from '../../../utils';
1010
import {DbQueryAIExtensionBindings} from '../keys';
1111
import {DbQueryNodes} from '../nodes.enum';
12+
import {DbSchemaHelperService} from '../services';
1213
import {DbQueryState} from '../state';
1314
import {DbQueryConfig, EvaluationResult} from '../types';
1415

@@ -21,6 +22,8 @@ export class SemanticValidatorNode implements IGraphNode<DbQueryState> {
2122
private readonly cheapllm: LLMProvider,
2223
@inject(DbQueryAIExtensionBindings.Config)
2324
private readonly config: DbQueryConfig,
25+
@service(DbSchemaHelperService)
26+
private readonly schemaHelper: DbSchemaHelperService,
2427
) {}
2528

2629
prompt = PromptTemplate.fromTemplate(`
@@ -31,10 +34,18 @@ Go through each checklist item and verify it against the SQL query.
3134
DO NOT make up issues that do not exist in the query.
3235
</instructions>
3336
37+
<user-question>
38+
{userPrompt}
39+
</user-question>
40+
3441
<sql-query>
3542
{query}
3643
</sql-query>
3744
45+
<database-schema>
46+
{schema}
47+
</database-schema>
48+
3849
<validation-checklist>
3950
{checklist}
4051
</validation-checklist>
@@ -86,13 +97,15 @@ Keep these feedbacks in mind while validating the new query.
8697
const llm = useSmartLLM ? this.smartllm : this.cheapllm;
8798
const chain = RunnableSequence.from([this.prompt, llm]);
8899
const output = await chain.invoke({
100+
userPrompt: state.prompt,
89101
query: state.sql,
102+
schema: this.schemaHelper.asString(state.schema),
90103
checklist: state.validationChecklist ?? 'No checklist provided.',
91104
feedbacks: await this.getFeedbacks(state),
92105
});
93106
const response = stripThinkingTokens(output);
94107

95-
const invalidMatch = response.match(/<invalid>(.*?)<\/invalid>/s);
108+
const invalidMatch = /<invalid>(.*?)<\/invalid>/s.exec(response);
96109
const isValid =
97110
response.includes('<valid/>') || response.includes('<valid />');
98111

src/components/db-query/nodes/verify-checklist.node.ts

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {IGraphNode, LLMStreamEventType} from '../../../graphs';
77
import {AiIntegrationBindings} from '../../../keys';
88
import {LLMProvider} from '../../../types';
99
import {stripThinkingTokens} from '../../../utils';
10+
import {AIMessage} from '@langchain/core/messages';
1011
import {DbQueryAIExtensionBindings} from '../keys';
1112
import {DbQueryNodes} from '../nodes.enum';
1213
import {DbSchemaHelperService} from '../services';
@@ -93,15 +94,15 @@ If no rules are relevant:
9394
state: DbQueryState,
9495
config: LangGraphRunnableConfig,
9596
): Promise<DbQueryState> {
96-
// Skip on retry — checklist was already verified on the first pass
97+
const empty = {} as DbQueryState;
98+
9799
if (state.feedbacks?.length) {
98-
return {} as DbQueryState;
100+
return empty;
99101
}
100102

101-
// Skip for small schemas (1-2 tables) — context is already small enough
102103
const tableCount = Object.keys(state.schema?.tables ?? {}).length;
103104
if (tableCount <= 2) {
104-
return {} as DbQueryState;
105+
return empty;
105106
}
106107

107108
const allChecks = [
@@ -110,14 +111,34 @@ If no rules are relevant:
110111
];
111112

112113
if (allChecks.length === 0) {
113-
return {} as DbQueryState;
114+
return empty;
114115
}
115116

116117
config.writer?.({
117118
type: LLMStreamEventType.Log,
118119
data: 'Verifying validation checklist with chain-of-thought.',
119120
});
120121

122+
const output = await this.invokeVerification(state, allChecks);
123+
const verifiedIndexes = this.parseVerifiedIndexes(output, allChecks.length);
124+
125+
if (verifiedIndexes.length === 0) {
126+
return empty;
127+
}
128+
129+
const validationChecklist = this.mergeWithExisting(
130+
state.validationChecklist,
131+
verifiedIndexes,
132+
allChecks,
133+
);
134+
135+
return {validationChecklist} as DbQueryState;
136+
}
137+
138+
private async invokeVerification(
139+
state: DbQueryState,
140+
allChecks: string[],
141+
): Promise<AIMessage> {
121142
const indexedChecks = allChecks
122143
.map((check, i) => `${i + 1}. ${check}`)
123144
.join('\n');
@@ -130,42 +151,40 @@ If no rules are relevant:
130151
? this.evaluationOutputInstructions
131152
: this.simpleOutputInstructions),
132153
);
154+
133155
const chain = RunnableSequence.from([promptTemplate, this.llm]);
134-
const output = await chain.invoke({
156+
return chain.invoke({
135157
prompt: state.prompt,
136158
tables: Object.keys(state.schema?.tables ?? {}).join(', '),
137159
schema: this.schemaHelper.asString(state.schema),
138160
indexedChecks,
139161
});
162+
}
140163

164+
private parseVerifiedIndexes(output: AIMessage, maxIndex: number): number[] {
141165
const response = stripThinkingTokens(output).trim();
142-
const resultMatch = response.match(/<result>(.*?)<\/result>/s);
166+
const resultMatch = /<result>(.*?)<\/result>/s.exec(response);
143167
const indexStr = resultMatch ? resultMatch[1].trim() : response;
144168

145-
if (indexStr === 'none' || !indexStr) {
146-
return {} as DbQueryState;
147-
}
169+
if (!indexStr || indexStr === 'none') return [];
148170

149-
const verifiedIndexes = indexStr
171+
return indexStr
150172
.split(',')
151-
.map(s => parseInt(s.trim(), 10))
152-
.filter(n => !isNaN(n) && n >= 1 && n <= allChecks.length);
153-
154-
if (verifiedIndexes.length === 0) {
155-
return {} as DbQueryState;
156-
}
173+
.map(s => Number.parseInt(s.trim(), 10))
174+
.filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex);
175+
}
157176

158-
// Merge with existing checklist — union of both passes
177+
private mergeWithExisting(
178+
existing: string | undefined,
179+
verifiedIndexes: number[],
180+
allChecks: string[],
181+
): string {
159182
const existingChecks = new Set(
160-
(state.validationChecklist ?? '').split('\n').filter(c => c.length > 0),
183+
(existing ?? '').split('\n').filter(c => c.length > 0),
161184
);
162-
const verifiedChecks = verifiedIndexes.map(i => allChecks[i - 1]);
163-
for (const check of verifiedChecks) {
185+
for (const check of verifiedIndexes.map(i => allChecks[i - 1])) {
164186
existingChecks.add(check);
165187
}
166-
167-
const validationChecklist = Array.from(existingChecks).join('\n');
168-
169-
return {validationChecklist} as DbQueryState;
188+
return Array.from(existingChecks).join('\n');
170189
}
171190
}

src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ import {Concept, GraphEdge, GraphNode, KnowledgeGraph} from './types';
1818
const debug = require('debug')('ai-integration:knowledge-graph');
1919

2020
@injectable({scope: BindingScope.SINGLETON})
21-
export class DbKnowledgeGraphService
22-
implements KnowledgeGraph<string, DatabaseSchema>
23-
{
21+
export class DbKnowledgeGraphService implements KnowledgeGraph<
22+
string,
23+
DatabaseSchema
24+
> {
2425
edges: Map<string, GraphEdge[]>;
2526
nodes: Map<string, GraphNode>;
2627
private vectorWeight: number;

src/components/db-query/tools/get-data-as-dataset.tool.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ export class GetDataAsDatasetTool implements IGraphTool {
4848
prompt: z
4949
.string()
5050
.describe(
51-
`The user's request describing what data they need from the database.`,
51+
`Prompt from the user that will be used for generating an SQL query and create a dataset from it.`,
5252
),
5353
}) as AnyObject[string];
5454
return graph.asTool({
5555
name: this.key,
56-
description: `Tool for fetching data from the database and returning it as a dataset based on the user's request.
57-
Use this whenever the user wants to retrieve, look up, or explore data from the database.
58-
It returns a dataset ID and renders a data grid on the UI for the user to see the results.`,
56+
description: `Query tool for generating SQL queries for a users request. Use it to find data from the database based on the user's request.
57+
Note that it does not return the query, instead only a dataset ID that is not relevant to the user.
58+
It internally fires an event that renders a grid for the dataset on the UI for the user to see.`,
5959
schema,
6060
});
6161
}

src/components/visualization/nodes/call-query-generation.node.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ import {VisualizationGraphState} from '../state';
88
@graphNode(VisualizationGraphNodes.CallQueryGeneration, {
99
[POST_DATASET_TAG]: true,
1010
})
11-
export class CallQueryGenerationNode
12-
implements IGraphNode<VisualizationGraphState>
13-
{
11+
export class CallQueryGenerationNode implements IGraphNode<VisualizationGraphState> {
1412
constructor(
1513
@service(DbQueryGraph)
1614
private readonly queryPipeline: DbQueryGraph,

0 commit comments

Comments
 (0)