11import { invariant } from '@zenstackhq/common-helpers' ;
22import type { BaseCrudDialect , ClientContract , CRUD_EXT , ProceedKyselyQueryFunction } from '@zenstackhq/orm' ;
3- import { getCrudDialect , QueryUtils , RejectedByPolicyReason , SchemaUtils } from '@zenstackhq/orm' ;
3+ import { CoreWriteOperations , getCrudDialect , QueryUtils , RejectedByPolicyReason , SchemaUtils , SingleRowReadOperations } from '@zenstackhq/orm' ;
44import {
55 ExpressionUtils ,
66 type BuiltinType ,
@@ -42,6 +42,7 @@ import {
4242} from 'kysely' ;
4343import { match } from 'ts-pattern' ;
4444import { ColumnCollector } from './column-collector' ;
45+ import { policyContextStorage } from './context' ;
4546import { ExpressionTransformer } from './expression-transformer' ;
4647import type { PolicyPluginOptions } from './options' ;
4748import type { Policy , PolicyOperation } from './types' ;
@@ -59,6 +60,9 @@ import {
5960 trueNode ,
6061} from './utils' ;
6162
63+ const SINGLE_ROW_READ_OPERATIONS = new Set < string > ( SingleRowReadOperations ) ;
64+ const ORM_WRITE_OPERATIONS = new Set < string > ( CoreWriteOperations ) ;
65+
6266export type CrudQueryNode = SelectQueryNode | InsertQueryNode | UpdateQueryNode | DeleteQueryNode ;
6367
6468export type MutationQueryNode = InsertQueryNode | UpdateQueryNode | DeleteQueryNode ;
@@ -93,8 +97,13 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
9397 }
9498
9599 if ( ! this . isMutationQueryNode ( node ) ) {
96- // transform and proceed with read directly
97- return proceed ( this . transformNode ( node ) ) ;
100+ const selectNode = node as SelectQueryNode ;
101+ const result = await proceed ( this . transformNode ( node ) ) ;
102+ // When 0 rows returned on a single-row read, distinguish "not found" from policy denial
103+ if ( result . rows . length === 0 && SINGLE_ROW_READ_OPERATIONS . has ( policyContextStorage . getStore ( ) ?. operation ?? '' ) ) {
104+ await this . postReadZeroRowsCheck ( selectNode , proceed ) ;
105+ }
106+ return result ;
98107 }
99108
100109 const { mutationModel } = this . getMutationModel ( node ) ;
@@ -142,9 +151,9 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
142151 // Use > 0 negation (not === 0) because numAffectedRows is BigInt in some drivers
143152 if ( ! ( ( result . numAffectedRows ?? 0 ) > 0 ) ) {
144153 if ( DeleteQueryNode . is ( node ) ) {
145- await this . postMutationZeroRowsCheck ( mutationModel , 'delete' , node . where ?. where , proceed ) ;
154+ await this . postZeroRowsCheck ( mutationModel , 'delete' , node . where ?. where , proceed ) ;
146155 } else if ( UpdateQueryNode . is ( node ) ) {
147- await this . postMutationZeroRowsCheck ( mutationModel , 'update' , node . where ?. where , proceed ) ;
156+ await this . postZeroRowsCheck ( mutationModel , 'update' , node . where ?. where , proceed ) ;
148157 }
149158 }
150159
@@ -259,38 +268,47 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
259268 }
260269 }
261270
262- // Checks if any row matching the original WHERE exists without the policy filter.
263- // Called when numAffectedRows == 0 for UPDATE or DELETE.
271+ // Called when a single-row read returns 0 rows. Skips internal reads (read-back after mutation).
272+ private async postReadZeroRowsCheck ( node : SelectQueryNode , proceed : ProceedKyselyQueryFunction ) : Promise < void > {
273+ if ( ORM_WRITE_OPERATIONS . has ( policyContextStorage . getStore ( ) ?. operation ?? '' ) ) return ;
274+ if ( ! node . from || node . from . froms . length !== 1 ) return ;
275+ const extractedTable = this . extractTableName ( node . from . froms [ 0 ] ! ) ;
276+ if ( ! extractedTable ) return ;
277+ const { model } = extractedTable ;
278+ if ( ! QueryUtils . getModel ( this . client . $schema , model ) ) return ;
279+ return this . postZeroRowsCheck ( model , 'read' , node . where ?. where , proceed ) ;
280+ }
281+
282+ // Checks if any row matching WHERE exists without the policy filter.
264283 // If a row exists but was filtered by policy → throws REJECTED_BY_POLICY with codes.
265- // If no row matches → returns silently (ORM layer handles "not found").
266- // Combines existence check and code diagnostics into a single query.
267- private async postMutationZeroRowsCheck (
284+ // If no row matches → returns silently.
285+ private async postZeroRowsCheck (
268286 model : string ,
269- operation : 'update' | 'delete' ,
270- originalWhere : OperationNode | undefined ,
287+ operation : 'read' | ' update' | 'delete' ,
288+ whereCondition : OperationNode | undefined ,
271289 proceed : ProceedKyselyQueryFunction ,
272290 ) {
273- if ( this . isManyToManyJoinTable ( model ) ) return ;
274291 if ( this . tryGetConstantPolicy ( model , operation ) === true ) return ;
275292 if ( this . options . fetchPolicyCodes === false ) return ;
276293 const policiesWithCode = this . getModelPolicies ( model , operation ) . filter ( ( p ) => p . code ) ;
277- // Skip if no policies carry an error code — nothing to surface.
278294 if ( policiesWithCode . length === 0 ) return ;
295+ if ( this . isManyToManyJoinTable ( model ) ) return ;
279296
280- const whereCondition = originalWhere ?? trueNode ( this . dialect ) ;
297+ // No WHERE clause means "match all rows" — use a literal TRUE so the existence sub-query is valid SQL.
298+ const where = whereCondition ?? trueNode ( this . dialect ) ;
281299
282300 const rowExistsInner = this . eb
283301 . selectFrom ( model )
284302 . select ( this . eb . lit ( 1 ) . as ( '_' ) )
285- . where ( ( ) => new ExpressionWrapper ( whereCondition ) ) ;
303+ . where ( ( ) => new ExpressionWrapper ( where ) ) ;
286304
287305 const codeSelections = policiesWithCode . map ( ( policy , i ) => {
288306 const condition = this . compilePolicyCondition ( model , undefined , operation , policy ) ;
289307 const violationCondition = policy . kind === 'allow' ? logicalNot ( this . dialect , condition ) : condition ;
290308 const inner = this . eb
291309 . selectFrom ( model )
292310 . select ( this . eb . lit ( 1 ) . as ( '_' ) )
293- . where ( ( ) => new ExpressionWrapper ( conjunction ( this . dialect , [ whereCondition , violationCondition ] ) ) ) ;
311+ . where ( ( ) => new ExpressionWrapper ( conjunction ( this . dialect , [ where , violationCondition ] ) ) ) ;
294312 return SelectionNode . create (
295313 AliasNode . create ( this . eb . exists ( inner ) . toOperationNode ( ) , IdentifierNode . create ( `$c${ i } ` ) ) ,
296314 ) ;
@@ -312,10 +330,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
312330 const row = result . rows [ 0 ] ?? { } ;
313331 if ( ! row . $exists ) return ;
314332
315- const policyCodes = policiesWithCode
316- ? policiesWithCode . filter ( ( _ , i ) => row [ `$c${ i } ` ] ) . map ( ( p ) => p . code ! )
317- : undefined ;
318-
333+ const policyCodes = policiesWithCode . filter ( ( _ , i ) => row [ `$c${ i } ` ] ) . map ( ( p ) => p . code ! ) ;
319334 throw createRejectedByPolicyError ( model , RejectedByPolicyReason . NO_ACCESS , undefined , policyCodes ) ;
320335 }
321336
0 commit comments