@@ -405,16 +405,17 @@ def __init__(self, module):
405405
406406 def forward (self , * args , ** kwargs ):
407407 params = list (self .signature .parameters .keys ())
408+ capture_kwargs = dict (kwargs )
408409 for i , arg in enumerate (args ):
409410 if i > 0 :
410- kwargs [params [i ]] = arg
411+ capture_kwargs [params [i ]] = arg
411412 first_block_input ['data' ].append (args [0 ])
412- first_block_input ['kwargs' ].append (kwargs )
413+ first_block_input ['kwargs' ].append (capture_kwargs )
413414 self .step += 1
414415 if self .step == sample_steps :
415416 raise ValueError
416417 else :
417- return self .module (* args )
418+ return self .module (* args , ** kwargs )
418419
419420 return Catcher
420421
@@ -442,20 +443,21 @@ def _to_cpu(self, x):
442443
443444 def forward (self , * args , ** kwargs ):
444445 params = list (self .signature .parameters .keys ())
446+ capture_kwargs = dict (kwargs )
445447 for i , arg in enumerate (args ):
446448 if i > 0 :
447- kwargs [params [i ]] = arg
449+ capture_kwargs [params [i ]] = arg
448450 cur_num = len (first_block_input [self .expert_name ]['data' ])
449451 if cur_num < sample_steps :
450452 first_block_input [self .expert_name ]['data' ].append (
451453 args [0 ].detach ().cpu () if torch .is_tensor (args [0 ]) else args [0 ]
452454 )
453455 first_block_input [self .expert_name ]['kwargs' ].append (
454- {k : self ._to_cpu (v ) for k , v in kwargs .items ()}
456+ {k : self ._to_cpu (v ) for k , v in capture_kwargs .items ()}
455457 )
456458 if all (len (first_block_input [name ]['data' ]) >= sample_steps for name in first_block_input ):
457459 raise ValueError
458- return self .module (* args )
460+ return self .module (* args , ** kwargs )
459461
460462 first_block = self .Pipeline .transformer .blocks [0 ]
461463 self .Pipeline .transformer .blocks [0 ] = Catcher (first_block , 'transformer' )
0 commit comments