@@ -119,6 +119,8 @@ def get_transform_info(self) -> dict:
119119 """
120120 Return a dictionary with the relevant information pertaining to an applied transform.
121121 """
122+ self ._init_trace_threadlocal ()
123+
122124 vals = (
123125 self .__class__ .__name__ ,
124126 id (self ),
@@ -300,25 +302,45 @@ def track_transform_meta(
300302 return out_obj
301303
302304 def check_transforms_match (self , transform : Mapping ) -> None :
303- """Check transforms are of same instance."""
304- xform_id = transform .get (TraceKeys .ID , "" )
305- if xform_id == id (self ):
306- return
307- # TraceKeys.NONE to skip the id check
308- if xform_id == TraceKeys .NONE :
305+ """Check whether a traced transform entry matches this transform.
306+
307+ When multiprocessing uses ``spawn``, transform instances are recreated,
308+ so matching can fall back to the transform class name instead of the
309+ original instance ID.
310+ """
311+ if self ._transforms_match (transform ):
309312 return
313+
314+ xform_id = transform .get (TraceKeys .ID , "" )
310315 xform_name = transform .get (TraceKeys .CLASS_NAME , "" )
311316 warning_msg = transform .get (TraceKeys .EXTRA_INFO , {}).get ("warn" )
312317 if warning_msg :
313318 warnings .warn (warning_msg )
314- # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
315- if torch .multiprocessing .get_start_method () in ("spawn" , None ) and xform_name == self .__class__ .__name__ :
316- return
317319 raise RuntimeError (
318320 f"Error { self .__class__ .__name__ } getting the most recently "
319321 f"applied invertible transform { xform_name } { xform_id } != { id (self )} ."
320322 )
321323
324+ def _transforms_match (self , transform : Mapping ) -> bool :
325+ """Return whether a traced transform entry matches this transform.
326+
327+ Matching succeeds when the traced ID matches this instance, when the ID
328+ check is explicitly disabled with ``TraceKeys.NONE``, or when
329+ multiprocessing uses ``spawn`` and the traced class name matches this
330+ transform class.
331+ """
332+ xform_id = transform .get (TraceKeys .ID , "" )
333+ if xform_id == id (self ):
334+ return True
335+ # TraceKeys.NONE to skip the id check
336+ if xform_id == TraceKeys .NONE :
337+ return True
338+ xform_name = transform .get (TraceKeys .CLASS_NAME , "" )
339+ # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
340+ if torch .multiprocessing .get_start_method (allow_none = True ) == "spawn" and xform_name == self .__class__ .__name__ :
341+ return True
342+ return False
343+
322344 def get_most_recent_transform (self , data , key : Hashable = None , check : bool = True , pop : bool = False ):
323345 """
324346 Get most recent matching transform for the current class from the sequence of applied operations.
@@ -350,10 +372,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
350372 if not all_transforms :
351373 raise ValueError (f"Item of type { type (data )} (key: { key } , pop: { pop } ) has empty 'applied_operations'" )
352374
375+ match_idx = len (all_transforms ) - 1
353376 if check :
354- self .check_transforms_match (all_transforms [- 1 ])
377+ for idx in range (len (all_transforms ) - 1 , - 1 , - 1 ):
378+ if self ._transforms_match (all_transforms [idx ]):
379+ match_idx = idx
380+ break
381+ else :
382+ self .check_transforms_match (all_transforms [- 1 ])
355383
356- return all_transforms .pop (- 1 ) if pop else all_transforms [- 1 ]
384+ return all_transforms .pop (match_idx ) if pop else all_transforms [match_idx ]
357385
358386 def pop_transform (self , data , key : Hashable = None , check : bool = True ):
359387 """
0 commit comments