Skip to content

Commit 2cf0c6f

Browse files
committed
feat(query-generation): add fix query flow, add query routing
1 parent 886c24b commit 2cf0c6f

20 files changed

Lines changed: 1003 additions & 92 deletions

src/__tests__/db-query/unit/db-query.graph.unit.ts

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ describe(`DbQueryGraph Unit`, function () {
3030
// Parallel branches must return partial state to avoid LastValue conflicts
3131
stubMap[DbQueryNodes.GetTables].callsFake(async () => ({}));
3232
stubMap[DbQueryNodes.CheckCache].callsFake(async () => ({}));
33+
stubMap[DbQueryNodes.GetColumns].callsFake(async () => ({}));
34+
stubMap[DbQueryNodes.ClassifyChange].callsFake(async () => ({}));
35+
stubMap[DbQueryNodes.FixQuery].callsFake(async () => ({}));
3336
// Checklist + Description run in parallel — must return partial state
3437
stubMap[DbQueryNodes.GenerateChecklist].callsFake(async () => ({
3538
validationChecklist: '1. Test check',
@@ -74,7 +77,7 @@ describe(`DbQueryGraph Unit`, function () {
7477
expect(stubMap[DbQueryNodes.Failed].called).to.be.false();
7578
});
7679

77-
it('should retry generation if syntactic validation fails with query error', async () => {
80+
it('should fix query via FixQuery if syntactic validation fails with query error', async () => {
7881
const compiledGraph = await graph.build();
7982
let syntacticRetryCount = 0;
8083
stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => {
@@ -102,7 +105,9 @@ describe(`DbQueryGraph Unit`, function () {
102105
expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true();
103106
expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true();
104107
expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true();
105-
expect(stubMap[DbQueryNodes.SqlGeneration].calledTwice).to.be.true();
108+
// SqlGeneration called once; FixQuery handles the retry
109+
expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true();
110+
expect(stubMap[DbQueryNodes.FixQuery].calledOnce).to.be.true();
106111
expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true();
107112
// Semantic runs in parallel with syntactic on both attempts
108113
expect(stubMap[DbQueryNodes.SemanticValidator].calledTwice).to.be.true();
@@ -137,7 +142,9 @@ describe(`DbQueryGraph Unit`, function () {
137142

138143
expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true();
139144
expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true();
145+
// GetTables called twice: initial + retry after table error
140146
expect(stubMap[DbQueryNodes.GetTables].calledTwice).to.be.true();
147+
// SqlGeneration called twice: once per full pipeline pass
141148
expect(stubMap[DbQueryNodes.SqlGeneration].calledTwice).to.be.true();
142149
expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true();
143150
expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true();
@@ -165,17 +172,20 @@ describe(`DbQueryGraph Unit`, function () {
165172
expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true();
166173
expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true();
167174
expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true();
168-
expect(stubMap[DbQueryNodes.SqlGeneration].getCalls().length).to.be.eql(
169-
MAX_ATTEMPTS,
170-
);
175+
// SqlGeneration runs once; FixQuery handles subsequent retries
176+
expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true();
171177
expect(
172178
stubMap[DbQueryNodes.SyntacticValidator].getCalls().length,
173179
).to.be.eql(MAX_ATTEMPTS);
180+
// FixQuery called MAX_ATTEMPTS - 1 times (first attempt via SqlGeneration)
181+
expect(stubMap[DbQueryNodes.FixQuery].getCalls().length).to.be.eql(
182+
MAX_ATTEMPTS - 1,
183+
);
174184
expect(stubMap[DbQueryNodes.Failed].calledOnce).to.be.true();
175-
expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.false();
185+
expect(stubMap[DbQueryNodes.SaveDataset].called).to.be.false();
176186
});
177187

178-
it('should retry generation if semantic validation fails with query error', async () => {
188+
it('should fix query via FixQuery if semantic validation fails with query error', async () => {
179189
const compiledGraph = await graph.build();
180190
let semanticRetryCount = 0;
181191
stubMap[DbQueryNodes.SemanticValidator].callsFake(async () => {
@@ -203,7 +213,9 @@ describe(`DbQueryGraph Unit`, function () {
203213
expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true();
204214
expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true();
205215
expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true();
206-
expect(stubMap[DbQueryNodes.SqlGeneration].calledTwice).to.be.true();
216+
// SqlGeneration called once; FixQuery handles the retry
217+
expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true();
218+
expect(stubMap[DbQueryNodes.FixQuery].calledOnce).to.be.true();
207219
expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true();
208220
expect(stubMap[DbQueryNodes.SemanticValidator].calledTwice).to.be.true();
209221
expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true();
@@ -235,10 +247,11 @@ describe(`DbQueryGraph Unit`, function () {
235247
expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true();
236248
expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true();
237249
expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true();
238-
expect(stubMap[DbQueryNodes.SqlGeneration].getCalls().length).to.be.eql(
239-
MAX_ATTEMPTS,
240-
);
250+
// SqlGeneration runs once; FixQuery handles retries
251+
expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true();
252+
// With both validators failing, feedbacks grow by 2 per iteration
253+
// so it reaches MAX_ATTEMPTS faster
241254
expect(stubMap[DbQueryNodes.Failed].calledOnce).to.be.true();
242-
expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.false();
255+
expect(stubMap[DbQueryNodes.SaveDataset].called).to.be.false();
243256
});
244257
});
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import {expect, sinon} from '@loopback/testlab';
2+
import {ChangeType, ClassifyChangeNode} from '../../../../components';
3+
import {LLMProvider} from '../../../../types';
4+
5+
describe('ClassifyChangeNode Unit', function () {
6+
let node: ClassifyChangeNode;
7+
let llmStub: sinon.SinonStub;
8+
9+
beforeEach(() => {
10+
llmStub = sinon.stub();
11+
const llm = llmStub as unknown as LLMProvider;
12+
node = new ClassifyChangeNode(llm);
13+
});
14+
15+
afterEach(() => {
16+
sinon.restore();
17+
});
18+
19+
it('should return empty state when sampleSql is not present', async () => {
20+
const state = {
21+
prompt: 'Get all users',
22+
schema: {tables: {}, relations: []},
23+
sampleSql: undefined,
24+
sampleSqlPrompt: undefined,
25+
};
26+
27+
const result = await node.execute(state as any, {});
28+
29+
expect(result).to.deepEqual({});
30+
sinon.assert.notCalled(llmStub);
31+
});
32+
33+
it('should classify as Minor for small changes', async () => {
34+
llmStub.resolves({
35+
content: 'minor',
36+
});
37+
38+
const state = {
39+
prompt: 'Get users with age > 25',
40+
schema: {tables: {}, relations: []},
41+
sampleSql: 'SELECT * FROM users WHERE age > 20',
42+
sampleSqlPrompt: 'Get users with age > 20',
43+
};
44+
45+
const result = await node.execute(state as any, {});
46+
47+
expect(result.changeType).to.equal(ChangeType.Minor);
48+
sinon.assert.calledOnce(llmStub);
49+
});
50+
51+
it('should classify as Major for structural changes', async () => {
52+
llmStub.resolves({
53+
content: 'major',
54+
});
55+
56+
const state = {
57+
prompt: 'Get users with their orders and total amount',
58+
schema: {tables: {}, relations: []},
59+
sampleSql: 'SELECT * FROM users',
60+
sampleSqlPrompt: 'Get all users',
61+
};
62+
63+
const result = await node.execute(state as any, {});
64+
65+
expect(result.changeType).to.equal(ChangeType.Major);
66+
sinon.assert.calledOnce(llmStub);
67+
});
68+
69+
it('should classify as Rewrite for fundamentally different queries', async () => {
70+
llmStub.resolves({
71+
content: 'rewrite',
72+
});
73+
74+
const state = {
75+
prompt: 'Get monthly revenue breakdown by product category',
76+
schema: {tables: {}, relations: []},
77+
sampleSql: 'SELECT * FROM users',
78+
sampleSqlPrompt: 'Get all users',
79+
};
80+
81+
const result = await node.execute(state as any, {});
82+
83+
expect(result.changeType).to.equal(ChangeType.Rewrite);
84+
sinon.assert.calledOnce(llmStub);
85+
});
86+
87+
it('should default to Major for unrecognized LLM responses', async () => {
88+
llmStub.resolves({
89+
content: 'something unexpected',
90+
});
91+
92+
const state = {
93+
prompt: 'Get users',
94+
schema: {tables: {}, relations: []},
95+
sampleSql: 'SELECT * FROM users',
96+
sampleSqlPrompt: 'Get all users',
97+
};
98+
99+
const result = await node.execute(state as any, {});
100+
101+
expect(result.changeType).to.equal(ChangeType.Major);
102+
});
103+
104+
it('should pass original and new descriptions to the LLM', async () => {
105+
llmStub.resolves({
106+
content: 'minor',
107+
});
108+
109+
const state = {
110+
prompt: 'Get users with age > 30',
111+
schema: {tables: {}, relations: []},
112+
sampleSql: 'SELECT * FROM users WHERE age > 20',
113+
sampleSqlPrompt: 'Get users with age > 20',
114+
};
115+
116+
await node.execute(state as any, {});
117+
118+
const prompt = llmStub.firstCall.args[0];
119+
expect(prompt.value).to.containEql('Get users with age > 20');
120+
expect(prompt.value).to.containEql('Get users with age > 30');
121+
});
122+
123+
it('should handle empty sampleSqlPrompt gracefully', async () => {
124+
llmStub.resolves({
125+
content: 'major',
126+
});
127+
128+
const state = {
129+
prompt: 'Get all users',
130+
schema: {tables: {}, relations: []},
131+
sampleSql: 'SELECT * FROM users',
132+
sampleSqlPrompt: undefined,
133+
};
134+
135+
const result = await node.execute(state as any, {});
136+
137+
expect(result.changeType).to.equal(ChangeType.Major);
138+
sinon.assert.calledOnce(llmStub);
139+
});
140+
141+
it('should handle LLM response with extra whitespace and casing', async () => {
142+
llmStub.resolves({
143+
content: ' Minor \n',
144+
});
145+
146+
const state = {
147+
prompt: 'Get users with age > 25',
148+
schema: {tables: {}, relations: []},
149+
sampleSql: 'SELECT * FROM users WHERE age > 20',
150+
sampleSqlPrompt: 'Get users with age > 20',
151+
};
152+
153+
const result = await node.execute(state as any, {});
154+
155+
expect(result.changeType).to.equal(ChangeType.Minor);
156+
});
157+
});

0 commit comments

Comments
 (0)