-
Notifications
You must be signed in to change notification settings - Fork 467
Expand file tree
/
Copy pathtest_quant_batchnorm.py
More file actions
115 lines (101 loc) · 4.31 KB
/
Copy pathtest_quant_batchnorm.py
File metadata and controls
115 lines (101 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of QuantBatchNorm module."""
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelopt.torch.quantization import tensor_quant
from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial
from modelopt.torch.quantization.nn import QuantModuleRegistry
NUM_CHANNELS = 3
class TestQuantBatchNormND:
@pytest.mark.parametrize(
("original_cls", "input_shape"),
[
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
],
)
def test_no_quant(self, original_cls, input_shape):
batchnorm_object = original_cls(NUM_CHANNELS, affine=True)
quant_batchnorm_object = QuantModuleRegistry.convert(batchnorm_object)
quant_batchnorm_object.input_quantizer.disable()
test_input = torch.randn(input_shape)
out1 = quant_batchnorm_object(test_input)
out2 = F.batch_norm(
test_input,
quant_batchnorm_object.running_mean,
quant_batchnorm_object.running_var,
quant_batchnorm_object.weight,
quant_batchnorm_object.bias,
training=True,
)
assert torch.allclose(out1, out2, rtol=0, atol=0)
@pytest.mark.parametrize(
("original_cls", "input_shape"),
[
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
],
)
def test_fake_quant_per_tensor(self, original_cls, input_shape):
batchnorm_object = original_cls(NUM_CHANNELS, affine=True)
quant_batchnorm_object = QuantModuleRegistry.convert(batchnorm_object)
test_input = torch.randn(input_shape)
quant_input = tensor_quant.fake_tensor_quant(test_input, torch.max(torch.abs(test_input)))
out1 = quant_batchnorm_object(test_input)
out2 = F.batch_norm(
quant_input,
quant_batchnorm_object.running_mean,
quant_batchnorm_object.running_var,
quant_batchnorm_object.weight,
quant_batchnorm_object.bias,
training=True,
)
assert torch.allclose(out1, out2, rtol=0, atol=0)
@pytest.mark.parametrize(
("original_cls", "input_shape"),
[
(nn.BatchNorm1d, (2, NUM_CHANNELS, 8)),
(nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)),
(nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)),
(nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)),
],
)
def test_fake_quant_per_channel(self, original_cls, input_shape):
batchnorm_object = original_cls(NUM_CHANNELS, affine=True)
quant_batchnorm_object = QuantModuleRegistry.convert(batchnorm_object)
set_quantizer_attributes_partial(quant_batchnorm_object, lambda name: True, {"axis": (1)})
test_input = torch.randn(input_shape)
reduce_dims = list(range(len(test_input.shape)))
reduce_dims.pop(1)
quant_input = tensor_quant.fake_tensor_quant(
test_input, torch.abs(test_input).amax(dim=reduce_dims, keepdim=True)
)
out1 = quant_batchnorm_object(test_input)
out2 = F.batch_norm(
quant_input,
quant_batchnorm_object.running_mean,
quant_batchnorm_object.running_var,
quant_batchnorm_object.weight,
quant_batchnorm_object.bias,
training=True,
)
assert torch.allclose(out1, out2, rtol=0, atol=0)