@@ -343,7 +343,7 @@ def fn(
343343 if task_buf_order :
344344 for name , val in zip (task_buf_order , task_buf_vals ):
345345 if name .startswith (_AM_PREFIX ):
346- actual = name [len (_AM_PREFIX ):]
346+ actual = name [len (_AM_PREFIX ) :]
347347 if _atomic_model is not None :
348348 originals [name ] = _atomic_model ._buffers .get (actual )
349349 _atomic_model ._buffers [actual ] = val
@@ -364,7 +364,7 @@ def fn(
364364 finally :
365365 for name , orig in originals .items ():
366366 if name .startswith (_AM_PREFIX ):
367- actual = name [len (_AM_PREFIX ):]
367+ actual = name [len (_AM_PREFIX ) :]
368368 if _atomic_model is not None :
369369 _atomic_model ._buffers [actual ] = orig
370370 else :
@@ -533,7 +533,7 @@ def forward(
533533 _vals : list [torch .Tensor ] = []
534534 for _name in self ._task_buf_order :
535535 if _name .startswith (_AM_PREFIX ):
536- _actual = _name [len (_AM_PREFIX ):]
536+ _actual = _name [len (_AM_PREFIX ) :]
537537 _vals .append (_am ._buffers [_actual ])
538538 else :
539539 _vals .append (getattr (_fitting , _name ))
0 commit comments