Skip to content

Commit a076e6c

Browse files
Make GPU tests >2x faster by reusing spawn processes between tests (#958)
### What does this PR do? Type of change: Test infra improvement <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> * Instead of spawning multiprocesses per test case (adds ~10s overhead x 150 tests), we spawn once per file (gpu_megatron) or session (gpu) to make the tests much much faster * Remove some unused / unnecessary tests * Fix and enable Minitron NAS test on Hybrid models for Mcore 0.16 Test time (`pytest tests/<test_type>`) on 2x RTXPro 6000 Blackwell: https://github.com/NVIDIA/Model-Optimizer/actions/runs/22625069275 | | gpu | gpu_megatron | |------------|-----|--------------| | **Before** | 27m | 40m | | **After** | 16m | 19m | Full PR-merge CI/CD now only takes ~45mins ### Testing <!-- Mention how have you tested your change if applicable. --> Tested on 1,2,4,8 GPU setup ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, using `torch.load(..., weights_only=True)`, avoiding `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- If ❌, explain why. --> - If you copied code from any other source, did you follow IP policy in [CONTRIBUTING.md](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-copying-code-from-other-sources)?: N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Reworked GPU test infra to use a persistent distributed worker pool with improved per-worker teardown, aggregated error reporting, Megatron-specific cleanup, reduced GPU CI timeouts, and a small build-system dependency change. * **Tests** * Converted many multi-process GPU tests to a fixture-driven distributed runner, updating test entry points and orchestration while preserving test logic; adjusted imports and conditional guards for several GPU tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent d780fa5 commit a076e6c

40 files changed

Lines changed: 568 additions & 691 deletions

.github/workflows/gpu_tests.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# TODO: Optimize gpu tests runtime!
21
name: GPU tests
32

43
on:
@@ -64,10 +63,10 @@ jobs:
6463
matrix:
6564
include:
6665
- example: gpu
67-
timeout: 60
66+
timeout: 45
6867
container_image: pytorch:26.01-py3
6968
- example: gpu-megatron
70-
timeout: 90
69+
timeout: 45
7170
container_image: pytorch:26.01-py3
7271
- example: gpu-trtllm
7372
timeout: 30

modelopt/torch/nas/plugins/megatron.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,12 @@ class _DynamicMLP(DynamicModule):
203203
Use for standard MLP and inside MoE layers (SequentialMLP and SharedExpertMLP).
204204
"""
205205

206-
def _setup(self, *, hidden_size: TracedHp):
206+
def _setup(self, *, hidden_size: TracedHp, hp_name: str):
207207
"""Setup the MLP dynamic module with global hidden_size hparam."""
208208
assert self.input_size == self.config.hidden_size, (
209209
"MLP input_size must be equal to hidden_size"
210210
)
211-
if isinstance(self, SharedExpertMLP):
212-
self.hparam_name = "moe_shared_expert_intermediate_size"
213-
elif self.config.num_moe_experts is not None:
214-
self.hparam_name = "moe_ffn_hidden_size"
215-
else:
216-
self.hparam_name = "ffn_hidden_size"
211+
self.hparam_name = hp_name
217212

218213
ffn_hidden_size = TracedHp(list(range(1, self.config.ffn_hidden_size + 1)))
219214
self._register_hparam(self.hparam_name, ffn_hidden_size)
@@ -552,7 +547,7 @@ def _setup(self, *, hidden_size: TracedHp):
552547
DynamicModuleList.convert(self.local_experts)
553548
self.local_experts.depth = num_moe_experts # Reuse same hparam for depth
554549
for expert in self.local_experts:
555-
DMRegistry.convert(expert, hidden_size=hidden_size)
550+
DMRegistry.convert(expert, hidden_size=hidden_size, hp_name="moe_ffn_hidden_size")
556551

557552
def export(self) -> torch.nn.Module:
558553
"""Export the dynamic module to a standard SequentialMLP."""
@@ -582,7 +577,11 @@ def _setup(self, *, hidden_size: TracedHp):
582577
lambda mod, val: num_moe_experts_hp.active, # EP = 1
583578
)
584579
if self.use_shared_expert:
585-
DMRegistry.convert(self.shared_experts, hidden_size=hidden_size)
580+
DMRegistry.convert(
581+
self.shared_experts,
582+
hidden_size=hidden_size,
583+
hp_name="moe_shared_expert_intermediate_size",
584+
)
586585

587586
def forward(self, *args, **kwargs):
588587
"""Forward pass for the MoE layer."""
@@ -651,7 +650,11 @@ def _setup(self, *, hidden_size: TracedHp):
651650

652651
if isinstance(self.mlp, (MLP, MoELayer)):
653652
DMRegistry.convert(self.pre_mlp_layernorm, num_features=hidden_size)
654-
DMRegistry.convert(self.mlp, hidden_size=hidden_size)
653+
if isinstance(self.mlp, MoELayer):
654+
setup_kwargs = {}
655+
else:
656+
setup_kwargs = {"hp_name": "ffn_hidden_size"}
657+
DMRegistry.convert(self.mlp, hidden_size=hidden_size, **setup_kwargs)
655658

656659
def modify(
657660
self,

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,7 @@ def run_search(self) -> None:
331331
)
332332
print_rank_0(f"Pruned hybrid_override_pattern: {self.model.hybrid_override_pattern}")
333333

334-
def _prune(
335-
self,
336-
export_config: dict,
337-
prune_depth: bool = True,
338-
) -> None:
334+
def _prune(self, export_config: dict, prune_depth: bool = True) -> None:
339335
"""Prune the model homogeneously based on the export_config by setting active choices for configurable hparams.
340336
341337
Args:

modelopt/torch/quantization/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
"""Quantization utilities."""
1717

18-
from __future__ import annotations
19-
2018
from collections import namedtuple
2119
from contextlib import ExitStack, contextmanager, nullcontext
2220
from typing import TYPE_CHECKING, Any
@@ -28,14 +26,13 @@
2826
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
2927
from torch.distributed.tensor import Replicate
3028

29+
from modelopt.torch.opt.searcher import ForwardLoop
3130
from modelopt.torch.utils import get_unwrapped_name, print_rank_0
3231
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3332

3433
if TYPE_CHECKING:
3534
from collections.abc import Generator
3635

37-
from modelopt.torch.opt.searcher import ForwardLoop
38-
3936
__all__ = [
4037
"EXPORT_MODE",
4138
"convert_quantization_axis_to_reduce_axis",
@@ -220,7 +217,7 @@ def reduce_sum(input, axis=None, keepdims=True):
220217
return output
221218

222219

223-
def weight_attr_names(module: nn.Module) -> Generator[str, None, None]:
220+
def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
224221
"""Get the weight param attribute names in a converted module, non-recursive.
225222
226223
We consider the following two cases for each weight param attribute:

modelopt/torch/utils/nemotron_vlm_dataset_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
VLM calibration pipeline.
2424
"""
2525

26-
from __future__ import annotations
27-
2826
import functools
2927
import json
3028
import os

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
############################### BUILD CONFIGURATION ##############################################
33
####################################################################################################
44
[build-system]
5-
requires = ["cython", "setuptools>=80", "setuptools-scm>=8"]
5+
requires = ["setuptools>=80", "setuptools-scm>=8"]
66
build-backend = "setuptools.build_meta"
77

88
[tool.setuptools_scm]

tests/_test_utils/torch/distributed/utils.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
import os
1717
import socket
18+
import traceback
1819

19-
import pytest
2020
import torch
2121
import torch.distributed as dist
2222
import torch.multiprocessing as mp
@@ -65,12 +65,136 @@ def spawn_multiprocess_job(size, job, backend="gloo"):
6565
assert not p.exitcode
6666

6767

68-
def get_device_counts():
69-
num_gpus = torch.cuda.device_count()
70-
return [
71-
1,
72-
pytest.param(2, marks=pytest.mark.skipif(num_gpus < 2, reason="need 2 GPUs!")),
73-
]
68+
def default_worker_teardown(rank, world_size):
69+
"""Minimal cleanup between tests in persistent workers."""
70+
try:
71+
from accelerate.state import AcceleratorState
72+
73+
AcceleratorState._reset_state()
74+
except ImportError:
75+
pass
76+
except Exception as e:
77+
print(f"Error resetting AcceleratorState: {e}")
78+
torch.cuda.empty_cache()
79+
80+
81+
class DistributedWorkerPool:
82+
"""Persistent worker pool that keeps distributed processes alive across multiple test dispatches.
83+
84+
Instead of spawning/destroying processes per test (which adds ~10s overhead each time),
85+
workers are spawned once and reuse the same ``torch.distributed`` process group.
86+
Use with a module-scoped pytest fixture to share workers across all tests in a file.
87+
88+
Usage::
89+
90+
pool = DistributedWorkerPool(
91+
world_size=2, backend="nccl", teardown_fn=default_worker_teardown
92+
)
93+
94+
95+
def _test_fn(rank, size): ...
96+
97+
98+
pool.run(_test_fn)
99+
pool.run(partial(other_fn, arg1))
100+
pool.shutdown()
101+
"""
102+
103+
def __init__(self, world_size, backend="nccl", teardown_fn=default_worker_teardown):
104+
assert world_size > 0, "World size must be greater than 0"
105+
self.world_size = world_size
106+
ctx = mp.get_context("spawn")
107+
self._cmd_queues = [ctx.Queue() for _ in range(world_size)]
108+
self._result_queue = ctx.Queue()
109+
self._processes = []
110+
111+
port = get_free_port()
112+
for rank in range(world_size):
113+
p = ctx.Process(
114+
target=self._worker_loop,
115+
args=(
116+
rank,
117+
world_size,
118+
backend,
119+
port,
120+
self._cmd_queues[rank],
121+
self._result_queue,
122+
teardown_fn,
123+
),
124+
)
125+
p.start()
126+
self._processes.append(p)
127+
128+
for _ in range(world_size):
129+
msg = self._result_queue.get(timeout=120)
130+
assert msg == "ready", f"Worker failed to initialize: {msg}"
131+
132+
@staticmethod
133+
def _worker_loop(rank, world_size, backend, port, cmd_queue, result_queue, teardown_fn):
134+
os.environ["MASTER_ADDR"] = "localhost"
135+
os.environ["MASTER_PORT"] = str(port)
136+
os.environ["LOCAL_RANK"] = str(rank)
137+
os.environ["RANK"] = str(rank)
138+
os.environ["WORLD_SIZE"] = str(world_size)
139+
dist.init_process_group(backend, rank=rank, world_size=world_size)
140+
if backend == "nccl" and torch.cuda.is_available():
141+
torch.cuda.set_device(rank)
142+
torch.manual_seed(1234)
143+
result_queue.put("ready")
144+
145+
while True:
146+
cmd = cmd_queue.get()
147+
if cmd is None:
148+
break
149+
fn, args, kwargs = cmd
150+
status = "ok"
151+
tb = None
152+
try:
153+
fn(rank, world_size, *args, **kwargs)
154+
except Exception:
155+
status = "error"
156+
tb = traceback.format_exc()
157+
finally:
158+
if teardown_fn is not None:
159+
try:
160+
teardown_fn(rank, world_size)
161+
except Exception as e:
162+
print(f"Error tearing down worker: {e}")
163+
status = "error"
164+
teardown_tb = traceback.format_exc()
165+
tb = (tb + "\n" if tb else "") + f"[teardown] {teardown_tb}"
166+
result_queue.put((status, rank, tb))
167+
168+
dist.destroy_process_group()
169+
170+
def run(self, fn, *args, **kwargs):
171+
"""Dispatch ``fn`` to all workers and block until completion.
172+
173+
``fn`` is called as ``fn(rank, world_size, *args, **kwargs)`` and must be picklable
174+
(top-level function or ``functools.partial`` of one).
175+
"""
176+
for q in self._cmd_queues:
177+
q.put((fn, args, kwargs))
178+
179+
errors = []
180+
for _ in range(self.world_size):
181+
status, rank, tb = self._result_queue.get(timeout=600)
182+
if status == "error":
183+
errors.append(f"--- Rank {rank} ---\n{tb}")
184+
185+
if errors:
186+
raise RuntimeError("Worker(s) failed:\n" + "\n".join(errors))
187+
188+
def shutdown(self):
189+
"""Signal all workers to exit and wait for them to finish."""
190+
for q in self._cmd_queues:
191+
q.put(None)
192+
for p in self._processes:
193+
p.join(timeout=60)
194+
if p.is_alive():
195+
p.terminate()
196+
# Ensure the terminated process is fully reaped to avoid zombies.
197+
p.join(timeout=10)
74198

75199

76200
def synchronize_state_dict(model: nn.Module):

tests/conftest.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
import platform
1717

1818
import pytest
19+
import torch
20+
import torch.distributed as dist
21+
from _test_utils.torch.distributed.utils import init_process
22+
23+
import modelopt.torch.opt as mto
1924

2025

2126
@pytest.fixture(scope="session")
@@ -57,3 +62,53 @@ def pytest_collection_modifyitems(config, items):
5762
def skip_on_windows():
5863
if platform.system() == "Windows":
5964
pytest.skip("Skipping on Windows")
65+
66+
67+
@pytest.fixture(scope="session")
68+
def num_gpus():
69+
return torch.cuda.device_count()
70+
71+
72+
@pytest.fixture(scope="session")
73+
def cuda_capability():
74+
if not torch.cuda.is_available():
75+
pytest.skip("CUDA is not available")
76+
return torch.cuda.get_device_capability()
77+
78+
79+
@pytest.fixture
80+
def distributed_setup_size_1():
81+
init_process(rank=0, size=1, backend="nccl")
82+
yield
83+
dist.destroy_process_group()
84+
85+
86+
@pytest.fixture
87+
def need_2_gpus():
88+
if torch.cuda.device_count() < 2:
89+
pytest.skip("Need at least 2 GPUs to run this test")
90+
91+
92+
@pytest.fixture
93+
def need_4_gpus():
94+
if torch.cuda.device_count() < 4:
95+
pytest.skip("Need at least 4 GPUs to run this test")
96+
97+
98+
@pytest.fixture
99+
def need_8_gpus():
100+
if torch.cuda.device_count() < 8:
101+
pytest.skip("Need at least 8 GPUs to run this test")
102+
103+
104+
@pytest.fixture(scope="module")
105+
def set_torch_dtype(request):
106+
orig_dtype = torch.get_default_dtype()
107+
torch.set_default_dtype(request.param)
108+
yield
109+
torch.set_default_dtype(orig_dtype)
110+
111+
112+
@pytest.fixture(scope="session", autouse=True)
113+
def enable_hf_checkpointing():
114+
mto.enable_huggingface_checkpointing()

tests/examples/conftest.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,9 @@
1515

1616

1717
import pytest
18-
import torch
1918
from _test_utils.torch.transformers_models import create_tiny_llama_dir
2019

2120

22-
@pytest.fixture(scope="session")
23-
def num_gpus():
24-
return torch.cuda.device_count()
25-
26-
27-
@pytest.fixture(scope="session")
28-
def cuda_capability():
29-
return torch.cuda.get_device_capability()
30-
31-
3221
@pytest.fixture(scope="session")
3322
def tiny_llama_path(tmp_path_factory):
3423
return str(

0 commit comments

Comments
 (0)