11import { BaseMessage } from '@langchain/core/messages' ;
2- import { BadRequestException , Inject , Injectable , Logger , NotFoundException , Scope } from '@nestjs/common' ;
2+ import {
3+ BadRequestException ,
4+ ForbiddenException ,
5+ Inject ,
6+ Injectable ,
7+ Logger ,
8+ NotFoundException ,
9+ Scope ,
10+ } from '@nestjs/common' ;
311import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js' ;
412import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js' ;
513import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js' ;
@@ -9,6 +17,7 @@ import { Response } from 'express';
917import { AIToolCall , AIToolDefinition } from '../../../ai-core/interfaces/ai-provider.interface.js' ;
1018import { AIProviderType } from '../../../ai-core/interfaces/ai-service.interface.js' ;
1119import { AICoreService } from '../../../ai-core/services/ai-core.service.js' ;
20+ import { collectMongoPipelineCollections } from '../../../ai-core/tools/collect-mongo-pipeline-collections.js' ;
1221import { createDatabaseTools } from '../../../ai-core/tools/database-tools.js' ;
1322import { searchDocumentation } from '../../../ai-core/tools/documentation-search.js' ;
1423import { createDatabaseQuerySystemPrompt } from '../../../ai-core/tools/prompts.js' ;
@@ -22,7 +31,9 @@ import { Messages } from '../../../exceptions/text/messages.js';
2231import { getErrorMessage } from '../../../helpers/get-error-message.js' ;
2332import { isConnectionTypeAgent } from '../../../helpers/is-connection-entity-agent.js' ;
2433import { slackPostMessage } from '../../../helpers/slack/slack-post-message.js' ;
34+ import { CedarPermissionsService } from '../../cedar-authorization/cedar-permissions.service.js' ;
2535import { ConnectionEntity } from '../../connection/connection.entity.js' ;
36+ import { assertUserCanReadQueryTables } from '../../visualizations/panel/utils/assert-query-tables-readable.util.js' ;
2637import { MessageRole } from '../ai-conversation-history/ai-chat-messages/message-role.enum.js' ;
2738import { UserAiChatEntity } from '../ai-conversation-history/user-ai-chat/user-ai-chat.entity.js' ;
2839import { IRequestInfoFromTableV2 } from '../ai-use-cases.interface.js' ;
@@ -41,6 +52,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
4152 @Inject ( BaseType . GLOBAL_DB_CONTEXT )
4253 protected _dbContext : IGlobalDatabaseContext ,
4354 private readonly aiCoreService : AICoreService ,
55+ private readonly cedarPermissions : CedarPermissionsService ,
4456 ) {
4557 super ( ) ;
4658 }
@@ -104,6 +116,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
104116 tableName ,
105117 userEmail ,
106118 foundConnection ,
119+ user_id ,
107120 ) ;
108121
109122 if ( accumulatedResponse ) {
@@ -132,6 +145,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
132145 inputTableName : string ,
133146 userEmail : string ,
134147 foundConnection : ConnectionEntity ,
148+ userId : string ,
135149 ) : Promise < string > {
136150 let currentMessages = [ ...messages ] ;
137151 let depth = 0 ;
@@ -178,6 +192,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
178192 inputTableName ,
179193 userEmail ,
180194 foundConnection ,
195+ userId ,
181196 ) ;
182197
183198 for ( const toolResult of toolResults ) {
@@ -226,6 +241,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
226241 inputTableName : string ,
227242 userEmail : string ,
228243 foundConnection : ConnectionEntity ,
244+ userId : string ,
229245 ) : Promise < Array < { toolCallId : string ; result : string } > > {
230246 const results : Array < { toolCallId : string ; result : string } > = [ ] ;
231247
@@ -236,11 +252,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
236252 switch ( toolCall . name ) {
237253 case 'getTableStructure' : {
238254 const tableName = ( toolCall . arguments . tableName as string ) || inputTableName ;
255+ await this . assertUserCanReadTables ( [ tableName ] , userId , foundConnection . id ) ;
239256 const structureInfo = await this . getTableStructureInfo (
240257 dataAccessObject ,
241258 tableName ,
242259 userEmail ,
243260 foundConnection ,
261+ userId ,
244262 ) ;
245263 result = encodeToToon ( structureInfo ) ;
246264 break ;
@@ -256,6 +274,14 @@ export class RequestInfoFromTableWithAIUseCaseV7
256274 'Invalid SQL query. Please ensure it is a read-only SELECT statement without any forbidden keywords.' ,
257275 ) ;
258276 }
277+ await assertUserCanReadQueryTables ( {
278+ query,
279+ connectionType : foundConnection . type as ConnectionTypesEnum ,
280+ connectionId : foundConnection . id ,
281+ validateTableRead : ( referencedTableName ) =>
282+ this . cedarPermissions . improvedCheckTableRead ( userId , foundConnection . id , referencedTableName ) ,
283+ listAllTableNames : async ( ) => ( await dataAccessObject . getTablesFromDB ( ) ) . map ( ( table ) => table . tableName ) ,
284+ } ) ;
259285 const wrappedQuery = wrapQueryWithLimit ( query , foundConnection . type as ConnectionTypesEnum ) ;
260286 const queryResult = await dataAccessObject . executeRawQuery ( wrappedQuery , inputTableName , userEmail ) ;
261287 result = encodeToToon ( queryResult ) ;
@@ -272,6 +298,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
272298 'Invalid MongoDB command. Please ensure it is a read-only aggregation pipeline without any forbidden keywords.' ,
273299 ) ;
274300 }
301+ await this . assertUserCanReadPipelineCollections (
302+ pipeline ,
303+ inputTableName ,
304+ userId ,
305+ foundConnection . id ,
306+ dataAccessObject ,
307+ ) ;
275308 const pipelineResult = await dataAccessObject . executeRawQuery ( pipeline , inputTableName , userEmail ) ;
276309 result = encodeToToon ( pipelineResult ) ;
277310 break ;
@@ -307,32 +340,50 @@ export class RequestInfoFromTableWithAIUseCaseV7
307340 tableName : string ,
308341 userEmail : string ,
309342 foundConnection : ConnectionEntity ,
343+ userId : string ,
310344 ) {
311345 const [ tableStructure , tableForeignKeys , referencedTableNamesAndColumns ] = await Promise . all ( [
312346 dao . getTableStructure ( tableName , userEmail ) ,
313347 dao . getTableForeignKeys ( tableName , userEmail ) ,
314348 dao . getReferencedTableNamesAndColumns ( tableName , userEmail ) ,
315349 ] ) ;
316350
351+ // Only expose the structure of related tables the user is permitted to
352+ // read — otherwise foreign-key traversal would leak the schema of tables
353+ // the user has no access to.
317354 const referencedTablesStructures = [ ] ;
318355 const structurePromises = referencedTableNamesAndColumns . flatMap ( ( referencedTable ) =>
319- referencedTable . referenced_by . map ( ( table ) =>
320- dao . getTableStructure ( table . table_name , userEmail ) . then ( ( structure ) => ( {
321- tableName : table . table_name ,
322- structure,
323- } ) ) ,
324- ) ,
356+ referencedTable . referenced_by . map ( async ( table ) => {
357+ const canRead = await this . cedarPermissions . improvedCheckTableRead (
358+ userId ,
359+ foundConnection . id ,
360+ table . table_name ,
361+ ) ;
362+ if ( ! canRead ) {
363+ return null ;
364+ }
365+ const structure = await dao . getTableStructure ( table . table_name , userEmail ) ;
366+ return { tableName : table . table_name , structure } ;
367+ } ) ,
325368 ) ;
326- referencedTablesStructures . push ( ...( await Promise . all ( structurePromises ) ) ) ;
369+ referencedTablesStructures . push ( ...( await Promise . all ( structurePromises ) ) . filter ( ( item ) => item !== null ) ) ;
327370
328371 const foreignTablesStructures = [ ] ;
329- const foreignTablesStructurePromises = tableForeignKeys . flatMap ( ( foreignKey ) =>
330- dao . getTableStructure ( foreignKey . referenced_table_name , userEmail ) . then ( ( structure ) => ( {
331- tableName : foreignKey . referenced_table_name ,
332- structure,
333- } ) ) ,
372+ const foreignTablesStructurePromises = tableForeignKeys . map ( async ( foreignKey ) => {
373+ const canRead = await this . cedarPermissions . improvedCheckTableRead (
374+ userId ,
375+ foundConnection . id ,
376+ foreignKey . referenced_table_name ,
377+ ) ;
378+ if ( ! canRead ) {
379+ return null ;
380+ }
381+ const structure = await dao . getTableStructure ( foreignKey . referenced_table_name , userEmail ) ;
382+ return { tableName : foreignKey . referenced_table_name , structure } ;
383+ } ) ;
384+ foreignTablesStructures . push (
385+ ...( await Promise . all ( foreignTablesStructurePromises ) ) . filter ( ( item ) => item !== null ) ,
334386 ) ;
335- foreignTablesStructures . push ( ...( await Promise . all ( foreignTablesStructurePromises ) ) ) ;
336387
337388 return {
338389 tableStructure,
@@ -345,6 +396,64 @@ export class RequestInfoFromTableWithAIUseCaseV7
345396 } ;
346397 }
347398
399+ /**
400+ * Verifies the user has read permission on every supplied table before the
401+ * AI is allowed to query or inspect them. Throws a `ForbiddenException` on
402+ * the first unreadable table; inside the tool loop this surfaces back to the
403+ * model as a tool error, so the offending query is never executed. Empty or
404+ * blank names are ignored.
405+ */
406+ private async assertUserCanReadTables (
407+ tableNames : Array < string > ,
408+ userId : string ,
409+ connectionId : string ,
410+ ) : Promise < void > {
411+ const uniqueTableNames = Array . from (
412+ new Set ( tableNames . map ( ( name ) => name ?. trim ( ) ) . filter ( ( name ) : name is string => Boolean ( name ) ) ) ,
413+ ) ;
414+
415+ for ( const tableName of uniqueTableNames ) {
416+ const canRead = await this . cedarPermissions . improvedCheckTableRead ( userId , connectionId , tableName ) ;
417+ if ( ! canRead ) {
418+ this . logger . warn (
419+ `AI request blocked for user ${ userId } on connection ${ connectionId } : ` +
420+ `no read permission for table "${ tableName } "` ,
421+ ) ;
422+ throw new ForbiddenException ( Messages . NO_READ_PERMISSION_FOR_TABLE ( tableName ) ) ;
423+ }
424+ }
425+ }
426+
427+ /**
428+ * Guards a MongoDB aggregation pipeline against table-level read permissions:
429+ * the user must be able to read the base collection and every collection the
430+ * pipeline pulls in (`$lookup` / `$graphLookup` / `$unionWith`). When the
431+ * pipeline cannot be parsed we cannot trust it to be harmless, so we fall
432+ * back to requiring read permission on every collection in the connection.
433+ */
434+ private async assertUserCanReadPipelineCollections (
435+ pipeline : string ,
436+ baseCollection : string ,
437+ userId : string ,
438+ connectionId : string ,
439+ dataAccessObject : IDataAccessObject | IDataAccessObjectAgent ,
440+ ) : Promise < void > {
441+ const collected = collectMongoPipelineCollections ( pipeline ) ;
442+
443+ let collectionsToCheck : Array < string > ;
444+ if ( collected . kind === 'tables' ) {
445+ collectionsToCheck = [ baseCollection , ...collected . tables ] ;
446+ } else {
447+ this . logger . warn (
448+ `AI pipeline permission check could not resolve referenced collections for connection ${ connectionId } ` +
449+ `(reason: ${ collected . reason } ); falling back to all-collections read check.` ,
450+ ) ;
451+ collectionsToCheck = ( await dataAccessObject . getTablesFromDB ( ) ) . map ( ( table ) => table . tableName ) ;
452+ }
453+
454+ await this . assertUserCanReadTables ( collectionsToCheck , userId , connectionId ) ;
455+ }
456+
348457 private setupResponseHeaders ( response : Response ) : void {
349458 response . setHeader ( 'Content-Type' , 'text/event-stream' ) ;
350459 response . setHeader ( 'Cache-Control' , 'no-cache' ) ;
0 commit comments