@@ -55,6 +55,7 @@ import {
5555 getTableName ,
5656 isBeforeInvocation ,
5757 isTrueNode ,
58+ logicalNot ,
5859 trueNode ,
5960} from './utils' ;
6061
@@ -100,27 +101,17 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
100101
101102 // #region Pre mutation work
102103
104+ // create
103105 if ( InsertQueryNode . is ( node ) ) {
104- // pre-create policy evaluation happens before execution of the query
105- const isManyToManyJoinTable = this . isManyToManyJoinTable ( mutationModel ) ;
106- let needCheckPreCreate = true ;
107-
108- // many-to-many join table is not a model so can't have policies on it
109- if ( ! isManyToManyJoinTable ) {
110- // check constant policies
111- const constCondition = this . tryGetConstantPolicy ( mutationModel , 'create' ) ;
112- if ( constCondition === true ) {
113- needCheckPreCreate = false ;
114- } else if ( constCondition === false ) {
115- throw createRejectedByPolicyError ( mutationModel , RejectedByPolicyReason . NO_ACCESS ) ;
116- }
117- }
106+ await this . preCreateCheck ( mutationModel , node , proceed ) ;
107+ }
118108
119- if ( needCheckPreCreate ) {
120- await this . enforcePreCreatePolicy ( node , mutationModel , isManyToManyJoinTable , proceed ) ;
121- }
109+ // update
110+ if ( UpdateQueryNode . is ( node ) ) {
111+ await this . preUpdateCheck ( mutationModel , node , proceed ) ;
122112 }
123113
114+ // post-update: load before-update entities if needed
124115 const hasPostUpdatePolicies = UpdateQueryNode . is ( node ) && this . hasPostUpdatePolicies ( mutationModel ) ;
125116
126117 let beforeUpdateInfo : Awaited < ReturnType < typeof this . loadBeforeUpdateEntities > > | undefined ;
@@ -130,7 +121,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
130121
131122 // #endregion
132123
133- // #region query execution
124+ // #region mutation execution
134125
135126 const result = await proceed ( this . transformNode ( node ) ) ;
136127
@@ -238,23 +229,76 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
238229 // #endregion
239230 }
240231
241- // #endregion
232+ private async preCreateCheck ( mutationModel : string , node : InsertQueryNode , proceed : ProceedKyselyQueryFunction ) {
233+ const isManyToManyJoinTable = this . isManyToManyJoinTable ( mutationModel ) ;
234+ let needCheckPreCreate = true ;
235+
236+ // many-to-many join table is not a model so can't have policies on it
237+ if ( ! isManyToManyJoinTable ) {
238+ // check constant policies
239+ const constCondition = this . tryGetConstantPolicy ( mutationModel , 'create' ) ;
240+ if ( constCondition === true ) {
241+ needCheckPreCreate = false ;
242+ } else if ( constCondition === false ) {
243+ throw createRejectedByPolicyError ( mutationModel , RejectedByPolicyReason . NO_ACCESS ) ;
244+ }
245+ }
242246
243- // correction to kysely mutation result may be needed because we might have added
244- // returning clause to the query and caused changes to the result shape
245- private postProcessMutationResult ( result : QueryResult < any > , node : MutationQueryNode ) {
246- if ( node . returning ) {
247- return result ;
248- } else {
249- return {
250- ...result ,
251- rows : [ ] ,
252- numAffectedRows : result . numAffectedRows ?? BigInt ( result . rows . length ) ,
253- } ;
247+ if ( needCheckPreCreate ) {
248+ await this . enforcePreCreatePolicy ( node , mutationModel , isManyToManyJoinTable , proceed ) ;
249+ }
250+ }
251+
252+ private async preUpdateCheck ( mutationModel : string , node : UpdateQueryNode , proceed : ProceedKyselyQueryFunction ) {
253+ // check if any rows will be filtered out by field-level update policies, and reject the whole update if so
254+
255+ const fieldsToUpdate =
256+ node . updates
257+ ?. map ( ( u ) => ( ColumnNode . is ( u . column ) ? u . column . column . name : undefined ) )
258+ . filter ( ( f ) : f is string => ! ! f ) ?? [ ] ;
259+ const fieldUpdatePolicies = fieldsToUpdate . map ( ( f ) => this . buildFieldPolicyFilter ( mutationModel , f , 'update' ) ) ;
260+
261+ // filter combining field-level update policies
262+ const fieldLevelFilter = conjunction ( this . dialect , fieldUpdatePolicies ) ;
263+ if ( isTrueNode ( fieldLevelFilter ) ) {
264+ return ;
265+ }
266+
267+ // model-level update policy filter
268+ const modelLevelFilter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
269+
270+ // filter combining model-level update policy and update where
271+ const updateFilter = conjunction ( this . dialect , [ modelLevelFilter , node . where ?. where ?? trueNode ( this . dialect ) ] ) ;
272+
273+ // build a query to count rows that will be rejected by field-level policies
274+ // `SELECT COALESCE(SUM((not <fieldsFilter>) as integer), 0) AS $filteredCount WHERE <updateFilter> AND <rowFilter>`
275+ const preUpdateCheckQuery = expressionBuilder < any , any > ( )
276+ . selectFrom ( mutationModel )
277+ . select ( ( eb ) =>
278+ eb . fn
279+ . coalesce (
280+ eb . fn . sum (
281+ eb . cast ( new ExpressionWrapper ( logicalNot ( this . dialect , fieldLevelFilter ) ) , 'integer' ) ,
282+ ) ,
283+ eb . lit ( 0 ) ,
284+ )
285+ . as ( '$filteredCount' ) ,
286+ )
287+ . where ( ( ) => new ExpressionWrapper ( updateFilter ) ) ;
288+
289+ const preUpdateResult = await proceed ( preUpdateCheckQuery . toOperationNode ( ) ) ;
290+ if ( preUpdateResult . rows [ 0 ] . $filteredCount > 0 ) {
291+ throw createRejectedByPolicyError (
292+ mutationModel ,
293+ RejectedByPolicyReason . NO_ACCESS ,
294+ 'some rows cannot be updated due to field policies' ,
295+ ) ;
254296 }
255297 }
256298
257- // #region overrides
299+ // #endregion
300+
301+ // #region Transformations
258302
259303 protected override transformSelectQuery ( node : SelectQueryNode ) {
260304 if ( ! node . from ) {
@@ -269,7 +313,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
269313 const hasFieldLevelPolicies = node . from . froms . some ( ( table ) => {
270314 const extractedTable = this . extractTableName ( table ) ;
271315 if ( extractedTable ) {
272- return this . hasFieldLevelPolicies ( extractedTable . model ) ;
316+ return this . hasFieldLevelPolicies ( extractedTable . model , 'read' ) ;
273317 } else {
274318 return false ;
275319 }
@@ -573,14 +617,12 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
573617 return { hasPolicies : true , selection : SelectionNode . create ( selection ) } ;
574618 }
575619
576- private hasFieldLevelPolicies ( model : string ) {
620+ private hasFieldLevelPolicies ( model : string , operation : FieldLevelPolicyOperations ) {
577621 const modelDef = QueryUtils . getModel ( this . client . $schema , model ) ;
578622 if ( ! modelDef ) {
579623 return false ;
580624 }
581- return Object . values ( modelDef . fields ) . some ( ( fieldDef ) =>
582- fieldDef . attributes ?. some ( ( attr ) => [ '@allow' , '@deny' ] . includes ( attr . name ) ) ,
583- ) ;
625+ return Object . keys ( modelDef . fields ) . some ( ( field ) => this . getFieldPolicies ( model , field , operation ) . length > 0 ) ;
584626 }
585627
586628 private buildFieldPolicyFilter ( model : string , field : string , operation : FieldLevelPolicyOperations ) {
@@ -1235,5 +1277,19 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
12351277 }
12361278 }
12371279
1280+ // correction to kysely mutation result may be needed because we might have added
1281+ // returning clause to the query and caused changes to the result shape
1282+ private postProcessMutationResult ( result : QueryResult < any > , node : MutationQueryNode ) {
1283+ if ( node . returning ) {
1284+ return result ;
1285+ } else {
1286+ return {
1287+ ...result ,
1288+ rows : [ ] ,
1289+ numAffectedRows : result . numAffectedRows ?? BigInt ( result . rows . length ) ,
1290+ } ;
1291+ }
1292+ }
1293+
12381294 // #endregion
12391295}
0 commit comments