Skip to content

Commit 5a51ded

Browse files
committed
docs/wan2.2 + refactor(wan2.2): move native save helpers and update quant/run config
Made-with: Cursor
1 parent 5aee498 commit 5a51ded

6 files changed

Lines changed: 220 additions & 249 deletions

File tree

configs/quantization/video_gen/wan2_2_t2v/awq_w_a_skip_first.yaml

Lines changed: 0 additions & 73 deletions
This file was deleted.

docs/wan2.2_quantization_guide.md

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Wan2.2 视频生成模型量化指南
2+
3+
## 概述
4+
5+
本仓库为 **Wan2.2-T2V** 提供的现成示例是 **4-bit AWQ 模拟量化**`configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml`)。
6+
7+
Wan2.2 为 **MoE 双专家**:高噪声(`transformer`)与低噪声(`transformer_2`),校准与块级量化会覆盖两条支路。保存侧默认示例为 `save_fake`,推理对接需按你的推理栈自行对齐。
8+
9+
**模型示例(原生 checkpoint 布局)**[Wan-AI/Wan2.2-T2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B)
10+
11+
## Wan2.2 相对 Wan2.1 的要点
12+
13+
| 项目 | 说明 |
14+
|------|------|
15+
| 注册名 | `Wan2T2V` |
16+
| 结构 | 双专家 MoE,非单路 DiT |
17+
| 推理后端 | 优先官方 `wan` + 原生目录;可按 YAML 注释回退 Diffusers |
18+
| CFG | `guidance_scale`(高噪声)与 `guidance_scale_2`(低噪声),与官方双引导一致 |
19+
20+
## 量化配置示例
21+
22+
`awq_w_a.yaml``quant` 段与仓库一致,例如:
23+
24+
```yaml
25+
quant:
26+
video_gen:
27+
method: Awq
28+
weight:
29+
quant_type: hif4
30+
bit: 4
31+
symmetric: True
32+
granularity: per_channel
33+
group_size: -1
34+
act:
35+
quant_type: hif4
36+
bit: 4
37+
symmetric: True
38+
granularity: per_token
39+
special:
40+
trans: True
41+
trans_version: v2
42+
weight_clip: True
43+
clip_sym: True
44+
```
45+
46+
首次使用需编译仓库内 hif4 GPU 扩展:`HiFloat4/hif4_gpu/build.sh`。
47+
48+
## 运行步骤
49+
50+
### 1. 环境
51+
52+
```bash
53+
export llmc=/path/to/LightCompress
54+
export PYTHONPATH=$llmc:$PYTHONPATH
55+
export CUDA_VISIBLE_DEVICES=0
56+
```
57+
58+
原生布局需要能 `import wan`,通常:
59+
60+
```bash
61+
pip install -e /path/to/Wan2.2
62+
```
63+
64+
或在 YAML 里设置 `wan2_repo_path: /path/to/Wan2.2`。
65+
66+
### 2. 校准数据
67+
68+
与 Wan2.1 T2V 相同,文本 prompt 文件目录,例如:
69+
70+
```
71+
assets/wan_t2v/calib/
72+
├── prompt_1.txt
73+
├── prompt_2.txt
74+
└── ...
75+
```
76+
77+
配置中 `calib.name: t2v`,`calib.path` 指向该目录。
78+
79+
### 3. 修改 `awq_w_a.yaml`
80+
81+
必改:
82+
83+
- `model.path`:Wan2.2 权重路径
84+
- `calib.path` / `eval.path`:校准与评估数据
85+
- `save.save_path`:输出目录
86+
87+
可选(见 YAML 注释):
88+
89+
- `use_cpu_to_save_cuda_mem_for_catcher: True`:校准显存紧张时减轻峰值
90+
- `allow_diffusers_fallback: True`:无法用官方后端时回退 Diffusers
91+
92+
双引导示例:
93+
94+
```yaml
95+
calib:
96+
guidance_scale: 4.0 # high_noise
97+
guidance_scale_2: 3.0 # low_noise
98+
eval:
99+
guidance_scale: 4.0
100+
guidance_scale_2: 3.0
101+
```
102+
103+
### 4. 启动量化
104+
105+
```bash
106+
torchrun \
107+
--nnodes 1 \
108+
--nproc_per_node 1 \
109+
--rdzv_id $RANDOM \
110+
--rdzv_backend c10d \
111+
--rdzv_endpoint 127.0.0.1:29500 \
112+
${llmc}/llmc/__main__.py \
113+
--config ${llmc}/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml \
114+
--task_id wan22_awq_int4
115+
```
116+
117+
`scripts/run_llmc.sh` 中把 `model_name=wan2_2_t2v``task_name=awq_w_a` 等与上述 YAML 对齐即可(需按本机修改脚本里的 Python 路径等)。
118+
119+
## 参数速查
120+
121+
| 区域 | 说明 |
122+
|------|------|
123+
| `model.type` | `Wan2T2V` |
124+
| `quant.video_gen.method` | `Awq` |
125+
| `weight` / `act` | `bit: 4`(具体 `quant_type` 以 YAML 为准) |
126+
| `save` | 示例 `save_fake: True``save_path` |
127+
128+
## 常见问题
129+
130+
- **OOM**:减小 `sample_steps``num_frames`、分辨率;`bs: 1`;可开 `use_cpu_to_save_cuda_mem_for_catcher`
131+
- **无法 `import wan`**:安装官方仓库或配置 `wan2_repo_path`
132+
- **hif4 扩展编译失败**:核对 CUDA / PyTorch 与 `HiFloat4/hif4_gpu/build.sh` 日志。
133+
- **画质下降**:增加/多样化校准 prompt;在支持范围内微调 `special` 与校准规模。
134+
135+
## 参考
136+
137+
- `configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml`
138+
- `llmc/models/wan2_2_t2v.py`
139+
- 其它精度(如 FP8、INT8)可参考 `docs/wan2.1_quantization_guide.md` 的思路,自行新增 `wan2_2_t2v` 下 YAML 并替换 `model.type` 与路径。

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
RotateLinear)
3939
from .quant import (
4040
FloatQuantizer,
41-
HiFloat4Quantizer,
4241
IntegerQuantizer,
4342
Weight48IntegerQuantizer,
4443
)
@@ -163,8 +162,6 @@ def set_quant_config(self):
163162
self.weight_quant_module = IntegerQuantizer
164163
elif quant_type == 'float-quant':
165164
self.weight_quant_module = FloatQuantizer
166-
elif quant_type == 'hif4':
167-
self.weight_quant_module = HiFloat4Quantizer
168165
logger.info(f'The used Weight Quant Module is {self.weight_quant_module}')
169166
self.wquantizer = self.weight_quant_module(**self.quant_config['weight'])
170167

@@ -183,8 +180,6 @@ def set_quant_config(self):
183180
self.act_quant_module = IntegerQuantizer
184181
elif quant_type == 'float-quant':
185182
self.act_quant_module = FloatQuantizer
186-
elif quant_type == 'hif4':
187-
self.act_quant_module = HiFloat4Quantizer
188183
else:
189184
raise ValueError(
190185
f"Unsupported act quant_type: {quant_type}. "
@@ -1060,58 +1055,6 @@ def contiguous_params(self):
10601055
if not param.is_contiguous():
10611056
param.data = param.data.contiguous()
10621057

1063-
def _copy_wan22_native_checkpoint(self, src, dst):
1064-
if not isinstance(src, str) or not os.path.isdir(src):
1065-
raise RuntimeError(
1066-
'Wan2.2 official save expects a local native checkpoint directory, '
1067-
f'but got src={src!r}.'
1068-
)
1069-
if os.path.abspath(src) == os.path.abspath(dst):
1070-
raise RuntimeError(
1071-
'Wan2.2 official save path must differ from source checkpoint path '
1072-
f'(src=dst={src}).'
1073-
)
1074-
if os.path.exists(dst):
1075-
shutil.rmtree(dst)
1076-
shutil.copytree(src, dst)
1077-
logger.info(f'Copied original Wan2.2 native checkpoint from {src} to {dst}')
1078-
1079-
def _validate_wan22_native_save_structure(self, save_path, source_path=None):
1080-
if not os.path.isdir(save_path):
1081-
raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}')
1082-
1083-
required_entries = ['configuration.json', 'high_noise_model', 'low_noise_model']
1084-
missing_required = [
1085-
name for name in required_entries
1086-
if not os.path.exists(os.path.join(save_path, name))
1087-
]
1088-
if missing_required:
1089-
raise RuntimeError(
1090-
'Wan2.2 saved structure is incomplete. Missing required entries: '
1091-
f'{missing_required}. save_path={save_path}'
1092-
)
1093-
1094-
if isinstance(source_path, str) and os.path.isdir(source_path):
1095-
source_entries = set(os.listdir(source_path))
1096-
source_non_expert_entries = sorted(
1097-
name for name in source_entries
1098-
if name not in {'high_noise_model', 'low_noise_model'}
1099-
)
1100-
missing_non_expert = [
1101-
name for name in source_non_expert_entries
1102-
if not os.path.exists(os.path.join(save_path, name))
1103-
]
1104-
if missing_non_expert:
1105-
raise RuntimeError(
1106-
'Wan2.2 saved structure lost original non-expert files/directories: '
1107-
f'{missing_non_expert}. source_path={source_path}, save_path={save_path}'
1108-
)
1109-
1110-
logger.info(
1111-
f'Wan2.2 native save structure verified. '
1112-
f'top-level entries={sorted(os.listdir(save_path))}'
1113-
)
1114-
11151058
@torch.no_grad()
11161059
def save_model(self, path):
11171060
if int(os.environ['RANK']) != 0:
@@ -1135,7 +1078,7 @@ def save_model(self, path):
11351078
elif self.config.model.type in ['Wan2T2V']:
11361079
if getattr(self.model.Pipeline, '_is_wan_official', False):
11371080
src = getattr(self.model, 'pipeline_model_path', self.model.model_path)
1138-
self._copy_wan22_native_checkpoint(src, path)
1081+
self.model.copy_native_checkpoint(src, path)
11391082

11401083
self.model.Pipeline.transformer.save_pretrained(
11411084
os.path.join(path, 'high_noise_model')
@@ -1149,7 +1092,7 @@ def save_model(self, path):
11491092
os.path.join(path, 'low_noise_model')
11501093
)
11511094
logger.info('save Wan2.2 low_noise_model done --')
1152-
self._validate_wan22_native_save_structure(path, source_path=src)
1095+
self.model.validate_native_save_structure(path, source_path=src)
11531096
return
11541097

11551098
# Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.)

0 commit comments

Comments
 (0)