Skip to content

Commit 0229d2c

Browse files
committed
[Minor] Force 'fuse_wgrad_accumulation' to false for TransformerEnginer GroupedLinear
Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent b44c60a commit 0229d2c

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def _functionals_to_replace(self, value):
120120
self._functionals_to_replace = value
121121

122122
def _setup(self):
123+
if getattr(self, "fuse_wgrad_accumulation", False):
124+
warnings.warn(
125+
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
126+
"Setting fuse_wgrad_accumulation to False."
127+
)
128+
self.fuse_wgrad_accumulation = False
129+
123130
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
124131
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
125132
# self.weight0 to self.weight to run the quantizer states initialization.
@@ -131,6 +138,9 @@ def _setup(self):
131138
# Remove self.weight after setup.
132139
delattr(self, "weight")
133140

141+
# TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization
142+
# with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm.
143+
134144
def modelopt_post_restore(self, prefix: str = ""):
135145
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
136146
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning

0 commit comments

Comments
 (0)