Skip to content

Commit 2bb019f

Browse files
refactor(verify-checklist): simplify prompt to reduce LLM reasoning output
pass all tables to semantic-validator node issues-fix
1 parent 4a106d1 commit 2bb019f

File tree

9 files changed

+119
-350
lines changed

9 files changed

+119
-350
lines changed

package-lock.json

Lines changed: 2 additions & 336 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,33 @@ import {
44
EvaluationResult,
55
SemanticValidatorNode,
66
} from '../../../../components';
7-
import {DbSchemaHelperService} from '../../../../components/db-query/services';
7+
import {
8+
DbSchemaHelperService,
9+
TableSearchService,
10+
} from '../../../../components/db-query/services';
811
import {LLMProvider} from '../../../../types';
912

1013
describe('SemanticValidatorNode Unit', function () {
1114
let node: SemanticValidatorNode;
1215
let llmStub: sinon.SinonStub;
16+
let tableSearchStub: sinon.SinonStubbedInstance<TableSearchService>;
1317

1418
beforeEach(() => {
1519
llmStub = sinon.stub();
1620
const llm = llmStub as unknown as LLMProvider;
1721
const schemaHelper = {
1822
asString: sinon.stub().returns(''),
1923
} as unknown as DbSchemaHelperService;
24+
tableSearchStub = sinon.createStubInstance(TableSearchService);
25+
tableSearchStub.getTables.resolves([]);
2026

21-
node = new SemanticValidatorNode(llm, llm, {models: []}, schemaHelper);
27+
node = new SemanticValidatorNode(
28+
llm,
29+
llm,
30+
{models: []},
31+
tableSearchStub,
32+
schemaHelper,
33+
);
2234
});
2335

2436
afterEach(() => {
@@ -64,6 +76,7 @@ describe('SemanticValidatorNode Unit', function () {
6476
});
6577

6678
it('should return QueryError if the query is invalid', async () => {
79+
tableSearchStub.getTables.resolves(['users', 'orders']);
6780
const state = {
6881
prompt: 'Get all users',
6982
sql: 'SELECT * FROM invalid_table',
@@ -154,4 +167,72 @@ describe('SemanticValidatorNode Unit', function () {
154167
const prompt = llmStub.firstCall.args[0];
155168
expect(prompt.value).to.containEql('the previous query was wrong');
156169
});
170+
171+
it('should pass all accessible tables from tableSearchService into available-tables so LLM can flag missing ones', async () => {
172+
const searchedTables = [
173+
'public.users',
174+
'public.orders',
175+
'public.payments',
176+
'analytics.reports',
177+
];
178+
tableSearchStub = sinon.createStubInstance(TableSearchService);
179+
tableSearchStub.getTables.resolves(searchedTables);
180+
181+
const schemaHelper = {
182+
asString: sinon.stub().returns(''),
183+
} as unknown as DbSchemaHelperService;
184+
185+
const nodeWithTables = new SemanticValidatorNode(
186+
llmStub as unknown as LLMProvider,
187+
llmStub as unknown as LLMProvider,
188+
{models: []},
189+
tableSearchStub,
190+
schemaHelper,
191+
);
192+
193+
const state = {
194+
prompt: 'Get revenue per user',
195+
sql: 'SELECT u.name, SUM(p.amount) FROM users u JOIN payments p ON u.id = p.user_id GROUP BY u.name',
196+
schema: {tables: {}, relations: []},
197+
status: EvaluationResult.Pass,
198+
id: 'test-id',
199+
feedbacks: [],
200+
replyToUser: '',
201+
datasetId: 'test-dataset-id',
202+
done: false,
203+
sampleSqlPrompt: '',
204+
sampleSql: '',
205+
fromCache: false,
206+
resultArray: undefined,
207+
directCall: false,
208+
description: undefined,
209+
syntacticStatus: undefined,
210+
syntacticFeedback: undefined,
211+
semanticStatus: undefined,
212+
semanticFeedback: undefined,
213+
syntacticErrorTables: undefined,
214+
semanticErrorTables: undefined,
215+
fromTemplate: undefined,
216+
templateId: undefined,
217+
validationChecklist: '1. Revenue grouped by user',
218+
changeType: undefined,
219+
};
220+
221+
llmStub.resolves({content: '<valid/>'});
222+
223+
await nodeWithTables.execute(state, {});
224+
225+
sinon.assert.calledOnce(tableSearchStub.getTables);
226+
expect(tableSearchStub.getTables.firstCall.args[0]).to.equal(
227+
'Get revenue per user',
228+
);
229+
expect(tableSearchStub.getTables.firstCall.args[1]).to.equal(10);
230+
231+
sinon.assert.calledOnce(llmStub);
232+
const prompt = llmStub.firstCall.args[0];
233+
expect(prompt.value).to.containEql('<available-tables>');
234+
expect(prompt.value).to.containEql(
235+
'public.users, public.orders, public.payments, analytics.reports',
236+
);
237+
});
157238
});
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export * from './dataset.model';
21
export * from './dataset-action.model';
32
export * from './dataset-update-dto.model';
3+
export * from './dataset.model';
44
export * from './query-template-dto.model';

src/components/db-query/nodes/semantic-validator.node.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ import {LLMProvider} from '../../../types';
99
import {stripThinkingTokens} from '../../../utils';
1010
import {DbQueryAIExtensionBindings} from '../keys';
1111
import {DbQueryNodes} from '../nodes.enum';
12-
import {DbSchemaHelperService} from '../services';
12+
import {
13+
DbSchemaHelperService,
14+
PermissionHelper,
15+
TableSearchService,
16+
} from '../services';
1317
import {DbQueryState} from '../state';
1418
import {DbQueryConfig, EvaluationResult} from '../types';
1519

@@ -22,8 +26,12 @@ export class SemanticValidatorNode implements IGraphNode<DbQueryState> {
2226
private readonly cheapllm: LLMProvider,
2327
@inject(DbQueryAIExtensionBindings.Config)
2428
private readonly config: DbQueryConfig,
29+
@service(TableSearchService)
30+
private readonly tableSearchService: TableSearchService,
2531
@service(DbSchemaHelperService)
2632
private readonly schemaHelper: DbSchemaHelperService,
33+
@service(PermissionHelper)
34+
private readonly permissionHelper?: PermissionHelper,
2735
) {}
2836

2937
prompt = PromptTemplate.fromTemplate(`
@@ -103,13 +111,15 @@ Keep these feedbacks in mind while validating the new query.
103111
const useSmartLLM =
104112
this.config.nodes?.semanticValidatorNode?.useSmartLLM ?? false;
105113
const llm = useSmartLLM ? this.smartllm : this.cheapllm;
106-
const tableNames = Object.keys(state.schema?.tables ?? {});
114+
const tableList =
115+
(await this.tableSearchService.getTables(state.prompt)) ?? [];
116+
const accessibleTables = this._filterByPermissions(tableList);
107117
const chain = RunnableSequence.from([this.prompt, llm]);
108118
const output = await chain.invoke({
109119
userPrompt: state.prompt,
110120
query: state.sql,
111121
schema: this.schemaHelper.asString(state.schema),
112-
tableNames: tableNames.join(', '),
122+
tableNames: accessibleTables.join(', '),
113123
checklist: state.validationChecklist ?? 'No checklist provided.',
114124
feedbacks: await this.getFeedbacks(state),
115125
});
@@ -153,4 +163,15 @@ Keep these feedbacks in mind while validating the new query.
153163
}
154164
return '';
155165
}
166+
167+
private _filterByPermissions(tables: string[]): string[] {
168+
const permHelper = this.permissionHelper;
169+
if (!permHelper) {
170+
return tables;
171+
}
172+
return tables.filter(t => {
173+
const name = t.toLowerCase().slice(t.indexOf('.') + 1);
174+
return permHelper.findMissingPermissions([name]).length === 0;
175+
});
176+
}
156177
}

src/components/db-query/nodes/verify-checklist.node.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import {AIMessage} from '@langchain/core/messages';
12
import {PromptTemplate} from '@langchain/core/prompts';
23
import {RunnableSequence} from '@langchain/core/runnables';
34
import {LangGraphRunnableConfig} from '@langchain/langgraph';
@@ -7,7 +8,6 @@ import {IGraphNode, LLMStreamEventType} from '../../../graphs';
78
import {AiIntegrationBindings} from '../../../keys';
89
import {LLMProvider} from '../../../types';
910
import {stripThinkingTokens} from '../../../utils';
10-
import {AIMessage} from '@langchain/core/messages';
1111
import {DbQueryAIExtensionBindings} from '../keys';
1212
import {DbQueryNodes} from '../nodes.enum';
1313
import {DbSchemaHelperService} from '../services';
@@ -43,7 +43,7 @@ A rule is relevant if:
4343
- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included).
4444
- It applies to any of the selected tables or their relationships.
4545
46-
After selecting relevant rules, review your selection and ensure:
46+
Ensure:
4747
- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included.
4848
- Do not include rules that are completely unrelated to the question, schema, or selected tables.
4949
</instructions>
@@ -82,8 +82,8 @@ If no rules are relevant: <result>none</result>
8282
</output-instructions>`;
8383

8484
simpleOutputInstructions = `<output-instructions>
85-
Return only a comma-separated list of the relevant rule indexes inside a result tag.
86-
Do not include any other text, explanation, or formatting.
85+
Return ONLY the comma-separated list of relevant rule indexes inside a result tag.
86+
Do NOT include any reasoning, analysis, or explanation — only the result tag.
8787
Example:
8888
<result>1,3,5</result>
8989
If no rules are relevant:
@@ -96,7 +96,7 @@ If no rules are relevant:
9696
): Promise<DbQueryState> {
9797
const empty = {} as DbQueryState;
9898

99-
if (this.config.nodes?.generateChecklistNode?.enabled === false) {
99+
if (this.config.nodes?.verifyChecklistNode?.enabled === false) {
100100
return empty;
101101
}
102102

src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export class DbKnowledgeGraphService implements KnowledgeGraph<
5353
config.knowledgeGraph?.maxClusterSize ?? MAX_CLUSTER_SIZE; // Default max cluster size
5454
}
5555

56-
async find(query: string, topK: number): Promise<string[]> {
56+
async find(query: string, topK = 10): Promise<string[]> {
5757
debug(`Selecting tables for query: "${query}"`);
5858

5959
// Step 1: Generate query embedding

src/components/db-query/services/knowledge-graph/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export interface Graph {
4646
export interface KnowledgeGraph<T, S> extends Graph {
4747
toJSON(): string;
4848
fromJSON(json: string): void;
49-
find(query: string, count: number): Promise<T[]>;
49+
find(query: string, count?: number): Promise<T[]>;
5050
seed(data: S): Promise<void>;
5151
}
5252

src/components/db-query/services/search/table-search.service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export class TableSearchService {
3131
private readonly dbSchemaHelper: DbSchemaHelperService,
3232
) {}
3333

34-
async getTables(prompt: string, count: number): Promise<string[]> {
34+
async getTables(prompt: string, count?: number): Promise<string[]> {
3535
if (this.config.noKnowledgeGraph) {
3636
return this._tables;
3737
}

src/components/db-query/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ export type DbQueryConfig = {
134134
parallelism?: number;
135135
};
136136
verifyChecklistNode?: {
137+
enabled?: boolean;
137138
evaluation?: boolean;
138139
};
139140
};

0 commit comments

Comments
 (0)