Skip to content

Commit 632c290

Browse files
committed
feat(component): add test utility
1 parent 9ecb0b2 commit 632c290

6 files changed

Lines changed: 456 additions & 3 deletions

File tree

README.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,112 @@ export class AddTool implements IGraphTool {
355355
}
356356
}
357357
```
358+
359+
# Testing
360+
361+
## Generation Acceptance Builder
362+
363+
The `generation.acceptance.builder.ts` file provides a utility to run acceptance tests for the `llm-chat-component`. These tests validate the functionality of the `/reply` endpoint and ensure that the generated SQL queries and their results align with expectations.
364+
365+
## Overview
366+
367+
This builder facilitates the execution of multiple test cases, each defined with specific prompts, expected results, and configurations. It also generates detailed reports to analyze the performance and correctness of the tests.
368+
369+
## Key Features
370+
371+
- **Dynamic Prompt Parsing**: Replaces placeholders in prompts with environment-specific values.
372+
- **Token Generation**: Creates JWT tokens with required permissions for test execution.
373+
- **Query Execution**: Executes the generated SQL queries and compares the results with expected outputs.
374+
- **Detailed Reporting**: Generates markdown reports with metrics such as success rates, token usage, and execution times.
375+
376+
## Usage
377+
378+
### Importing the Builder
379+
380+
```typescript
381+
import {generationAcceptanceBuilder} from './generation.acceptance.builder';
382+
```
383+
384+
### Running Tests
385+
386+
To use the builder, define your test cases as an array of `GenerationAcceptanceTestCase` objects and pass them to the `generationAcceptanceBuilder` function along with the required parameters.
387+
388+
#### Example
389+
390+
```typescript
391+
const testCases = [
392+
{
393+
case: 'Test Case 1',
394+
prompt: 'Find all the active resources',
395+
outputInstructions:
396+
'The output should have a single column `resource_name` arranged in alphabetical order.',
397+
resultQuery:
398+
'SELECT name as resource_name FROM resource WHERE status = 1 ORDER BY name',
399+
count: 1,
400+
},
401+
];
402+
403+
const result = await generationAcceptanceBuilder(
404+
testCases,
405+
client,
406+
app,
407+
1,
408+
true,
409+
);
410+
console.log(result);
411+
```
412+
413+
### Parameters
414+
415+
- `cases`: An array of test cases to execute.
416+
- `client`: The LoopBack test client.
417+
- `app`: The LoopBack application instance.
418+
- `countPerPrompt`: Number of iterations per test case (default: 1).
419+
- `writeReport`: Whether to generate a markdown report (default: false).
420+
421+
### Test Case Structure
422+
423+
Each test case should follow the `GenerationAcceptanceTestCase` interface:
424+
425+
```typescript
426+
interface GenerationAcceptanceTestCase {
427+
case: string; // Name of the test case
428+
prompt: string; // Prompt to send to the LLM
429+
outputInstructions: string; // Additional instructions for the output
430+
resultQuery: string; // Expected SQL query
431+
count?: number; // Number of iterations (optional)
432+
only?: boolean; // Run only this test case (optional)
433+
skip?: boolean; // Skip this test case (optional)
434+
}
435+
```
436+
437+
## Report Generation
438+
439+
The builder generates a markdown report summarizing the test results. The report includes:
440+
441+
- Success metrics
442+
- Time metrics
443+
- Token usage metrics
444+
- Detailed results for each test case
445+
- Failed queries with actual and expected results
446+
447+
The report is saved in the `llm-reports` directory with a filename based on the model name.
448+
449+
## Environment Variables
450+
451+
The builder relies on the following environment variables:
452+
453+
- `SAMPLE_DEAL_NAME`: Default value for `<testDeal>` placeholder.
454+
- `TEST_TENANT_ID`: Tenant ID for token generation.
455+
- `JWT_SECRET`: Secret key for signing JWT tokens.
456+
- `JWT_ISSUER`: Issuer for JWT tokens.
457+
458+
## Dependencies
459+
460+
- `@loopback/testlab`
461+
- `@loopback/core`
462+
- `@loopback/repository`
463+
- `@sourceloop/core`
464+
- `jsonwebtoken`
465+
- `crypto`
466+
- `fs`

package.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
"default": "./dist/sub-modules/providers/pg/index.js"
4747
},
4848
"./db-query/testing": {
49-
"type": "./dist/sub-modules/providers/db-query/testing/index.d.ts",
50-
"default": "./dist/sub-modules/providers/db-query/testing/index.js"
49+
"type": "./dist/components/db-query/testing/index.d.ts",
50+
"default": "./dist/components/db-query/testing/index.js"
5151
}
5252
},
5353
"typesVersions": {
@@ -77,7 +77,7 @@
7777
"dist/sub-modules/providers/pg/index.d.ts"
7878
],
7979
"db-query/testing": [
80-
"dist/sub-modules/providers/db-query/testing/index.d.ts"
80+
"dist/components/db-query/testing/index.d.ts"
8181
]
8282
}
8383
},
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import {Client} from '@loopback/testlab';
2+
import {
3+
GenerationAcceptanceSuiteResult,
4+
GenerationAcceptanceTestCase,
5+
GenerationAcceptanceTestResult,
6+
} from './types';
7+
import {Application} from '@loopback/core';
8+
import {PermissionKey} from '../../../permissions';
9+
import {DbQueryAIExtensionBindings} from '../keys';
10+
import {sign} from 'jsonwebtoken';
11+
import {randomUUID} from 'crypto';
12+
import {
13+
LLMStreamEvent,
14+
LLMStreamEventType,
15+
LLMStreamTokenCountEvent,
16+
LLMStreamToolStatusEvent,
17+
ToolStatus,
18+
} from '../../../graphs';
19+
import {generateMarkdownTable, getModelNameFromEnv} from './utils';
20+
import {writeFileSync} from 'fs';
21+
import {juggler} from '@loopback/repository';
22+
import {ILogger, LOGGER} from '@sourceloop/core';
23+
24+
function parsePrompt(prompt: string) {
25+
const keys: Record<string, string> = {
26+
testDeal: process.env.SAMPLE_DEAL_NAME ?? 'test-deal',
27+
};
28+
for (const key of Object.keys(keys)) {
29+
prompt = prompt.replace(new RegExp(`\\<${key}\\>`, 'g'), keys[key]);
30+
}
31+
return prompt;
32+
}
33+
34+
function parseQuery(prompt: string) {
35+
const keys: Record<string, string> = {
36+
testDeal: (process.env.SAMPLE_DEAL_NAME ?? 'test-deal')
37+
.split(' ')
38+
.join('%')
39+
.split('_')
40+
.join('%'),
41+
tenantId: process.env.TEST_TENANT_ID ?? 'test-tenant',
42+
date: new Date().toISOString().split('T')[0],
43+
};
44+
for (const key of Object.keys(keys)) {
45+
prompt = prompt.replace(new RegExp(`\\<${key}\\>`, 'g'), keys[key]);
46+
}
47+
return prompt;
48+
}
49+
50+
function tokenBuilder(tenantid: string, permissions: string[]) {
51+
return sign(
52+
{
53+
id: randomUUID(),
54+
userTenantId: randomUUID(),
55+
permissions: permissions,
56+
tenantId: tenantid,
57+
},
58+
process.env.JWT_SECRET ?? '',
59+
{
60+
issuer: process.env.JWT_ISSUER ?? '',
61+
},
62+
);
63+
}
64+
65+
export async function generationAcceptanceBuilder(
66+
cases: GenerationAcceptanceTestCase[],
67+
client: Client,
68+
app: Application,
69+
countPerPrompt = 1,
70+
writeReport = false,
71+
): Promise<GenerationAcceptanceSuiteResult> {
72+
// setup app
73+
const config = app.getSync(DbQueryAIExtensionBindings.Config);
74+
const token = tokenBuilder(process.env.TEST_TENANT_ID ?? 'test-tenant', [
75+
...config.models.map(v => v.readPermissionKey),
76+
PermissionKey.AskAI,
77+
PermissionKey.ViewDataset,
78+
PermissionKey.ExecuteDataset,
79+
]);
80+
const datasetStore = await app.get(DbQueryAIExtensionBindings.DatasetStore);
81+
const ds = await app.get<juggler.DataSource>('datasources.db');
82+
const logger = await app.get<ILogger>(LOGGER.LOGGER_INJECT);
83+
84+
const results: GenerationAcceptanceTestResult[] = [];
85+
const anyOnly = cases.some(q => q.only);
86+
const queriesToRun = anyOnly
87+
? cases.filter(q => q.only && !q.skip)
88+
: cases.filter(q => !q.skip);
89+
90+
for (const query of queriesToRun) {
91+
const count = query.count ?? countPerPrompt;
92+
for (let i = 0; i < count; i++) {
93+
logger.info(
94+
`Running query: ${query.case} ${i > 0 ? `Iteration: ${i + 1}` : ''}`,
95+
);
96+
const result: GenerationAcceptanceTestResult = {
97+
success: false,
98+
time: 0,
99+
inputTokens: 0,
100+
outputTokens: 0,
101+
emptyOutput: false,
102+
generationCount: 0,
103+
query: '',
104+
case: query.case,
105+
description: '',
106+
actualResult: null,
107+
expectedResult: null,
108+
};
109+
try {
110+
const startTime = Date.now();
111+
const {body} = await client
112+
.post('/reply')
113+
.set('Authorization', `Bearer ${token}`)
114+
.field(
115+
'prompt',
116+
`${parsePrompt(query.prompt)}. ${query.outputInstructions}`,
117+
)
118+
.expect(200);
119+
// time in seconds
120+
result.time = (Date.now() - startTime) / 1000;
121+
const status = body.filter(
122+
(v: LLMStreamEvent) => v.type === LLMStreamEventType.ToolStatus,
123+
);
124+
const lastStatus: LLMStreamToolStatusEvent = status[status.length - 1];
125+
const [tokenCount]: LLMStreamTokenCountEvent[] = body.filter(
126+
(v: LLMStreamEvent) => v.type === LLMStreamEventType.TokenCount,
127+
);
128+
result.inputTokens = tokenCount.data.inputTokens;
129+
result.outputTokens = tokenCount.data.outputTokens;
130+
131+
const finalDescription = body.filter(
132+
(v: LLMStreamEvent) =>
133+
v.type === LLMStreamEventType.ToolStatus &&
134+
v.data.status.startsWith('DESCRIPTION:'),
135+
);
136+
if (finalDescription.length > 0) {
137+
result.description = finalDescription
138+
.pop()
139+
.data.status.replace('DESCRIPTION:', '');
140+
}
141+
result.generationCount = body.filter(
142+
(v: LLMStreamEvent) =>
143+
v.type === LLMStreamEventType.ToolStatus &&
144+
v.data.status === 'Generating SQL query from the prompt',
145+
).length;
146+
if (lastStatus.data.status === ToolStatus.Completed) {
147+
const dataset = await datasetStore.findById(
148+
lastStatus.data.data?.['datasetId'],
149+
);
150+
result.query = parseQuery(dataset.query);
151+
const {body: actualData} = await client
152+
.get(`/datasets/${dataset.id}/execute`)
153+
.set('Authorization', `Bearer ${token}`)
154+
.expect(200);
155+
const expectedData = await ds.execute(parseQuery(query.resultQuery));
156+
result.actualResult = actualData;
157+
result.expectedResult = expectedData;
158+
// compare actualData and expectedData
159+
if (JSON.stringify(actualData) === JSON.stringify(expectedData)) {
160+
result.success = true;
161+
}
162+
if (expectedData.length === 0) {
163+
result.emptyOutput = true;
164+
}
165+
} else {
166+
result.actualResult = JSON.stringify(lastStatus);
167+
logger.error('Tool did not complete successfully');
168+
}
169+
} catch (error) {
170+
result.actualResult = error.message ?? error.toString();
171+
logger.error('Error: ', error);
172+
}
173+
results.push(result);
174+
if (writeReport) {
175+
writeResultSoFar(results);
176+
}
177+
}
178+
}
179+
180+
return buildFinalResult(results);
181+
}
182+
183+
function buildFinalResult(results: GenerationAcceptanceTestResult[]) {
184+
const success = results.filter(r => r.success).length;
185+
const total = results.length;
186+
return {
187+
total,
188+
success,
189+
results,
190+
};
191+
}
192+
193+
function writeResultSoFar(results: GenerationAcceptanceTestResult[]) {
194+
const successCount = results.filter(r => r.success).length;
195+
const totalCount = results.length;
196+
const totalInputTokens = results.reduce((acc, r) => acc + r.inputTokens, 0);
197+
const totalOutputTokens = results.reduce((acc, r) => acc + r.outputTokens, 0);
198+
const totalTime = results.reduce((acc, r) => acc + r.time, 0);
199+
const avgTime = totalTime / totalCount || 0;
200+
const avgInputTokens = totalInputTokens / totalCount || 0;
201+
const avgOutputTokens = totalOutputTokens / totalCount || 0;
202+
const modelName = getModelNameFromEnv();
203+
let report = `# For Model - ${modelName}\n`;
204+
// print a table with success, non empty success, total time, avg time, total tokens, avg tokens
205+
report += `## Success Metrics\n`;
206+
report += generateMarkdownTable([
207+
{
208+
'Success Count': successCount,
209+
'Total Count': results.length,
210+
'Success Rate': ((successCount / totalCount) * 100).toFixed(2) + '%',
211+
},
212+
]);
213+
report += `\n## Time Metrics\n`;
214+
report += generateMarkdownTable([
215+
{
216+
'Total Time (s)': totalTime.toFixed(2),
217+
'Avg Time (s)': avgTime.toFixed(2),
218+
},
219+
]);
220+
report += `\n## Token Metrics\n`;
221+
report += generateMarkdownTable([
222+
{
223+
'Total Input Tokens': totalInputTokens,
224+
'Total Output Tokens': totalOutputTokens,
225+
'Avg Input Tokens': avgInputTokens.toFixed(2),
226+
'Avg Output Tokens': avgOutputTokens.toFixed(2),
227+
'Total Tokens': (totalInputTokens + totalOutputTokens).toFixed(2),
228+
},
229+
]);
230+
report += `\n## Detailed Results\n`;
231+
report += generateMarkdownTable(
232+
results.map(result => ({
233+
Query: result.case,
234+
Success: result.success ? `:green_circle:` : `:red_circle:`,
235+
'Empty Output': result.emptyOutput,
236+
'Time (s)': result.time.toFixed(2),
237+
'Input Tokens Used': result.inputTokens,
238+
'Output Tokens Used': result.outputTokens,
239+
'Generation Count': result.generationCount,
240+
})),
241+
);
242+
report += `\n## Failed Queries and Results\n`;
243+
for (const result of results) {
244+
if (result.success) continue;
245+
report += `\n ### Query: ${result.case}\n`;
246+
report += `**Description:** ${result.description}\n\n`;
247+
report += `\n \`\`\`sql\n${result.query}\n\`\`\`\n`;
248+
report += `\n**Actual Result:**\n\n`;
249+
if (Array.isArray(result.actualResult)) {
250+
report += generateMarkdownTable(result.actualResult ?? []);
251+
} else {
252+
report += '```\n' + JSON.stringify(result.actualResult) + '\n```\n';
253+
}
254+
report += `\n**Expected Result:**\n\n`;
255+
report += generateMarkdownTable(result.expectedResult ?? []);
256+
report += `\n---\n`;
257+
}
258+
writeFileSync(
259+
`./llm-reports/generation-report-${modelName.toLowerCase().replace(/[\s\_\/\\]/g, '-')}.md`,
260+
report,
261+
);
262+
}

0 commit comments

Comments
 (0)