Skip to content

Commit 1f5d6f6

Browse files
author
gushiqiao
committed
Support wan t2v quantization and fix some bugs.
1 parent 36026bf commit 1f5d6f6

20 files changed

Lines changed: 395 additions & 37 deletions

File tree

assets/wan_t2v/calib/samples.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[
2+
{
3+
"prompt": "A cat walks on the grass, realistic",
4+
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
5+
}
6+
]

assets/wan_t2v/eval/samples.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[
2+
{
3+
"prompt": "A cat walks on the grass, realistic",
4+
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
5+
}
6+
]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: WanT2V
5+
path: /path/to/wan_t2v
6+
torch_dtype: auto
7+
calib:
8+
name: custom_t2v
9+
download: False
10+
path: ../assets/wan_t2v/calib/
11+
sample_steps: 20
12+
bs: 1
13+
target_height: 480
14+
target_width: 832
15+
num_frames: 81
16+
guidance_scale: 5.0
17+
seed: *seed
18+
eval:
19+
eval_pos: [transformed, fake_quant]
20+
type: video_gen
21+
name: custom_t2v
22+
download: False
23+
path: ../assets/wan_t2v/calib/
24+
bs: 1
25+
target_height: 480
26+
target_width: 832
27+
num_frames: 81
28+
guidance_scale: 5.0
29+
output_video_path: ./output_videos_awq/
30+
quant:
31+
video_gen:
32+
method: Awq
33+
weight:
34+
bit: 6
35+
symmetric: True
36+
granularity: per_channel
37+
group_size: -1
38+
act:
39+
bit: 6
40+
symmetric: True
41+
granularity: per_token
42+
special:
43+
trans: True
44+
trans_version: v2
45+
weight_clip: True
46+
clip_sym: True
47+
save:
48+
save_trans: False
49+
save_fake: False
50+
save_path: /path/to/save/
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: WanT2V
5+
path: /path/to/wan_t2v
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed, fake_quant]
9+
type: video_gen
10+
name: custom_t2v
11+
download: False
12+
path: /mtc/gushiqiao/llmc_video_new/llmc/assets/wan_t2v/
13+
bs: 1
14+
target_height: 480
15+
target_width: 832
16+
num_frames: 81
17+
guidance_scale: 5.0
18+
output_video_path: ./output_videos_sq/
19+
quant:
20+
video_gen:
21+
method: RTN
22+
weight:
23+
bit: 6
24+
symmetric: True
25+
granularity: per_channel
26+
act:
27+
bit: 6
28+
symmetric: True
29+
granularity: per_token
30+
save:
31+
save_trans: False
32+
save_fake: False
33+
save_path: /path/to/save/
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: WanT2V
5+
path: /path/to/wan_t2v
6+
torch_dtype: auto
7+
calib:
8+
name: custom_t2v
9+
download: False
10+
path: ../assets/wan_t2v/calib/
11+
sample_steps: 20
12+
bs: 1
13+
target_height: 480
14+
target_width: 832
15+
num_frames: 81
16+
guidance_scale: 5.0
17+
seed: *seed
18+
eval:
19+
eval_pos: [transformed, fake_quant]
20+
type: video_gen
21+
name: custom_t2v
22+
download: False
23+
path: ../assets/wan_t2v/calib/
24+
bs: 1
25+
target_height: 480
26+
target_width: 832
27+
num_frames: 81
28+
guidance_scale: 5.0
29+
output_video_path: ./output_videos_sq/
30+
quant:
31+
video_gen:
32+
method: SmoothQuant
33+
weight:
34+
bit: 6
35+
symmetric: True
36+
granularity: per_channel
37+
act:
38+
bit: 6
39+
symmetric: True
40+
granularity: per_token
41+
special:
42+
alpha: 0.7
43+
save:
44+
save_trans: False
45+
save_fake: False
46+
save_path: /path/to/save/

llmc/__main__.py

100644100755
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ def main(config):
3030
logger.info(f'model: {model}')
3131
logger.info(f'tokenizer: {model.get_tokenizer()}')
3232

33+
eval_list = get_eval_list(model, config)
34+
eval_model(model, None, eval_list, eval_pos='pretrain')
35+
3336
blockwise_opts = []
3437
modalities, modality_configs = get_modality(config)
38+
3539
for modality, modality_config in zip(modalities, modality_configs):
3640
model.set_modality(modality)
37-
eval_list = get_eval_list(model, config)
38-
eval_model(model, None, eval_list, eval_pos='pretrain')
3941
if not config.get('calib', False):
4042
blockwise_opt = ALGO_REGISTRY[modality_config.method](
4143
model,

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def set_quant_config(self):
249249
self.config['model']['type'] in ['Opt', 'Llama']
250250
), 'Please set online_rotate=False'
251251
self.fp32_had = special_config.get('fp32_had', False)
252-
self.hidden_size = self.model.model_config.hidden_size
253-
self.set_model_config()
252+
if self.quant_config.modality != 'video_gen':
253+
self.set_model_config()
254254
self.modality = self.quant_config.modality
255255
logger.info(f'self.quant_objects : {self.quant_config.modality}')
256256

@@ -373,12 +373,12 @@ def block_forward(self, block, input_data=None):
373373
if torch.is_tensor(self.input['kwargs'][i][k]):
374374
self.input['kwargs'][i][k] = self.input['kwargs'][i][k].to(
375375
device=next(block.parameters()).device
376-
) # noqa
376+
)
377377
if isinstance(self.input['kwargs'][i][k], tuple):
378378
self.input['kwargs'][i][k] = tuple(
379379
tmp.to(device=next(block.parameters()).device)
380380
for tmp in self.input['kwargs'][i][k]
381-
) # noqa
381+
)
382382
with torch.no_grad():
383383
out = block(input_data[i], **self.input['kwargs'][i])
384384
if isinstance(out, tuple):
@@ -474,9 +474,10 @@ def block_transform(self, block, input_feat, block_kwargs):
474474
inspect_has_kwargs = subset['has_kwargs']
475475
if inspect_has_kwargs:
476476
if 'sub_keys' in subset:
477-
subset_kwargs = [
478-
{k: block_kwargs[0][v] for k, v in subset['sub_keys'].items()}
479-
]
477+
subset_kwargs = []
478+
for i in range(len(block_kwargs)):
479+
for k, v in subset['sub_keys'].items():
480+
subset_kwargs.append({k: block_kwargs[i][v]})
480481
else:
481482
subset_kwargs = block_kwargs
482483
else:
@@ -746,7 +747,10 @@ def shift_ln_fcs(self, ln, fcs, shifts):
746747
def scale_ln_fcs(self, ln, fcs, scales):
747748
if not isinstance(fcs, list):
748749
fcs = [fcs]
750+
749751
scales = scales.to(ln.weight.device)
752+
scales = scales.to(ln.weight.dtype)
753+
750754
ln.weight.div_(scales)
751755

752756
if hasattr(ln, 'bias') and ln.bias is not None:
@@ -954,6 +958,13 @@ def deploy(self, quant_format, keep_device=False):
954958
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
955959
keep_device=keep_device,
956960
)
961+
if self.modality == 'video_gen':
962+
self.model.replace_video_gen_module_all(
963+
module,
964+
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
965+
keep_device=keep_device,
966+
)
967+
957968
self.set_non_linear_mode(quant_format, self.model.model, False)
958969

959970
if self.quant_kvcache:
@@ -973,8 +984,11 @@ def deploy(self, quant_format, keep_device=False):
973984

974985
@torch.no_grad()
975986
def copy_tokenizer(self, path):
976-
self.model.tokenizer.save_pretrained(path)
977-
logger.info('copy tokenizer done --')
987+
if self.model.tokenizer is not None:
988+
self.model.tokenizer.save_pretrained(path)
989+
logger.info('copy tokenizer done --')
990+
else:
991+
logger.info('no tokenizer, skip --')
978992

979993
@torch.no_grad()
980994
def contiguous_params(self):

llmc/compression/quantization/dgq.py

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def set_quant_config(self):
4343
self.quant_out = True
4444
else:
4545
self.quant_out = False
46-
self.quant_type = self.quant_config.get('quant_type', 'int_quant')
47-
assert self.quant_type != 'float_quant', 'DGQ do not support Float quant now.'
46+
self.quant_type = self.quant_config.get('quant_type', 'int-quant')
47+
assert self.quant_type != 'float-quant', 'DGQ do not support Float quant now.'
4848
# set weight quant config
4949
self.wquantizer_w4 = IntegerQuantizer(**self.quant_config['weight']['w_1'])
5050
perchannel_setting = {

llmc/compression/quantization/module_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,88 @@ def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias):
4040
return y
4141

4242

43+
class FakeAffineLayerNorm(nn.Module):
44+
def __init__(self, norm, shape):
45+
super().__init__()
46+
self.register_parameter('weight', nn.Parameter(torch.ones(shape, dtype=torch.float)))
47+
self.register_parameter('bias', nn.Parameter(torch.ones(shape, dtype=torch.float)))
48+
self.norm = norm
49+
50+
def forward(self, x):
51+
return self.norm(x)
52+
53+
def extra_repr(self):
54+
return f'affine=True (emulated), shape={self.weight.shape}'
55+
56+
57+
class LlmcWanTransformerBlock(nn.Module):
58+
def __init__(self, module):
59+
super().__init__()
60+
61+
self.norm1 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1])
62+
self.attn1 = module.attn1
63+
64+
self.attn2 = module.attn2
65+
self.norm2 = module.norm2
66+
67+
self.norm3 = FakeAffineLayerNorm(module.norm1, module.scale_shift_table.shape[-1])
68+
self.ffn = module.ffn
69+
self.scale_shift_table = module.scale_shift_table
70+
71+
def forward(
72+
self,
73+
hidden_states,
74+
encoder_hidden_states,
75+
temb,
76+
rotary_emb,
77+
):
78+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
79+
self.scale_shift_table + temb
80+
).chunk(6, dim=1)
81+
82+
# 1. Self-attention
83+
norm1_weight = (1 + scale_msa) * self.norm1.weight
84+
norm1_bias = shift_msa * self.norm1.bias
85+
86+
norm_hidden_states = (
87+
self.norm1(hidden_states.float()) * norm1_weight + norm1_bias
88+
).type_as(hidden_states)
89+
attn_output = self.attn1(
90+
hidden_states=norm_hidden_states, rotary_emb=rotary_emb
91+
)
92+
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(
93+
hidden_states
94+
)
95+
96+
# 2. Cross-attention
97+
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
98+
attn_output = self.attn2(
99+
hidden_states=norm_hidden_states,
100+
encoder_hidden_states=encoder_hidden_states,
101+
)
102+
hidden_states = hidden_states + attn_output
103+
104+
# 3. Feed-forward
105+
norm3_weight = (1 + c_scale_msa) * self.norm3.weight
106+
norm3_bias = c_shift_msa * self.norm3.bias
107+
108+
norm_hidden_states = (
109+
self.norm3(hidden_states.float()) * norm3_weight + norm3_bias
110+
).type_as(hidden_states)
111+
ff_output = self.ffn(norm_hidden_states)
112+
hidden_states = (
113+
hidden_states.float() + ff_output.float() * c_gate_msa
114+
).type_as(hidden_states)
115+
116+
return hidden_states
117+
118+
@classmethod
119+
@torch.no_grad()
120+
def new(cls, module):
121+
new_module = cls(module)
122+
return new_module
123+
124+
43125
class LlmcFp8Linear(nn.Module):
44126
def __init__(self, in_features, out_features, bias, block_size):
45127
super().__init__()

llmc/compression/quantization/spqr.py

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def add_quant_config(self):
5050
scale_config = special_config['scale']
5151
zero_config = special_config['zero']
5252

53-
self.quant_type = self.quant_config.get('quant_type', 'int_quant')
54-
assert self.quant_type != 'float_quant', 'SPQR do not support Float quant now.'
53+
self.quant_type = self.quant_config.get('quant_type', 'int-quant')
54+
assert self.quant_type != 'float-quant', 'SPQR do not support Float quant now.'
5555
self.scale_quantizer = IntegerQuantizer(**scale_config)
5656
self.zero_quantizer = IntegerQuantizer(**zero_config)
5757
self.Q = IntegerQuantizer(

0 commit comments

Comments
 (0)