Skip to content

Commit 282675b

Browse files
realAsmadanielkorzekwa
authored andcommitted
[Minor] Force 'fuse_wgrad_accumulation' to false for TE GroupedLinear (#814)
## What does this PR do? **Type of change:** ? Minor **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Automatically disables fuse_wgrad_accumulation when using ModelOpt quantization with Transformer Engine-based quantization paths. A warning is now displayed to notify users when this adjustment occurs. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent c17f6c2 commit 282675b

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def _functionals_to_replace(self, value):
120120
self._functionals_to_replace = value
121121

122122
def _setup(self):
123+
if getattr(self, "fuse_wgrad_accumulation", False):
124+
warnings.warn(
125+
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
126+
"Setting fuse_wgrad_accumulation to False."
127+
)
128+
self.fuse_wgrad_accumulation = False
129+
123130
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
124131
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
125132
# self.weight0 to self.weight to run the quantizer states initialization.
@@ -131,6 +138,9 @@ def _setup(self):
131138
# Remove self.weight after setup.
132139
delattr(self, "weight")
133140

141+
# TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization
142+
# with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm.
143+
134144
def modelopt_post_restore(self, prefix: str = ""):
135145
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
136146
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning

0 commit comments

Comments
 (0)