Skip to content

Commit 22b2f5d

Browse files
committed
update default flow shift
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
1 parent 6a35224 commit 22b2f5d

2 files changed

Lines changed: 16 additions & 6 deletions

File tree

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,12 +499,17 @@ def forward(
499499
resolved_flow_shift = (
500500
flow_shift if flow_shift is not None else self._default_flow_shift(height, width)
501501
)
502-
if self.scheduler.config.shift != resolved_flow_shift:
502+
503+
sched_cfg = self.scheduler.config
504+
shift_key = (
505+
"shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None
506+
)
507+
if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift:
503508
logger.info(
504-
f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} "
509+
f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} "
505510
f"({'user' if flow_shift is not None else 'variant default'})"
506511
)
507-
self.scheduler.config.shift = resolved_flow_shift
512+
self.scheduler.register_to_config(**{shift_key: resolved_flow_shift})
508513

509514
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
510515

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,17 @@ def forward(
550550
resolved_flow_shift = (
551551
flow_shift if flow_shift is not None else self._default_flow_shift(height, width)
552552
)
553-
if self.scheduler.config.shift != resolved_flow_shift:
553+
554+
sched_cfg = self.scheduler.config
555+
shift_key = (
556+
"shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None
557+
)
558+
if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift:
554559
logger.info(
555-
f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} "
560+
f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} "
556561
f"({'user' if flow_shift is not None else 'variant default'})"
557562
)
558-
self.scheduler.config.shift = resolved_flow_shift
563+
self.scheduler.register_to_config(**{shift_key: resolved_flow_shift})
559564

560565
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
561566

0 commit comments

Comments
 (0)