Skip to content

Commit 0cdfa67

Browse files
committed
refactor(wan2.2): move Wan2.2 save logic; update wan_t2v configs
Made-with: Cursor
1 parent e0fc7d4 commit 0cdfa67

6 files changed

Lines changed: 85 additions & 129 deletions

File tree

configs/quantization/video_gen/wan_t2v/awq_w_a.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ base:
22
seed: &seed 42
33
model:
44
type: WanT2V
5-
path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-14B-Diffusers
5+
path: /path/to/wan_t2v
66
torch_dtype: auto
77
calib:
88
name: t2v
99
download: False
10-
path: ./assets/wan_t2v/calib/
10+
path: ../assets/wan_t2v/calib/
1111
sample_steps: 20
1212
bs: 1
1313
target_height: 480
@@ -20,7 +20,7 @@ eval:
2020
type: video_gen
2121
name: t2v
2222
download: False
23-
path: ./assets/wan_t2v/calib/
23+
path: ../assets/wan_t2v/calib/
2424
bs: 1
2525
target_height: 480
2626
target_width: 832
@@ -31,12 +31,12 @@ quant:
3131
video_gen:
3232
method: Awq
3333
weight:
34-
bit: 4
34+
bit: 6
3535
symmetric: True
3636
granularity: per_channel
3737
group_size: -1
3838
act:
39-
bit: 4
39+
bit: 6
4040
symmetric: True
4141
granularity: per_token
4242
special:
@@ -46,4 +46,4 @@ quant:
4646
clip_sym: True
4747
save:
4848
save_lightx2v: True
49-
save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/
49+
save_path: /path/to/x2v/

configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml

Lines changed: 0 additions & 49 deletions
This file was deleted.

configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ base:
22
seed: &seed 42
33
model:
44
type: WanT2V
5-
path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-14B-Diffusers
5+
path: /path/to/wan_t2v
66
torch_dtype: auto
77
calib:
88
name: t2v
99
download: False
10-
path: ./assets/wan_t2v/calib/
10+
path: ../assets/wan_t2v/calib/
1111
sample_steps: 20
1212
bs: 1
1313
target_height: 480
@@ -20,30 +20,26 @@ eval:
2020
type: video_gen
2121
name: t2v
2222
download: False
23-
path: ./assets/wan_t2v/calib/
23+
path: ../assets/wan_t2v/calib/
2424
bs: 1
2525
target_height: 480
2626
target_width: 832
2727
num_frames: 81
2828
guidance_scale: 5.0
29-
output_video_path: ./output_videos_awq/
29+
output_video_path: ./output_videos_sq/
3030
quant:
3131
video_gen:
32-
method: Awq
32+
method: SmoothQuant
3333
weight:
34-
bit: 4
34+
bit: 6
3535
symmetric: True
3636
granularity: per_channel
37-
group_size: -1
3837
act:
39-
bit: 4
38+
bit: 6
4039
symmetric: True
4140
granularity: per_token
4241
special:
43-
trans: True
44-
trans_version: v2
45-
weight_clip: True
46-
clip_sym: True
42+
alpha: 0.7
4743
save:
4844
save_lightx2v: True
49-
save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/
45+
save_path: /path/to/x2v/

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,57 +1076,7 @@ def save_model(self, path):
10761076
logger.info('save model done --')
10771077
self.copy_tokenizer(path)
10781078
elif self.config.model.type in ['Wan2T2V']:
1079-
if getattr(self.model.Pipeline, '_is_wan_official', False):
1080-
src = getattr(self.model, 'pipeline_model_path', self.model.model_path)
1081-
self.model.copy_native_checkpoint(src, path)
1082-
1083-
self.model.Pipeline.transformer.save_pretrained(
1084-
os.path.join(path, 'high_noise_model')
1085-
)
1086-
logger.info('save Wan2.2 high_noise_model done --')
1087-
if (
1088-
hasattr(self.model.Pipeline, 'transformer_2')
1089-
and self.model.Pipeline.transformer_2 is not None
1090-
):
1091-
self.model.Pipeline.transformer_2.save_pretrained(
1092-
os.path.join(path, 'low_noise_model')
1093-
)
1094-
logger.info('save Wan2.2 low_noise_model done --')
1095-
self.model.validate_native_save_structure(path, source_path=src)
1096-
return
1097-
1098-
# Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.)
1099-
# so that non-quantized components are preserved.
1100-
src = getattr(self.model, 'pipeline_model_path', self.model.model_path)
1101-
copied_from_source = False
1102-
if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path):
1103-
if os.path.exists(path):
1104-
shutil.rmtree(path)
1105-
shutil.copytree(src, path)
1106-
logger.info(f'Copied original pipeline from {src} to {path}')
1107-
copied_from_source = True
1108-
if not copied_from_source:
1109-
if os.path.exists(path):
1110-
shutil.rmtree(path)
1111-
# Fallback for remote repo-id sources: materialize all non-quantized components first.
1112-
self.model.Pipeline.save_pretrained(path, safe_serialization=True)
1113-
logger.info(
1114-
'save Wan2.2 full pipeline done via Pipeline.save_pretrained '
1115-
f'(source={src}) --'
1116-
)
1117-
# Overwrite transformer subfolder with quantized weights.
1118-
self.model.Pipeline.transformer.save_pretrained(
1119-
os.path.join(path, 'transformer')
1120-
)
1121-
logger.info('save Wan2.2 transformer done --')
1122-
if (
1123-
hasattr(self.model.Pipeline, 'transformer_2')
1124-
and self.model.Pipeline.transformer_2 is not None
1125-
):
1126-
self.model.Pipeline.transformer_2.save_pretrained(
1127-
os.path.join(path, 'transformer_2')
1128-
)
1129-
logger.info('save Wan2.2 transformer_2 done --')
1079+
self.model.save_wan2_2_pretrained(path)
11301080
else:
11311081
self.model.get_model().save_pretrained(path)
11321082
logger.info('save model done --')

llmc/models/wan2_2_t2v.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,5 +677,68 @@ def validate_native_save_structure(save_path, source_path=None):
677677
f'top-level entries={sorted(os.listdir(save_path))}'
678678
)
679679

680+
def save_wan2_2_pretrained(self, path):
681+
"""Wan2.2 专用保存:支持官方 native 与非官方 Pipeline 两种布局。
682+
683+
该逻辑原本位于 llmc/compression/quantization/base_blockwise_quantization.py 的 Wan2T2V 分支。
684+
"""
685+
if int(os.environ.get('RANK', '0')) != 0:
686+
return
687+
688+
if getattr(self.Pipeline, '_is_wan_official', False):
689+
src = getattr(self, 'pipeline_model_path', self.model_path)
690+
self.copy_native_checkpoint(src, path)
691+
692+
self.Pipeline.transformer.save_pretrained(
693+
os.path.join(path, 'high_noise_model')
694+
)
695+
logger.info('save Wan2.2 high_noise_model done --')
696+
if (
697+
hasattr(self.Pipeline, 'transformer_2')
698+
and self.Pipeline.transformer_2 is not None
699+
):
700+
self.Pipeline.transformer_2.save_pretrained(
701+
os.path.join(path, 'low_noise_model')
702+
)
703+
logger.info('save Wan2.2 low_noise_model done --')
704+
705+
self.validate_native_save_structure(path, source_path=src)
706+
return
707+
708+
# Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.)
709+
# so that non-quantized components are preserved.
710+
src = getattr(self, 'pipeline_model_path', self.model_path)
711+
copied_from_source = False
712+
if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path):
713+
if os.path.exists(path):
714+
shutil.rmtree(path)
715+
shutil.copytree(src, path)
716+
logger.info(f'Copied original pipeline from {src} to {path}')
717+
copied_from_source = True
718+
719+
if not copied_from_source:
720+
if os.path.exists(path):
721+
shutil.rmtree(path)
722+
# Fallback for remote repo-id sources: materialize all non-quantized components first.
723+
self.Pipeline.save_pretrained(path, safe_serialization=True)
724+
logger.info(
725+
'save Wan2.2 full pipeline done via Pipeline.save_pretrained '
726+
f'(source={src}) --'
727+
)
728+
729+
# Overwrite transformer subfolder with quantized weights.
730+
self.Pipeline.transformer.save_pretrained(
731+
os.path.join(path, 'transformer')
732+
)
733+
logger.info('save Wan2.2 transformer done --')
734+
if (
735+
hasattr(self.Pipeline, 'transformer_2')
736+
and self.Pipeline.transformer_2 is not None
737+
):
738+
self.Pipeline.transformer_2.save_pretrained(
739+
os.path.join(path, 'transformer_2')
740+
)
741+
logger.info('save Wan2.2 transformer_2 done --')
742+
680743
def skip_layer_name(self):
681744
pass

llmc/models/wan_t2v.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,10 @@ def __init__(self, config, device_map=None, use_cache=False):
3131

3232
def build_model(self):
3333
vae = AutoencoderKLWan.from_pretrained(
34-
self.model_path, subfolder='vae', torch_dtype=torch.float32, use_safetensors=True
34+
self.model_path, subfolder='vae', torch_dtype=torch.float32
3535
)
36-
# self.Pipeline = WanPipeline.from_pretrained(
37-
# self.model_path, vae=vae, torch_dtype=torch.bfloat16
38-
# )
3936
self.Pipeline = WanPipeline.from_pretrained(
40-
self.model_path, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True
37+
self.model_path, vae=vae, torch_dtype=torch.bfloat16
4138
)
4239
self.find_llmc_model()
4340
self.find_blocks()
@@ -64,17 +61,16 @@ def __init__(self, module):
6461

6562
def forward(self, *args, **kwargs):
6663
params = list(self.signature.parameters.keys())
67-
capture_kwargs = dict(kwargs)
6864
for i, arg in enumerate(args):
6965
if i > 0:
70-
capture_kwargs[params[i]] = arg
66+
kwargs[params[i]] = arg
7167
first_block_input['data'].append(args[0])
72-
first_block_input['kwargs'].append(capture_kwargs)
68+
first_block_input['kwargs'].append(kwargs)
7369
self.step += 1
7470
if self.step == sample_steps:
7571
raise ValueError
7672
else:
77-
return self.module(*args, **kwargs)
73+
return self.module(*args)
7874

7975
return Catcher
8076

@@ -166,4 +162,4 @@ def get_layers_except_blocks(self):
166162
pass
167163

168164
def skip_layer_name(self):
169-
pass
165+
pass

0 commit comments

Comments
 (0)