Skip to content

Commit 7eac455

Browse files
authored
Merge pull request #1818 from rocket-admin/backend_ai_requests_permissions
feat: implement permission checks for qai requests and add tests
2 parents 4c8b7cb + 681610c commit 7eac455

3 files changed

Lines changed: 239 additions & 14 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import { CollectQueryTablesResult } from '../../entities/visualizations/panel/utils/collect-query-tables.util.js';
2+
import { getErrorMessage } from '../../helpers/get-error-message.js';
3+
4+
/**
5+
* Recursively collects collection names referenced by stages that read from
6+
* other collections (`$lookup`, `$graphLookup`, `$unionWith`) anywhere in a
7+
* MongoDB aggregation pipeline, including nested sub-pipelines.
8+
*/
9+
function collectReferencedCollections(node: unknown, collected: Set<string>): void {
10+
if (Array.isArray(node)) {
11+
for (const item of node) {
12+
collectReferencedCollections(item, collected);
13+
}
14+
return;
15+
}
16+
if (!node || typeof node !== 'object') {
17+
return;
18+
}
19+
for (const [key, value] of Object.entries(node as Record<string, unknown>)) {
20+
if (key === '$lookup' || key === '$graphLookup') {
21+
const from = (value as { from?: unknown })?.from;
22+
if (typeof from === 'string' && from.length > 0) {
23+
collected.add(from);
24+
}
25+
} else if (key === '$unionWith') {
26+
// `$unionWith` accepts either a collection-name string or `{ coll: <name>, pipeline: [...] }`.
27+
if (typeof value === 'string' && value.length > 0) {
28+
collected.add(value);
29+
} else {
30+
const coll = (value as { coll?: unknown })?.coll;
31+
if (typeof coll === 'string' && coll.length > 0) {
32+
collected.add(coll);
33+
}
34+
}
35+
}
36+
collectReferencedCollections(value, collected);
37+
}
38+
}
39+
40+
/**
41+
* Resolves the collections a MongoDB aggregation pipeline reads from besides
42+
* its base collection (the `$lookup` / `$graphLookup` / `$unionWith` targets),
43+
* so the caller can verify the user has read permission on each.
44+
*
45+
* Returns `{ kind: 'tables' }` (possibly empty) when the pipeline parses, and
46+
* `{ kind: 'indeterminate' }` when it cannot be parsed — in which case the
47+
* caller must fall back to a stricter check rather than assume it is harmless.
48+
*/
49+
export function collectMongoPipelineCollections(pipeline: string): CollectQueryTablesResult {
50+
let parsedPipeline: unknown;
51+
try {
52+
parsedPipeline = JSON.parse(pipeline);
53+
} catch (error) {
54+
return { kind: 'indeterminate', reason: `pipeline parse error: ${getErrorMessage(error)}` };
55+
}
56+
const collected = new Set<string>();
57+
collectReferencedCollections(parsedPipeline, collected);
58+
return { kind: 'tables', tables: Array.from(collected) };
59+
}

backend/src/entities/ai/use-cases/request-info-from-table-with-ai-v7.use.case.ts

Lines changed: 123 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
import { BaseMessage } from '@langchain/core/messages';
2-
import { BadRequestException, Inject, Injectable, Logger, NotFoundException, Scope } from '@nestjs/common';
2+
import {
3+
BadRequestException,
4+
ForbiddenException,
5+
Inject,
6+
Injectable,
7+
Logger,
8+
NotFoundException,
9+
Scope,
10+
} from '@nestjs/common';
311
import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js';
412
import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js';
513
import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js';
@@ -9,6 +17,7 @@ import { Response } from 'express';
917
import { AIToolCall, AIToolDefinition } from '../../../ai-core/interfaces/ai-provider.interface.js';
1018
import { AIProviderType } from '../../../ai-core/interfaces/ai-service.interface.js';
1119
import { AICoreService } from '../../../ai-core/services/ai-core.service.js';
20+
import { collectMongoPipelineCollections } from '../../../ai-core/tools/collect-mongo-pipeline-collections.js';
1221
import { createDatabaseTools } from '../../../ai-core/tools/database-tools.js';
1322
import { searchDocumentation } from '../../../ai-core/tools/documentation-search.js';
1423
import { createDatabaseQuerySystemPrompt } from '../../../ai-core/tools/prompts.js';
@@ -22,7 +31,9 @@ import { Messages } from '../../../exceptions/text/messages.js';
2231
import { getErrorMessage } from '../../../helpers/get-error-message.js';
2332
import { isConnectionTypeAgent } from '../../../helpers/is-connection-entity-agent.js';
2433
import { slackPostMessage } from '../../../helpers/slack/slack-post-message.js';
34+
import { CedarPermissionsService } from '../../cedar-authorization/cedar-permissions.service.js';
2535
import { ConnectionEntity } from '../../connection/connection.entity.js';
36+
import { assertUserCanReadQueryTables } from '../../visualizations/panel/utils/assert-query-tables-readable.util.js';
2637
import { MessageRole } from '../ai-conversation-history/ai-chat-messages/message-role.enum.js';
2738
import { UserAiChatEntity } from '../ai-conversation-history/user-ai-chat/user-ai-chat.entity.js';
2839
import { IRequestInfoFromTableV2 } from '../ai-use-cases.interface.js';
@@ -41,6 +52,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
4152
@Inject(BaseType.GLOBAL_DB_CONTEXT)
4253
protected _dbContext: IGlobalDatabaseContext,
4354
private readonly aiCoreService: AICoreService,
55+
private readonly cedarPermissions: CedarPermissionsService,
4456
) {
4557
super();
4658
}
@@ -104,6 +116,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
104116
tableName,
105117
userEmail,
106118
foundConnection,
119+
user_id,
107120
);
108121

109122
if (accumulatedResponse) {
@@ -132,6 +145,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
132145
inputTableName: string,
133146
userEmail: string,
134147
foundConnection: ConnectionEntity,
148+
userId: string,
135149
): Promise<string> {
136150
let currentMessages = [...messages];
137151
let depth = 0;
@@ -178,6 +192,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
178192
inputTableName,
179193
userEmail,
180194
foundConnection,
195+
userId,
181196
);
182197

183198
for (const toolResult of toolResults) {
@@ -226,6 +241,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
226241
inputTableName: string,
227242
userEmail: string,
228243
foundConnection: ConnectionEntity,
244+
userId: string,
229245
): Promise<Array<{ toolCallId: string; result: string }>> {
230246
const results: Array<{ toolCallId: string; result: string }> = [];
231247

@@ -236,11 +252,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
236252
switch (toolCall.name) {
237253
case 'getTableStructure': {
238254
const tableName = (toolCall.arguments.tableName as string) || inputTableName;
255+
await this.assertUserCanReadTables([tableName], userId, foundConnection.id);
239256
const structureInfo = await this.getTableStructureInfo(
240257
dataAccessObject,
241258
tableName,
242259
userEmail,
243260
foundConnection,
261+
userId,
244262
);
245263
result = encodeToToon(structureInfo);
246264
break;
@@ -256,6 +274,14 @@ export class RequestInfoFromTableWithAIUseCaseV7
256274
'Invalid SQL query. Please ensure it is a read-only SELECT statement without any forbidden keywords.',
257275
);
258276
}
277+
await assertUserCanReadQueryTables({
278+
query,
279+
connectionType: foundConnection.type as ConnectionTypesEnum,
280+
connectionId: foundConnection.id,
281+
validateTableRead: (referencedTableName) =>
282+
this.cedarPermissions.improvedCheckTableRead(userId, foundConnection.id, referencedTableName),
283+
listAllTableNames: async () => (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName),
284+
});
259285
const wrappedQuery = wrapQueryWithLimit(query, foundConnection.type as ConnectionTypesEnum);
260286
const queryResult = await dataAccessObject.executeRawQuery(wrappedQuery, inputTableName, userEmail);
261287
result = encodeToToon(queryResult);
@@ -272,6 +298,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
272298
'Invalid MongoDB command. Please ensure it is a read-only aggregation pipeline without any forbidden keywords.',
273299
);
274300
}
301+
await this.assertUserCanReadPipelineCollections(
302+
pipeline,
303+
inputTableName,
304+
userId,
305+
foundConnection.id,
306+
dataAccessObject,
307+
);
275308
const pipelineResult = await dataAccessObject.executeRawQuery(pipeline, inputTableName, userEmail);
276309
result = encodeToToon(pipelineResult);
277310
break;
@@ -307,32 +340,50 @@ export class RequestInfoFromTableWithAIUseCaseV7
307340
tableName: string,
308341
userEmail: string,
309342
foundConnection: ConnectionEntity,
343+
userId: string,
310344
) {
311345
const [tableStructure, tableForeignKeys, referencedTableNamesAndColumns] = await Promise.all([
312346
dao.getTableStructure(tableName, userEmail),
313347
dao.getTableForeignKeys(tableName, userEmail),
314348
dao.getReferencedTableNamesAndColumns(tableName, userEmail),
315349
]);
316350

351+
// Only expose the structure of related tables the user is permitted to
352+
// read — otherwise foreign-key traversal would leak the schema of tables
353+
// the user has no access to.
317354
const referencedTablesStructures = [];
318355
const structurePromises = referencedTableNamesAndColumns.flatMap((referencedTable) =>
319-
referencedTable.referenced_by.map((table) =>
320-
dao.getTableStructure(table.table_name, userEmail).then((structure) => ({
321-
tableName: table.table_name,
322-
structure,
323-
})),
324-
),
356+
referencedTable.referenced_by.map(async (table) => {
357+
const canRead = await this.cedarPermissions.improvedCheckTableRead(
358+
userId,
359+
foundConnection.id,
360+
table.table_name,
361+
);
362+
if (!canRead) {
363+
return null;
364+
}
365+
const structure = await dao.getTableStructure(table.table_name, userEmail);
366+
return { tableName: table.table_name, structure };
367+
}),
325368
);
326-
referencedTablesStructures.push(...(await Promise.all(structurePromises)));
369+
referencedTablesStructures.push(...(await Promise.all(structurePromises)).filter((item) => item !== null));
327370

328371
const foreignTablesStructures = [];
329-
const foreignTablesStructurePromises = tableForeignKeys.flatMap((foreignKey) =>
330-
dao.getTableStructure(foreignKey.referenced_table_name, userEmail).then((structure) => ({
331-
tableName: foreignKey.referenced_table_name,
332-
structure,
333-
})),
372+
const foreignTablesStructurePromises = tableForeignKeys.map(async (foreignKey) => {
373+
const canRead = await this.cedarPermissions.improvedCheckTableRead(
374+
userId,
375+
foundConnection.id,
376+
foreignKey.referenced_table_name,
377+
);
378+
if (!canRead) {
379+
return null;
380+
}
381+
const structure = await dao.getTableStructure(foreignKey.referenced_table_name, userEmail);
382+
return { tableName: foreignKey.referenced_table_name, structure };
383+
});
384+
foreignTablesStructures.push(
385+
...(await Promise.all(foreignTablesStructurePromises)).filter((item) => item !== null),
334386
);
335-
foreignTablesStructures.push(...(await Promise.all(foreignTablesStructurePromises)));
336387

337388
return {
338389
tableStructure,
@@ -345,6 +396,64 @@ export class RequestInfoFromTableWithAIUseCaseV7
345396
};
346397
}
347398

399+
/**
400+
* Verifies the user has read permission on every supplied table before the
401+
* AI is allowed to query or inspect them. Throws a `ForbiddenException` on
402+
* the first unreadable table; inside the tool loop this surfaces back to the
403+
* model as a tool error, so the offending query is never executed. Empty or
404+
* blank names are ignored.
405+
*/
406+
private async assertUserCanReadTables(
407+
tableNames: Array<string>,
408+
userId: string,
409+
connectionId: string,
410+
): Promise<void> {
411+
const uniqueTableNames = Array.from(
412+
new Set(tableNames.map((name) => name?.trim()).filter((name): name is string => Boolean(name))),
413+
);
414+
415+
for (const tableName of uniqueTableNames) {
416+
const canRead = await this.cedarPermissions.improvedCheckTableRead(userId, connectionId, tableName);
417+
if (!canRead) {
418+
this.logger.warn(
419+
`AI request blocked for user ${userId} on connection ${connectionId}: ` +
420+
`no read permission for table "${tableName}"`,
421+
);
422+
throw new ForbiddenException(Messages.NO_READ_PERMISSION_FOR_TABLE(tableName));
423+
}
424+
}
425+
}
426+
427+
/**
428+
* Guards a MongoDB aggregation pipeline against table-level read permissions:
429+
* the user must be able to read the base collection and every collection the
430+
* pipeline pulls in (`$lookup` / `$graphLookup` / `$unionWith`). When the
431+
* pipeline cannot be parsed we cannot trust it to be harmless, so we fall
432+
* back to requiring read permission on every collection in the connection.
433+
*/
434+
private async assertUserCanReadPipelineCollections(
435+
pipeline: string,
436+
baseCollection: string,
437+
userId: string,
438+
connectionId: string,
439+
dataAccessObject: IDataAccessObject | IDataAccessObjectAgent,
440+
): Promise<void> {
441+
const collected = collectMongoPipelineCollections(pipeline);
442+
443+
let collectionsToCheck: Array<string>;
444+
if (collected.kind === 'tables') {
445+
collectionsToCheck = [baseCollection, ...collected.tables];
446+
} else {
447+
this.logger.warn(
448+
`AI pipeline permission check could not resolve referenced collections for connection ${connectionId} ` +
449+
`(reason: ${collected.reason}); falling back to all-collections read check.`,
450+
);
451+
collectionsToCheck = (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName);
452+
}
453+
454+
await this.assertUserCanReadTables(collectionsToCheck, userId, connectionId);
455+
}
456+
348457
private setupResponseHeaders(response: Response): void {
349458
response.setHeader('Content-Type', 'text/event-stream');
350459
response.setHeader('Cache-Control', 'no-cache');
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import test from 'ava';
2+
import { collectMongoPipelineCollections } from '../../../src/ai-core/tools/collect-mongo-pipeline-collections.js';
3+
4+
function tablesOf(pipeline: string): Array<string> {
5+
const result = collectMongoPipelineCollections(pipeline);
6+
if (result.kind !== 'tables') {
7+
throw new Error(`expected resolved tables, got indeterminate: ${result.reason}`);
8+
}
9+
return [...result.tables].sort();
10+
}
11+
12+
test('resolves no referenced collections for a pipeline without joins', (t) => {
13+
t.deepEqual(tablesOf('[{"$match":{"status":"active"}},{"$group":{"_id":"$type"}}]'), []);
14+
});
15+
16+
test('resolves a $lookup target collection', (t) => {
17+
t.deepEqual(tablesOf('[{"$lookup":{"from":"salaries","localField":"id","foreignField":"user_id","as":"s"}}]'), [
18+
'salaries',
19+
]);
20+
});
21+
22+
test('resolves a $graphLookup target collection', (t) => {
23+
t.deepEqual(
24+
tablesOf(
25+
'[{"$graphLookup":{"from":"org_chart","startWith":"$managerId","connectFromField":"managerId","connectToField":"_id","as":"chain"}}]',
26+
),
27+
['org_chart'],
28+
);
29+
});
30+
31+
test('resolves a $unionWith string collection', (t) => {
32+
t.deepEqual(tablesOf('[{"$unionWith":"archived_orders"}]'), ['archived_orders']);
33+
});
34+
35+
test('resolves a $unionWith object collection', (t) => {
36+
t.deepEqual(tablesOf('[{"$unionWith":{"coll":"audit_log","pipeline":[]}}]'), ['audit_log']);
37+
});
38+
39+
test('resolves collections nested inside a $lookup sub-pipeline', (t) => {
40+
const pipeline =
41+
'[{"$lookup":{"from":"orders","as":"o","pipeline":[{"$lookup":{"from":"secret_payouts","localField":"a","foreignField":"b","as":"p"}}]}}]';
42+
t.deepEqual(tablesOf(pipeline), ['orders', 'secret_payouts']);
43+
});
44+
45+
test('deduplicates repeated collection references', (t) => {
46+
const pipeline =
47+
'[{"$lookup":{"from":"orders","localField":"a","foreignField":"b","as":"o1"}},{"$lookup":{"from":"orders","localField":"c","foreignField":"d","as":"o2"}}]';
48+
t.deepEqual(tablesOf(pipeline), ['orders']);
49+
});
50+
51+
test('returns indeterminate for an unparseable pipeline', (t) => {
52+
const result = collectMongoPipelineCollections('not valid json {');
53+
t.is(result.kind, 'indeterminate');
54+
if (result.kind === 'indeterminate') {
55+
t.true(result.reason.includes('parse error'));
56+
}
57+
});

0 commit comments

Comments
 (0)