Skip to content

Commit 358ee83

Browse files
updated bmm and matmul for GPT-OSS (#999)
### What does this PR do? This PR fixes maximum recursion bug for GPT-OSS. It replaces `torch._bmm` and `torch.matmul` with `torch.ops.aten.bmm` and `torch.ops.aten.matmul` to avoid recursion ### Usage ```shell Docker image: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc4 [Repro Steps]: [gpt-oss] Step1: accelerate launch --config_file configs/zero3.yaml sft.py --config configs/sft_full.yaml --model_name_or_path openai/gpt-oss-20b --output_dir /tmp/pytest-of-root/pytest-0/test_gpt_oss_complete_pipeline0/gpt-oss-20b-sft Step 1 completed: SFT checkpoint at /tmp/pytest-of-root/pytest-0/test_gpt_oss_complete_pipeline0/gpt-oss-20b-sft Step2: accelerate launch --config_file configs/zero3.yaml sft.py --config configs/sft_full.yaml --model_name_or_path /tmp/pytest-of-root/pytest-0/test_gpt_oss_complete_pipeline0/gpt-oss-20b-sft --quant_cfg MXFP4_MLP_WEIGHT_ONLY_CFG --output_dir /tmp/pytest-of-root/pytest-0/test_gpt_oss_complete_pipeline0/gpt-oss-20b-qat ``` ### Testing ``` python pytest tests/examples/gpt_oss/test_gpt_oss_qat.py ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`:N/A - Did you write any new necessary tests?: N/A (test already exist) - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed a recursion-related instability in attention quantization that could cause errors during certain matrix operations, improving reliability. * **Performance** * Improved handling of batched and matrix-multiplication operations under quantization for more consistent and efficient runtime behavior, including better support for outputs specified by callers. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent a5d46ff commit 358ee83

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,15 +1040,22 @@ def _setup(self):
10401040

10411041
@property
10421042
def functionals_to_replace(self):
1043-
def _quantized_bmm(batch1, batch2):
1043+
# Use torch.ops.aten to bypass Python dispatch and avoid RecursionError
1044+
# (torch.matmul / __matmul__ can dispatch to each other)
1045+
_aten_bmm = torch.ops.aten.bmm
1046+
_aten_matmul = torch.ops.aten.matmul
1047+
1048+
def _quantized_bmm(batch1, batch2, *, out=None):
10441049
batch1 = self.down_proj_input_quantizer(batch1) if self._down_proj_mul else batch1
10451050
self._down_proj_mul = not self._down_proj_mul # toggle the flag
1046-
return torch._bmm(batch1, batch2)
1051+
if out is not None:
1052+
return torch.ops.aten.bmm.out(batch1, batch2, out=out)
1053+
return _aten_bmm(batch1, batch2)
10471054

10481055
def _tensor_matmul(self_t, other):
10491056
self_t = self.down_proj_input_quantizer(self_t) if self._down_proj_mul else self_t
10501057
self._down_proj_mul = not self._down_proj_mul
1051-
return torch.matmul(self_t, other)
1058+
return _aten_matmul(self_t, other)
10521059

10531060
return [
10541061
(torch, "bmm", _quantized_bmm),

0 commit comments

Comments
 (0)