Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 2 additions & 336 deletions package-lock.json

Large diffs are not rendered by default.

84 changes: 82 additions & 2 deletions src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,33 @@ import {
EvaluationResult,
SemanticValidatorNode,
} from '../../../../components';
import {DbSchemaHelperService} from '../../../../components/db-query/services';
import {
DbSchemaHelperService,
TableSearchService,
} from '../../../../components/db-query/services';
import {LLMProvider} from '../../../../types';

describe('SemanticValidatorNode Unit', function () {
let node: SemanticValidatorNode;
let llmStub: sinon.SinonStub;
let tableSearchStub: sinon.SinonStubbedInstance<TableSearchService>;

beforeEach(() => {
llmStub = sinon.stub();
const llm = llmStub as unknown as LLMProvider;
const schemaHelper = {
asString: sinon.stub().returns(''),
} as unknown as DbSchemaHelperService;
tableSearchStub = sinon.createStubInstance(TableSearchService);
tableSearchStub.getTables.resolves([]);

node = new SemanticValidatorNode(llm, llm, {models: []}, schemaHelper);
node = new SemanticValidatorNode(
llm,
llm,
{models: []},
tableSearchStub,
schemaHelper,
);
});

afterEach(() => {
Expand Down Expand Up @@ -64,6 +76,7 @@ describe('SemanticValidatorNode Unit', function () {
});

it('should return QueryError if the query is invalid', async () => {
tableSearchStub.getTables.resolves(['users', 'orders']);
const state = {
prompt: 'Get all users',
sql: 'SELECT * FROM invalid_table',
Expand Down Expand Up @@ -154,4 +167,71 @@ describe('SemanticValidatorNode Unit', function () {
const prompt = llmStub.firstCall.args[0];
expect(prompt.value).to.containEql('the previous query was wrong');
});

it('should pass all accessible tables from tableSearchService into available-tables so LLM can flag missing ones', async () => {
const searchedTables = [
'public.users',
'public.orders',
'public.payments',
'analytics.reports',
];
tableSearchStub = sinon.createStubInstance(TableSearchService);
tableSearchStub.getTables.resolves(searchedTables);

const schemaHelper = {
asString: sinon.stub().returns(''),
} as unknown as DbSchemaHelperService;

const nodeWithTables = new SemanticValidatorNode(
llmStub as unknown as LLMProvider,
llmStub as unknown as LLMProvider,
{models: []},
tableSearchStub,
schemaHelper,
);

const state = {
prompt: 'Get revenue per user',
sql: 'SELECT u.name, SUM(p.amount) FROM users u JOIN payments p ON u.id = p.user_id GROUP BY u.name',
schema: {tables: {}, relations: []},
status: EvaluationResult.Pass,
id: 'test-id',
feedbacks: [],
replyToUser: '',
datasetId: 'test-dataset-id',
done: false,
sampleSqlPrompt: '',
sampleSql: '',
fromCache: false,
resultArray: undefined,
directCall: false,
description: undefined,
syntacticStatus: undefined,
syntacticFeedback: undefined,
semanticStatus: undefined,
semanticFeedback: undefined,
syntacticErrorTables: undefined,
semanticErrorTables: undefined,
fromTemplate: undefined,
templateId: undefined,
validationChecklist: '1. Revenue grouped by user',
changeType: undefined,
};

llmStub.resolves({content: '<valid/>'});

await nodeWithTables.execute(state, {});

sinon.assert.calledOnce(tableSearchStub.getTables);
expect(tableSearchStub.getTables.firstCall.args[0]).to.equal(
'Get revenue per user',
);

sinon.assert.calledOnce(llmStub);
const prompt = llmStub.firstCall.args[0];
expect(prompt.value).to.containEql('<available-tables>');
expect(prompt.value).to.containEql(
'public.users, public.orders, public.payments, analytics.reports',
);
});
});
2 changes: 1 addition & 1 deletion src/components/db-query/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export * from './dataset.model';
export * from './dataset-action.model';
export * from './dataset-update-dto.model';
export * from './dataset.model';
export * from './query-template-dto.model';
27 changes: 24 additions & 3 deletions src/components/db-query/nodes/semantic-validator.node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import {LLMProvider} from '../../../types';
import {stripThinkingTokens} from '../../../utils';
import {DbQueryAIExtensionBindings} from '../keys';
import {DbQueryNodes} from '../nodes.enum';
import {DbSchemaHelperService} from '../services';
import {
DbSchemaHelperService,
PermissionHelper,
TableSearchService,
} from '../services';
import {DbQueryState} from '../state';
import {DbQueryConfig, EvaluationResult} from '../types';

Expand All @@ -22,8 +26,12 @@ export class SemanticValidatorNode implements IGraphNode<DbQueryState> {
private readonly cheapllm: LLMProvider,
@inject(DbQueryAIExtensionBindings.Config)
private readonly config: DbQueryConfig,
@service(TableSearchService)
private readonly tableSearchService: TableSearchService,
@service(DbSchemaHelperService)
private readonly schemaHelper: DbSchemaHelperService,
@service(PermissionHelper)
private readonly permissionHelper?: PermissionHelper,
) {}

prompt = PromptTemplate.fromTemplate(`
Expand Down Expand Up @@ -103,13 +111,15 @@ Keep these feedbacks in mind while validating the new query.
const useSmartLLM =
this.config.nodes?.semanticValidatorNode?.useSmartLLM ?? false;
const llm = useSmartLLM ? this.smartllm : this.cheapllm;
const tableNames = Object.keys(state.schema?.tables ?? {});
const tableList =
Comment thread
akshatdubeysf marked this conversation as resolved.
(await this.tableSearchService.getTables(state.prompt)) ?? [];
const accessibleTables = this._filterByPermissions(tableList);
const chain = RunnableSequence.from([this.prompt, llm]);
const output = await chain.invoke({
userPrompt: state.prompt,
query: state.sql,
schema: this.schemaHelper.asString(state.schema),
tableNames: tableNames.join(', '),
tableNames: accessibleTables.join(', '),
checklist: state.validationChecklist ?? 'No checklist provided.',
feedbacks: await this.getFeedbacks(state),
});
Expand Down Expand Up @@ -153,4 +163,15 @@ Keep these feedbacks in mind while validating the new query.
}
return '';
}

private _filterByPermissions(tables: string[]): string[] {
const permHelper = this.permissionHelper;
if (!permHelper) {
return tables;
}
return tables.filter(t => {
const name = t.toLowerCase().slice(t.indexOf('.') + 1);
return permHelper.findMissingPermissions([name]).length === 0;
});
}
}
10 changes: 5 additions & 5 deletions src/components/db-query/nodes/verify-checklist.node.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import {AIMessage} from '@langchain/core/messages';
import {PromptTemplate} from '@langchain/core/prompts';
import {RunnableSequence} from '@langchain/core/runnables';
import {LangGraphRunnableConfig} from '@langchain/langgraph';
Expand All @@ -7,7 +8,6 @@ import {IGraphNode, LLMStreamEventType} from '../../../graphs';
import {AiIntegrationBindings} from '../../../keys';
import {LLMProvider} from '../../../types';
import {stripThinkingTokens} from '../../../utils';
import {AIMessage} from '@langchain/core/messages';
import {DbQueryAIExtensionBindings} from '../keys';
import {DbQueryNodes} from '../nodes.enum';
import {DbSchemaHelperService} from '../services';
Expand Down Expand Up @@ -43,7 +43,7 @@ A rule is relevant if:
- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included).
- It applies to any of the selected tables or their relationships.

After selecting relevant rules, review your selection and ensure:
Ensure:
- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included.
- Do not include rules that are completely unrelated to the question, schema, or selected tables.
</instructions>
Expand Down Expand Up @@ -82,8 +82,8 @@ If no rules are relevant: <result>none</result>
</output-instructions>`;

simpleOutputInstructions = `<output-instructions>
Return only a comma-separated list of the relevant rule indexes inside a result tag.
Do not include any other text, explanation, or formatting.
Return ONLY the comma-separated list of relevant rule indexes inside a result tag.
Do NOT include any reasoning, analysis, or explanation — only the result tag.
Example:
<result>1,3,5</result>
If no rules are relevant:
Expand All @@ -96,7 +96,7 @@ If no rules are relevant:
): Promise<DbQueryState> {
const empty = {} as DbQueryState;

if (this.config.nodes?.generateChecklistNode?.enabled === false) {
if (this.config.nodes?.verifyChecklistNode?.enabled === false) {
return empty;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export class DbKnowledgeGraphService implements KnowledgeGraph<
config.knowledgeGraph?.maxClusterSize ?? MAX_CLUSTER_SIZE; // Default max cluster size
}

async find(query: string, topK: number): Promise<string[]> {
async find(query: string, topK = 10): Promise<string[]> {
debug(`Selecting tables for query: "${query}"`);

// Step 1: Generate query embedding
Expand Down
2 changes: 1 addition & 1 deletion src/components/db-query/services/knowledge-graph/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export interface Graph {
export interface KnowledgeGraph<T, S> extends Graph {
toJSON(): string;
fromJSON(json: string): void;
find(query: string, count: number): Promise<T[]>;
find(query: string, count?: number): Promise<T[]>;
seed(data: S): Promise<void>;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export class TableSearchService {
private readonly dbSchemaHelper: DbSchemaHelperService,
) {}

async getTables(prompt: string, count: number): Promise<string[]> {
async getTables(prompt: string, count?: number): Promise<string[]> {
if (this.config.noKnowledgeGraph) {
return this._tables;
}
Expand Down
1 change: 1 addition & 0 deletions src/components/db-query/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ export type DbQueryConfig = {
parallelism?: number;
};
verifyChecklistNode?: {
enabled?: boolean;
evaluation?: boolean;
};
};
Expand Down
Loading