@@ -34,6 +34,7 @@ class TestQuantBatchNormND:
3434 (nn .BatchNorm1d , (2 , NUM_CHANNELS , 8 )),
3535 (nn .BatchNorm2d , (2 , NUM_CHANNELS , 8 , 8 )),
3636 (nn .BatchNorm3d , (2 , NUM_CHANNELS , 8 , 8 , 8 )),
37+ (nn .SyncBatchNorm , (2 , NUM_CHANNELS , 8 , 8 )),
3738 ],
3839 )
3940 def test_no_quant (self , original_cls , input_shape ):
@@ -60,6 +61,7 @@ def test_no_quant(self, original_cls, input_shape):
6061 (nn .BatchNorm1d , (2 , NUM_CHANNELS , 8 )),
6162 (nn .BatchNorm2d , (2 , NUM_CHANNELS , 8 , 8 )),
6263 (nn .BatchNorm3d , (2 , NUM_CHANNELS , 8 , 8 , 8 )),
64+ (nn .SyncBatchNorm , (2 , NUM_CHANNELS , 8 , 8 )),
6365 ],
6466 )
6567 def test_fake_quant_per_tensor (self , original_cls , input_shape ):
@@ -86,6 +88,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape):
8688 (nn .BatchNorm1d , (2 , NUM_CHANNELS , 8 )),
8789 (nn .BatchNorm2d , (2 , NUM_CHANNELS , 8 , 8 )),
8890 (nn .BatchNorm3d , (2 , NUM_CHANNELS , 8 , 8 , 8 )),
91+ (nn .SyncBatchNorm , (2 , NUM_CHANNELS , 8 , 8 )),
8992 ],
9093 )
9194 def test_fake_quant_per_channel (self , original_cls , input_shape ):
0 commit comments