Skip to content

Commit 7d25dc9

Browse files
authored
change quantization calibration method (#300)
* change quantization calibration method * fixe unit test * change set * fix unit test
1 parent d145419 commit 7d25dc9

3 files changed

Lines changed: 22 additions & 16 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,10 @@ quantization: ''
319319
quantization_local_shard_count: -1
320320
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
321321
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
322-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
323-
quantization_calibration_method: "absmax"
322+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
323+
weight_quantization_calibration_method: "absmax"
324+
act_quantization_calibration_method: "absmax"
325+
bwd_quantization_calibration_method: "absmax"
324326
qwix_module_path: ".*"
325327

326328
# Eval model on per eval_every steps. -1 means don't eval.

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ def get_fp8_config(cls, config: HyperParameters):
302302
act_qtype=jnp.float8_e4m3fn,
303303
bwd_qtype=jnp.float8_e5m2,
304304
disable_channelwise_axes=True, # per_tensor calibration
305-
weight_calibration_method=config.quantization_calibration_method,
306-
act_calibration_method=config.quantization_calibration_method,
307-
bwd_calibration_method=config.quantization_calibration_method,
305+
weight_calibration_method=config.weight_quantization_calibration_method,
306+
act_calibration_method=config.act_quantization_calibration_method,
307+
bwd_calibration_method=config.bwd_quantization_calibration_method,
308308
op_names=("dot_general", "einsum"),
309309
),
310310
qwix.QtRule(
@@ -313,9 +313,9 @@ def get_fp8_config(cls, config: HyperParameters):
313313
act_qtype=jnp.float8_e4m3fn,
314314
bwd_qtype=jnp.float8_e4m3fn,
315315
disable_channelwise_axes=True, # per_tensor calibration
316-
weight_calibration_method=config.quantization_calibration_method,
317-
act_calibration_method=config.quantization_calibration_method,
318-
bwd_calibration_method=config.quantization_calibration_method,
316+
weight_calibration_method=config.weight_quantization_calibration_method,
317+
act_calibration_method=config.act_quantization_calibration_method,
318+
bwd_calibration_method=config.bwd_quantization_calibration_method,
319319
op_names=("conv_general_dilated"),
320320
),
321321
]

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ def create_real_rule_instance(*args, **kwargs):
332332
config_fp8_full = Mock(spec=HyperParameters)
333333
config_fp8_full.use_qwix_quantization = True
334334
config_fp8_full.quantization = "fp8_full"
335-
config_fp8_full.quantization_calibration_method = "absmax"
335+
config_fp8_full.weight_quantization_calibration_method = "fixed,-224,224"
336+
config_fp8_full.act_quantization_calibration_method = "fixed,-224,224"
337+
config_fp8_full.bwd_quantization_calibration_method = "absmax"
336338
config_fp8_full.qwix_module_path = ".*"
337339
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
338340
self.assertIsNotNone(provider_fp8_full)
@@ -343,9 +345,9 @@ def create_real_rule_instance(*args, **kwargs):
343345
act_qtype=jnp.float8_e4m3fn,
344346
bwd_qtype=jnp.float8_e5m2,
345347
disable_channelwise_axes=True, # per_tensor calibration
346-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
347-
act_calibration_method=config_fp8_full.quantization_calibration_method,
348-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
348+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
349+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
350+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
349351
op_names=("dot_general", "einsum"),
350352
),
351353
call(
@@ -354,9 +356,9 @@ def create_real_rule_instance(*args, **kwargs):
354356
act_qtype=jnp.float8_e4m3fn,
355357
bwd_qtype=jnp.float8_e4m3fn,
356358
disable_channelwise_axes=True, # per_tensor calibration
357-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
358-
act_calibration_method=config_fp8_full.quantization_calibration_method,
359-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
359+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
360+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
361+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
360362
op_names=("conv_general_dilated"),
361363
),
362364
]
@@ -381,7 +383,9 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
381383
mock_config.quantization = "fp8_full"
382384
mock_config.qwix_module_path = ".*"
383385
mock_config.per_device_batch_size = 1
384-
mock_config.quantization_calibration_method = "absmax"
386+
mock_config.weight_quantization_calibration_method = "fixed,-224,224"
387+
mock_config.act_quantization_calibration_method = "fixed,-224,224"
388+
mock_config.bwd_quantization_calibration_method = "absmax"
385389

386390
mock_model = Mock(spec=WanModel)
387391
mock_pipeline = Mock()

0 commit comments

Comments
 (0)