Skip to content

Commit f261203

Browse files
committed
fix(wan2.2): enforce native save structure and align default dual guidance
1 parent df9a09b commit f261203

4 files changed

Lines changed: 106 additions & 15 deletions

File tree

configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ model:
55
path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B
66
# 若未 `pip install -e /path/to/Wan2.2`,可显式指定官方仓库代码路径:
77
# wan2_repo_path: /path/to/Wan2.2
8+
# 默认严格走官方 Wan2.2 原生后端;官方代码不可用时会直接报错,不再静默回退到 Diffusers。
9+
# 若确实需要回退可开启:
10+
# allow_diffusers_fallback: True
811
torch_dtype: auto
912
# 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU
1013
use_cpu_to_save_cuda_mem_for_catcher: True
@@ -17,7 +20,9 @@ calib:
1720
target_height: 480 # OOM 时可减小,如 320
1821
target_width: 832 # OOM 时可减小,如 576
1922
num_frames: 81 # OOM 时可减小,如 49 或 33
20-
guidance_scale: 5.0
23+
# 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise)
24+
guidance_scale: 4.0 # high_noise
25+
guidance_scale_2: 3.0 # low_noise
2126
seed: *seed
2227
eval:
2328
eval_pos: [transformed, fake_quant]
@@ -29,7 +34,9 @@ eval:
2934
target_height: 480
3035
target_width: 832
3136
num_frames: 81
32-
guidance_scale: 5.0
37+
# 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise)
38+
guidance_scale: 4.0 # high_noise
39+
guidance_scale_2: 3.0 # low_noise
3340
output_video_path: ./output_videos_awq/
3441
quant:
3542
video_gen:

configs/quantization/video_gen/wan2_2_t2v/awq_w_a_skip_first.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ model:
55
path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B
66
# 若未 `pip install -e /path/to/Wan2.2`,可显式指定官方仓库代码路径:
77
# wan2_repo_path: /path/to/Wan2.2
8+
# 默认严格走官方 Wan2.2 原生后端;官方代码不可用时会直接报错,不再静默回退到 Diffusers。
9+
# 若确实需要回退可开启:
10+
# allow_diffusers_fallback: True
811
torch_dtype: auto
912
# 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU
1013
use_cpu_to_save_cuda_mem_for_catcher: True
@@ -17,7 +20,9 @@ calib:
1720
target_height: 480 # OOM 时可减小,如 320
1821
target_width: 832 # OOM 时可减小,如 576
1922
num_frames: 81 # OOM 时可减小,如 49 或 33
20-
guidance_scale: 5.0
23+
# 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise)
24+
guidance_scale: 4.0 # high_noise
25+
guidance_scale_2: 3.0 # low_noise
2126
seed: *seed
2227
eval:
2328
eval_pos: [transformed, fake_quant]
@@ -29,7 +34,9 @@ eval:
2934
target_height: 480
3035
target_width: 832
3136
num_frames: 81
32-
guidance_scale: 5.0
37+
# 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise)
38+
guidance_scale: 4.0 # high_noise
39+
guidance_scale_2: 3.0 # low_noise
3340
output_video_path: ./output_videos_awq_skip_first/
3441
quant:
3542
video_gen:

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,58 @@ def contiguous_params(self):
10601060
if not param.is_contiguous():
10611061
param.data = param.data.contiguous()
10621062

1063+
def _copy_wan22_native_checkpoint(self, src, dst):
1064+
if not isinstance(src, str) or not os.path.isdir(src):
1065+
raise RuntimeError(
1066+
'Wan2.2 official save expects a local native checkpoint directory, '
1067+
f'but got src={src!r}.'
1068+
)
1069+
if os.path.abspath(src) == os.path.abspath(dst):
1070+
raise RuntimeError(
1071+
'Wan2.2 official save path must differ from source checkpoint path '
1072+
f'(src=dst={src}).'
1073+
)
1074+
if os.path.exists(dst):
1075+
shutil.rmtree(dst)
1076+
shutil.copytree(src, dst)
1077+
logger.info(f'Copied original Wan2.2 native checkpoint from {src} to {dst}')
1078+
1079+
def _validate_wan22_native_save_structure(self, save_path, source_path=None):
1080+
if not os.path.isdir(save_path):
1081+
raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}')
1082+
1083+
required_entries = ['configuration.json', 'high_noise_model', 'low_noise_model']
1084+
missing_required = [
1085+
name for name in required_entries
1086+
if not os.path.exists(os.path.join(save_path, name))
1087+
]
1088+
if missing_required:
1089+
raise RuntimeError(
1090+
'Wan2.2 saved structure is incomplete. Missing required entries: '
1091+
f'{missing_required}. save_path={save_path}'
1092+
)
1093+
1094+
if isinstance(source_path, str) and os.path.isdir(source_path):
1095+
source_entries = set(os.listdir(source_path))
1096+
source_non_expert_entries = sorted(
1097+
name for name in source_entries
1098+
if name not in {'high_noise_model', 'low_noise_model'}
1099+
)
1100+
missing_non_expert = [
1101+
name for name in source_non_expert_entries
1102+
if not os.path.exists(os.path.join(save_path, name))
1103+
]
1104+
if missing_non_expert:
1105+
raise RuntimeError(
1106+
'Wan2.2 saved structure lost original non-expert files/directories: '
1107+
f'{missing_non_expert}. source_path={source_path}, save_path={save_path}'
1108+
)
1109+
1110+
logger.info(
1111+
f'Wan2.2 native save structure verified. '
1112+
f'top-level entries={sorted(os.listdir(save_path))}'
1113+
)
1114+
10631115
@torch.no_grad()
10641116
def save_model(self, path):
10651117
if int(os.environ['RANK']) != 0:
@@ -1082,16 +1134,8 @@ def save_model(self, path):
10821134
self.copy_tokenizer(path)
10831135
elif self.config.model.type in ['Wan2T2V']:
10841136
if getattr(self.model.Pipeline, '_is_wan_official', False):
1085-
src = self.model.model_path
1086-
copied_from_source = False
1087-
if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path):
1088-
if os.path.exists(path):
1089-
shutil.rmtree(path)
1090-
shutil.copytree(src, path)
1091-
logger.info(f'Copied original Wan2.2 native checkpoint from {src} to {path}')
1092-
copied_from_source = True
1093-
if not copied_from_source and os.path.exists(path):
1094-
shutil.rmtree(path)
1137+
src = getattr(self.model, 'pipeline_model_path', self.model.model_path)
1138+
self._copy_wan22_native_checkpoint(src, path)
10951139

10961140
self.model.Pipeline.transformer.save_pretrained(
10971141
os.path.join(path, 'high_noise_model')
@@ -1105,6 +1149,7 @@ def save_model(self, path):
11051149
os.path.join(path, 'low_noise_model')
11061150
)
11071151
logger.info('save Wan2.2 low_noise_model done --')
1152+
self._validate_wan22_native_save_structure(path, source_path=src)
11081153
return
11091154

11101155
# Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.)

llmc/models/wan2_2_t2v.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,24 @@ def _has_wan22_native_layout(model_path):
163163
and os.path.isdir(os.path.join(model_path, 'low_noise_model'))
164164
)
165165

166+
@staticmethod
167+
def _is_wan22_native_repo_id(model_path):
168+
if not isinstance(model_path, str):
169+
return False
170+
return model_path.rstrip('/\\') == 'Wan-AI/Wan2.2-T2V-A14B'
171+
172+
def _should_require_official_backend(self, normalized_model_path):
173+
if self.config.model.get('force_diffusers', False):
174+
return False
175+
if self.config.model.get('diffusers_path', None):
176+
return False
177+
if self.config.model.get('allow_diffusers_fallback', False):
178+
return False
179+
return (
180+
self._has_wan22_native_layout(normalized_model_path)
181+
or self._is_wan22_native_repo_id(normalized_model_path)
182+
)
183+
166184
def _import_official_wan(self):
167185
def _import_impl():
168186
from wan.configs import t2v_A14B
@@ -185,7 +203,8 @@ def _import_impl():
185203
)
186204
logger.warning(
187205
'Failed to import official Wan2.2 runtime (wan package). '
188-
f'Falling back to diffusers import path. import_error={e}'
206+
'Diffusers fallback depends on model.allow_diffusers_fallback/model.force_diffusers. '
207+
f'import_error={e}'
189208
)
190209
return None, None
191210

@@ -286,6 +305,9 @@ def _resolve_pipeline_model_path(self):
286305

287306
def build_model(self):
288307
self.use_official_wan = False
308+
normalized_model_path = self._normalize_hf_repo_path(self.model_path)
309+
require_official_backend = self._should_require_official_backend(normalized_model_path)
310+
289311
if self._try_build_official_wan_pipeline():
290312
self.find_llmc_model()
291313
self.find_blocks()
@@ -302,6 +324,16 @@ def build_model(self):
302324
logger.info('Model: %s', self.model)
303325
return
304326

327+
if require_official_backend:
328+
raise RuntimeError(
329+
'Detected Wan2.2 native source '
330+
f'({normalized_model_path}) but official Wan runtime is unavailable. '
331+
'Please install/prepare official Wan2.2 code (pip install -e /path/to/Wan2.2 '
332+
'or set model.wan2_repo_path). '
333+
'If you intentionally want Diffusers fallback, set '
334+
'model.allow_diffusers_fallback=True or model.force_diffusers=True.'
335+
)
336+
305337
self.pipeline_model_path = self._resolve_pipeline_model_path()
306338
vae = AutoencoderKLWan.from_pretrained(
307339
self.pipeline_model_path,

0 commit comments

Comments
 (0)