Skip to content

Commit dea2eeb

Browse files
Fix failures
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent f0bd99c commit dea2eeb

File tree

16 files changed

+93
-66
lines changed

16 files changed

+93
-66
lines changed

.github/workflows/example_tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ jobs:
6666
example: [llm_distill, llm_qat, llm_sparsity]
6767
include:
6868
- example: speculative_decoding
69-
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
69+
docker_image: "26.01"
7070
uses: ./.github/workflows/_example_tests_runner.yml
7171
secrets: inherit
7272
with:
73-
docker_image: ${{ matrix.docker_image || 'nvcr.io/nvidia/pytorch:26.01-py3' }}
73+
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3"
7474
example: ${{ matrix.example }}
7575
pip_install_extras: "[hf,dev-test]"
7676
runner: linux-amd64-gpu-l4-latest-1
@@ -83,11 +83,11 @@ jobs:
8383
example: [llm_distill, llm_qat, llm_sparsity]
8484
include:
8585
- example: speculative_decoding
86-
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
86+
docker_image: "26.01"
8787
uses: ./.github/workflows/_example_tests_runner.yml
8888
secrets: inherit
8989
with:
90-
docker_image: ${{ matrix.docker_image || 'nvcr.io/nvidia/pytorch:26.01-py3' }}
90+
docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3"
9191
example: ${{ matrix.example }}
9292
pip_install_extras: "[hf,dev-test]"
9393
runner: linux-amd64-gpu-h100-latest-2

.github/workflows/gpu_tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ jobs:
6363
fail-fast: false
6464
matrix:
6565
include:
66-
- example: py312-cuda13-gpu
66+
- example: cuda13-gpu
6767
timeout: 90
68-
- example: py312-cuda13-gpu-megatron
68+
- example: cuda13-gpu-megatron
6969
timeout: 120
7070
runs-on: linux-amd64-gpu-l4-latest-1
7171
timeout-minutes: ${{ matrix.timeout }}
@@ -89,9 +89,9 @@ jobs:
8989
fail-fast: false
9090
matrix:
9191
include:
92-
- example: py312-cuda12-gpu
92+
- example: cuda13-gpu
9393
timeout: 90
94-
- example: py312-cuda12-gpu-megatron
94+
- example: cuda13-gpu-megatron
9595
timeout: 120
9696
runs-on: linux-amd64-gpu-h100-latest-2
9797
timeout-minutes: ${{ matrix.timeout }}

.github/workflows/unit_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ jobs:
5555
with:
5656
python-version: "3.12"
5757
- name: Run unit tests (without coverage)
58-
run: pip install tox && tox -e py312-torch210-tf_latest-unit
58+
# Some issues with torch 2.10 on Windows, so using 2.9 for now
59+
run: pip install tox && tox -e py312-torch29-tf_latest-unit
5960
multi-py:
6061
if: github.event_name == 'pull_request'
6162
needs: [linux]

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""ModelOpt plugin for transformers Trainer."""
1717

18+
import contextlib
1819
import gc
1920
import json
2021
import os
@@ -100,6 +101,52 @@ class QuantizationArgumentsWithConfig(QuantizationArguments):
100101
)
101102

102103

104+
def _patch_fsdp2_post_backward():
105+
"""Patch FSDP2 ``post_backward`` to handle mixed-precision gradient dtypes.
106+
107+
FSDP2 with bf16 mixed precision upcasts bf16 parameters to fp32 for optimizer
108+
precision, while gradients are reduced in bf16. In PyTorch >= 2.6, assigning a
109+
bf16 gradient to a fp32 parameter raises a ``RuntimeError`` due to the
110+
``grad_dtype`` check, and the fused Adam optimizer also rejects mixed dtypes.
111+
112+
This patch wraps ``FSDPParamGroup.post_backward`` to:
113+
1. Set ``grad_dtype=None`` on sharded params before reduction (allowing bf16 assignment).
114+
2. Cast gradients to match parameter dtype after reduction (so the optimizer sees matching dtypes).
115+
116+
.. note::
117+
This is a workaround. The proper fix should come from PyTorch's FSDP2
118+
``foreach_reduce`` (which should cast gradients to match the parameter dtype)
119+
or from accelerate (which should set ``grad_dtype`` when it upcasts params).
120+
Remove this once the upstream fix is available.
121+
"""
122+
try:
123+
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
124+
except ImportError:
125+
return
126+
127+
if hasattr(FSDPParamGroup, "_modelopt_original_post_backward"):
128+
return # Already patched
129+
130+
FSDPParamGroup._modelopt_original_post_backward = FSDPParamGroup.post_backward
131+
132+
@torch.no_grad()
133+
def _patched_post_backward(self):
134+
# Allow bf16 gradients to be assigned to fp32 parameters
135+
for fsdp_param in self.fsdp_params:
136+
with contextlib.suppress(AttributeError):
137+
fsdp_param.sharded_param.grad_dtype = None
138+
139+
self._modelopt_original_post_backward()
140+
141+
# Cast gradients to parameter dtype so the optimizer sees matching dtypes
142+
for fsdp_param in self.fsdp_params:
143+
sp = fsdp_param.sharded_param
144+
if sp.grad is not None and sp.grad.dtype != sp.dtype:
145+
sp.grad = sp.grad.to(sp.dtype)
146+
147+
FSDPParamGroup.post_backward = _patched_post_backward
148+
149+
103150
def check_awq_smoothquant(quant_cfg):
104151
# TODO: Remove this once deepspeed for AWQ and SmoothQuant is added
105152
"""Get the quantization type from the configuration."""
@@ -337,6 +384,7 @@ def _patch_accelerate_for_fsdp2_fix(self):
337384
is causing issues with quantized models since quantization modules adds buffers which are not sharded.
338385
This patch hides the buffers added by quantization modules from the original accelerate prepare.
339386
"""
387+
_patch_fsdp2_post_backward()
340388

341389
def _modelopt_prepare(self, *args, **kwargs):
342390
if not self.is_fsdp2:

tests/_test_utils/import_helper.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
15+
import ctypes
1616
import importlib.metadata
17+
import os
1718
import shutil
1819

1920
import pytest
@@ -28,6 +29,23 @@ def skip_if_no_tensorrt():
2829
except (AssertionError, ImportError) as e:
2930
pytest.skip(f"{e}", allow_module_level=True)
3031

32+
# Also verify that ORT's TensorRT EP can actually load its native library.
33+
# The tensorrt Python package may be installed, but ORT's provider shared library
34+
# (libonnxruntime_providers_tensorrt.so) could fail to load due to CUDA version
35+
# mismatches (e.g., ORT built for CUDA 12 running on a CUDA 13 system).
36+
try:
37+
import onnxruntime
38+
39+
ort_capi_dir = os.path.join(os.path.dirname(onnxruntime.__file__), "capi")
40+
trt_provider_lib = os.path.join(ort_capi_dir, "libonnxruntime_providers_tensorrt.so")
41+
if os.path.isfile(trt_provider_lib):
42+
ctypes.CDLL(trt_provider_lib)
43+
except OSError as e:
44+
pytest.skip(
45+
f"ORT TensorRT EP native library cannot be loaded: {e}",
46+
allow_module_level=True,
47+
)
48+
3149

3250
def skip_if_no_trtexec():
3351
if not shutil.which("trtexec"):
@@ -43,19 +61,12 @@ def skip_if_no_libcudnn():
4361
pytest.skip(f"{e}!", allow_module_level=True)
4462

4563

46-
def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool = False):
64+
def skip_if_no_megatron(*, te_required: bool = True, mamba_required: bool = False):
4765
try:
4866
import megatron # noqa: F401
4967
except ImportError:
5068
pytest.skip("megatron not available", allow_module_level=True)
5169

52-
try:
53-
import apex # noqa: F401
54-
55-
has_apex = True
56-
except ImportError:
57-
has_apex = False
58-
5970
try:
6071
import transformer_engine # noqa: F401
6172

@@ -70,8 +81,8 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool
7081
except ImportError:
7182
has_mamba = False
7283

73-
if apex_or_te_required and not has_apex and not has_te:
74-
pytest.skip("Apex or TE required for Megatron test", allow_module_level=True)
84+
if te_required and not has_te:
85+
pytest.skip("TE required for Megatron test", allow_module_level=True)
7586

7687
if mamba_required and not has_mamba:
7788
pytest.skip("Mamba required for Megatron test", allow_module_level=True)

tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
from functools import partial
1717

1818
import torch
19-
from _test_utils.import_helper import skip_if_no_megatron
20-
21-
skip_if_no_megatron(apex_or_te_required=True)
22-
2319
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2420
from _test_utils.torch.megatron.models import get_mcore_gpt_model
2521
from _test_utils.torch.megatron.utils import run_mcore_inference_with_dummy_input

tests/gpu_megatron/torch/export/test_unified_export_megatron.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,11 @@
2121
import pytest
2222
import torch
2323
import transformers
24-
from _test_utils.import_helper import skip_if_no_megatron
2524
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2625
from _test_utils.torch.megatron.models import get_mcore_gpt_model
2726
from _test_utils.torch.megatron.utils import get_forward
2827
from _test_utils.torch.transformers_models import create_tiny_llama_dir
2928

30-
skip_if_no_megatron(apex_or_te_required=True)
31-
3229
import modelopt.torch.quantization as mtq
3330
import modelopt.torch.speculative as mtsp
3431
from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf

tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,12 @@
1818

1919
import pytest
2020
import torch
21-
from _test_utils.import_helper import skip_if_no_megatron
2221
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2322
from _test_utils.torch.megatron.models import get_mcore_gpt_model
2423

2524
import modelopt.torch.quantization as mtq
2625
from modelopt.torch.export import export_mcore_gpt_to_hf_vllm_fq
2726

28-
skip_if_no_megatron(apex_or_te_required=True)
29-
3027

3128
def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
3229
"""Test megatron-core model export for vLLM with fake quantization."""

tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717

1818
import pytest
1919
import torch
20-
from _test_utils.import_helper import skip_if_no_megatron
21-
22-
skip_if_no_megatron(apex_or_te_required=True)
23-
2420
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2521
from _test_utils.torch.megatron.models import get_mcore_gpt_model
2622
from _test_utils.torch.megatron.utils import run_mcore_inference

tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from _test_utils.import_helper import skip_if_no_megatron
1919

20-
skip_if_no_megatron(apex_or_te_required=True, mamba_required=True)
20+
skip_if_no_megatron(mamba_required=True)
2121

2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2323
from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model

0 commit comments

Comments
 (0)