Skip to content

Commit 9ffde26

Browse files
authored
fix qwix bug (#271)
* fix qwix bug * fix unit test
1 parent 0c10d44 commit 9ffde26

2 files changed

Lines changed: 0 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def get_fp8_config(cls, config: HyperParameters):
297297
weight_qtype=jnp.float8_e4m3fn,
298298
act_qtype=jnp.float8_e4m3fn,
299299
bwd_qtype=jnp.float8_e5m2,
300-
bwd_use_original_residuals=True,
301300
disable_channelwise_axes=True, # per_tensor calibration
302301
weight_calibration_method=config.quantization_calibration_method,
303302
act_calibration_method=config.quantization_calibration_method,
@@ -309,7 +308,6 @@ def get_fp8_config(cls, config: HyperParameters):
309308
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
310309
act_qtype=jnp.float8_e4m3fn,
311310
bwd_qtype=jnp.float8_e4m3fn,
312-
bwd_use_original_residuals=True,
313311
disable_channelwise_axes=True, # per_tensor calibration
314312
weight_calibration_method=config.quantization_calibration_method,
315313
act_calibration_method=config.quantization_calibration_method,

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ def create_real_rule_instance(*args, **kwargs):
342342
weight_qtype=jnp.float8_e4m3fn,
343343
act_qtype=jnp.float8_e4m3fn,
344344
bwd_qtype=jnp.float8_e5m2,
345-
bwd_use_original_residuals=True,
346345
disable_channelwise_axes=True, # per_tensor calibration
347346
weight_calibration_method=config_fp8_full.quantization_calibration_method,
348347
act_calibration_method=config_fp8_full.quantization_calibration_method,
@@ -354,7 +353,6 @@ def create_real_rule_instance(*args, **kwargs):
354353
weight_qtype=jnp.float8_e4m3fn,
355354
act_qtype=jnp.float8_e4m3fn,
356355
bwd_qtype=jnp.float8_e4m3fn,
357-
bwd_use_original_residuals=True,
358356
disable_channelwise_axes=True, # per_tensor calibration
359357
weight_calibration_method=config_fp8_full.quantization_calibration_method,
360358
act_calibration_method=config_fp8_full.quantization_calibration_method,

0 commit comments

Comments
 (0)