diff --git a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py index 21eed5b82e0..2f07547566b 100644 --- a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py +++ b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py @@ -22,3 +22,4 @@ QuantModuleRegistry.register({nn.BatchNorm1d: "nn.BatchNorm1d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm2d: "nn.BatchNorm2d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm3d: "nn.BatchNorm3d"})(QuantInputBase) +QuantModuleRegistry.register({nn.SyncBatchNorm: "nn.SyncBatchNorm"})(QuantInputBase) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 2adcf1f60f0..cced0e3729f 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -45,6 +45,9 @@ - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml b/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml index 432b970339c..36b49d42912 100644 --- a/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml +++ b/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml @@ -57,6 +57,9 @@ quantize: - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/modelopt_recipes/huggingface/step3p5/Step3.5-Flash/ptq/nvfp4-mlp-only.yaml b/modelopt_recipes/huggingface/step3p5/Step3.5-Flash/ptq/nvfp4-mlp-only.yaml index d0adbe00479..2af21bb016f 100644 --- a/modelopt_recipes/huggingface/step3p5/Step3.5-Flash/ptq/nvfp4-mlp-only.yaml +++ b/modelopt_recipes/huggingface/step3p5/Step3.5-Flash/ptq/nvfp4-mlp-only.yaml @@ -71,6 +71,9 @@ quantize: - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/tests/unit/torch/quantization/test_quant_batchnorm.py b/tests/unit/torch/quantization/test_quant_batchnorm.py index c55b4b0b0e4..ef9bdd87222 100644 --- a/tests/unit/torch/quantization/test_quant_batchnorm.py +++ b/tests/unit/torch/quantization/test_quant_batchnorm.py @@ -34,6 +34,7 @@ class TestQuantBatchNormND: (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_no_quant(self, original_cls, input_shape): @@ -60,6 +61,7 @@ def test_no_quant(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_tensor(self, original_cls, input_shape): @@ -86,6 +88,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_channel(self, original_cls, input_shape):