@@ -2,7 +2,7 @@ import { match } from 'ts-pattern';
22import type { GetModels , SchemaDef } from '../../../schema' ;
33import type { WhereInput } from '../../crud-types' ;
44import { createRejectedByPolicyError , RejectedByPolicyReason } from '../../errors' ;
5- import { getIdValues } from '../../query-utils' ;
5+ import { getIdValues , requireIdFields } from '../../query-utils' ;
66import { BaseOperationHandler } from './base' ;
77
88export class UpdateOperationHandler < Schema extends SchemaDef > extends BaseOperationHandler < Schema > {
@@ -28,7 +28,10 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
2828 // analyze if we need to read back the update record, or just return the updated result
2929 const { needReadBack, selectedFields } = this . needReadBack ( args ) ;
3030
31- const result = await this . safeTransaction ( async ( tx ) => {
31+ // analyze if the update involves nested updates
32+ const needsNestedUpdate = this . needsNestedUpdate ( args . data ) ;
33+
34+ const result = await this . safeTransactionIf ( needReadBack || needsNestedUpdate , async ( tx ) => {
3235 const updateResult = await this . update (
3336 tx ,
3437 this . model ,
@@ -76,10 +79,11 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
7679 return result ;
7780 }
7881 }
79-
8082 private async runUpdateMany ( args : any ) {
81- // TODO: avoid using transaction for simple update
82- return this . safeTransaction ( async ( tx ) => {
83+ // analyze if the update involves nested updates
84+ const needsNestedUpdate = this . needsNestedUpdate ( args . data ) ;
85+
86+ return this . safeTransactionIf ( needsNestedUpdate , async ( tx ) => {
8387 return this . updateMany ( tx , this . model , args . where , args . data , args . limit , false ) ;
8488 } ) ;
8589 }
@@ -92,37 +96,43 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
9296 // analyze if we need to read back the updated record, or just return the update result
9397 const { needReadBack, selectedFields } = this . needReadBack ( args ) ;
9498
95- const { readBackResult, updateResult } = await this . safeTransaction ( async ( tx ) => {
96- const updateResult = await this . updateMany (
97- tx ,
98- this . model ,
99- args . where ,
100- args . data ,
101- args . limit ,
102- true ,
103- undefined ,
104- undefined ,
105- selectedFields ,
106- ) ;
99+ // analyze if the update involves nested updates
100+ const needsNestedUpdate = this . needsNestedUpdate ( args . data ) ;
107101
108- if ( needReadBack ) {
109- const readBackResult = await this . read (
102+ const { readBackResult, updateResult } = await this . safeTransactionIf (
103+ needReadBack || needsNestedUpdate ,
104+ async ( tx ) => {
105+ const updateResult = await this . updateMany (
110106 tx ,
111107 this . model ,
112- {
113- select : args . select ,
114- omit : args . omit ,
115- where : {
116- OR : updateResult . map ( ( item ) => getIdValues ( this . schema , this . model , item ) as any ) ,
117- } ,
118- } as any , // TODO: fix type
108+ args . where ,
109+ args . data ,
110+ args . limit ,
111+ true ,
112+ undefined ,
113+ undefined ,
114+ selectedFields ,
119115 ) ;
120116
121- return { readBackResult, updateResult } ;
122- } else {
123- return { readBackResult : updateResult , updateResult } ;
124- }
125- } ) ;
117+ if ( needReadBack ) {
118+ const readBackResult = await this . read (
119+ tx ,
120+ this . model ,
121+ {
122+ select : args . select ,
123+ omit : args . omit ,
124+ where : {
125+ OR : updateResult . map ( ( item ) => getIdValues ( this . schema , this . model , item ) as any ) ,
126+ } ,
127+ } as any , // TODO: fix type
128+ ) ;
129+
130+ return { readBackResult, updateResult } ;
131+ } else {
132+ return { readBackResult : updateResult , updateResult } ;
133+ }
134+ } ,
135+ ) ;
126136
127137 if ( readBackResult . length < updateResult . length && this . hasPolicyEnabled ) {
128138 // some of the updated entities cannot be read back
@@ -140,6 +150,7 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
140150 // analyze if we need to read back the updated record, or just return the update result
141151 const { needReadBack, selectedFields } = this . needReadBack ( args ) ;
142152
153+ // upsert is intrinsically multi-step and is always run in a transaction
143154 const result = await this . safeTransaction ( async ( tx ) => {
144155 let mutationResult : unknown = await this . update (
145156 tx ,
@@ -191,9 +202,11 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
191202 return baseResult ;
192203 }
193204
205+ const idFields = requireIdFields ( this . schema , this . model ) ;
206+
194207 if ( ! this . dialect . supportsReturning ) {
195208 // if dialect doesn't support "returning", we always need to read back
196- return { needReadBack : true , selectedFields : undefined } ;
209+ return { needReadBack : true , selectedFields : idFields } ;
197210 }
198211
199212 // further check if we're not updating any non-relation fields, because if so,
@@ -206,14 +219,33 @@ export class UpdateOperationHandler<Schema extends SchemaDef> extends BaseOperat
206219
207220 // update/updateMany payload
208221 if ( args . data && ! Object . keys ( args . data ) . some ( ( field ) => nonRelationFields . includes ( field ) ) ) {
209- return { needReadBack : true , selectedFields : undefined } ;
222+ return { needReadBack : true , selectedFields : idFields } ;
210223 }
211224
212225 // upsert payload
213226 if ( args . update && ! Object . keys ( args . update ) . some ( ( field : string ) => nonRelationFields . includes ( field ) ) ) {
214- return { needReadBack : true , selectedFields : undefined } ;
227+ return { needReadBack : true , selectedFields : idFields } ;
215228 }
216229
217230 return baseResult ;
218231 }
232+
233+ private needsNestedUpdate ( data : any ) {
234+ const modelDef = this . requireModel ( this . model ) ;
235+ if ( modelDef . baseModel ) {
236+ // involve delegate base models
237+ return true ;
238+ }
239+
240+ // has relation manipulation in the payload
241+ const hasRelation = Object . entries ( data ) . some ( ( [ field , value ] ) => {
242+ const fieldDef = this . getField ( this . model , field ) ;
243+ return fieldDef ?. relation && value !== undefined ;
244+ } ) ;
245+ if ( hasRelation ) {
246+ return true ;
247+ }
248+
249+ return false ;
250+ }
219251}
0 commit comments