Skip to content

Commit 007360e

Browse files
committed
fix(wan): preserve catcher kwargs forwarding during calibration
1 parent f261203 commit 007360e

2 files changed

Lines changed: 12 additions & 9 deletions

File tree

llmc/models/wan2_2_t2v.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,16 +405,17 @@ def __init__(self, module):
405405

406406
def forward(self, *args, **kwargs):
407407
params = list(self.signature.parameters.keys())
408+
capture_kwargs = dict(kwargs)
408409
for i, arg in enumerate(args):
409410
if i > 0:
410-
kwargs[params[i]] = arg
411+
capture_kwargs[params[i]] = arg
411412
first_block_input['data'].append(args[0])
412-
first_block_input['kwargs'].append(kwargs)
413+
first_block_input['kwargs'].append(capture_kwargs)
413414
self.step += 1
414415
if self.step == sample_steps:
415416
raise ValueError
416417
else:
417-
return self.module(*args)
418+
return self.module(*args, **kwargs)
418419

419420
return Catcher
420421

@@ -442,20 +443,21 @@ def _to_cpu(self, x):
442443

443444
def forward(self, *args, **kwargs):
444445
params = list(self.signature.parameters.keys())
446+
capture_kwargs = dict(kwargs)
445447
for i, arg in enumerate(args):
446448
if i > 0:
447-
kwargs[params[i]] = arg
449+
capture_kwargs[params[i]] = arg
448450
cur_num = len(first_block_input[self.expert_name]['data'])
449451
if cur_num < sample_steps:
450452
first_block_input[self.expert_name]['data'].append(
451453
args[0].detach().cpu() if torch.is_tensor(args[0]) else args[0]
452454
)
453455
first_block_input[self.expert_name]['kwargs'].append(
454-
{k: self._to_cpu(v) for k, v in kwargs.items()}
456+
{k: self._to_cpu(v) for k, v in capture_kwargs.items()}
455457
)
456458
if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input):
457459
raise ValueError
458-
return self.module(*args)
460+
return self.module(*args, **kwargs)
459461

460462
first_block = self.Pipeline.transformer.blocks[0]
461463
self.Pipeline.transformer.blocks[0] = Catcher(first_block, 'transformer')

llmc/models/wan_t2v.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@ def __init__(self, module):
6464

6565
def forward(self, *args, **kwargs):
6666
params = list(self.signature.parameters.keys())
67+
capture_kwargs = dict(kwargs)
6768
for i, arg in enumerate(args):
6869
if i > 0:
69-
kwargs[params[i]] = arg
70+
capture_kwargs[params[i]] = arg
7071
first_block_input['data'].append(args[0])
71-
first_block_input['kwargs'].append(kwargs)
72+
first_block_input['kwargs'].append(capture_kwargs)
7273
self.step += 1
7374
if self.step == sample_steps:
7475
raise ValueError
7576
else:
76-
return self.module(*args)
77+
return self.module(*args, **kwargs)
7778

7879
return Catcher
7980

0 commit comments

Comments
 (0)