Skip to content

Commit ffd3c8e

Browse files
committed
Register SyncBatchNorm module for quantization.
Signed-off-by: Bryce Ferenczi <bryce.ferenczi@Arkeus.com>
1 parent 2c52e7b commit ffd3c8e

4 files changed

Lines changed: 10 additions & 0 deletions

File tree

modelopt/torch/quantization/nn/modules/quant_batchnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
QuantModuleRegistry.register({nn.BatchNorm1d: "nn.BatchNorm1d"})(QuantInputBase)
2323
QuantModuleRegistry.register({nn.BatchNorm2d: "nn.BatchNorm2d"})(QuantInputBase)
2424
QuantModuleRegistry.register({nn.BatchNorm3d: "nn.BatchNorm3d"})(QuantInputBase)
25+
QuantModuleRegistry.register({nn.SyncBatchNorm: "nn.SyncBatchNorm"})(QuantInputBase)

modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
- parent_class: 'nn.BatchNorm3d'
4646
quantizer_name: '*'
4747
enable: false
48+
- parent_class: 'nn.SyncBatchNorm'
49+
quantizer_name: '*'
50+
enable: false
4851
- parent_class: 'nn.LeakyReLU'
4952
quantizer_name: '*'
5053
enable: false

modelopt_recipes/huggingface/step3p5/Step3.5-Flash/ptq/nvfp4-mlp-only.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ quantize:
7171
- parent_class: 'nn.BatchNorm3d'
7272
quantizer_name: '*'
7373
enable: false
74+
- parent_class: 'nn.SyncBatchNorm'
75+
quantizer_name: '*'
76+
enable: false
7477
- parent_class: 'nn.LeakyReLU'
7578
quantizer_name: '*'
7679
enable: false

tests/unit/torch/quantization/test_quant_batchnorm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class TestQuantBatchNormND:
3434
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
3535
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
3636
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
37+
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
3738
],
3839
)
3940
def test_no_quant(self, original_cls, input_shape):
@@ -60,6 +61,7 @@ def test_no_quant(self, original_cls, input_shape):
6061
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
6162
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
6263
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
64+
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
6365
],
6466
)
6567
def test_fake_quant_per_tensor(self, original_cls, input_shape):
@@ -86,6 +88,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape):
8688
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
8789
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
8890
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
91+
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
8992
],
9093
)
9194
def test_fake_quant_per_channel(self, original_cls, input_shape):

0 commit comments

Comments
 (0)