Skip to content

Commit a9fca91

Browse files
Allow searcher ckpt dir for per-rank ckpt files (#1091)
### What does this PR do? Type of change: minor improvement <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Currently we modify user's `path.pth` to `path<rank>.pth` which may be confusing. Instead we now add alternative to provide `path/` and store `path/rank<rank>.pth` per-rank searcher state files ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A <!--- Mandatory --> - Did you write any new necessary tests?: N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Documentation** * Updated pruning examples and CLI help to reference directory-based checkpoint paths for intermediate pruning scores and clarified default wording. * **Improvements** * Checkpoint handling now accepts directory paths for intermediate score storage, produces per-rank files automatically, and resolves to a directory default when unset. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent cf4e6cc commit a9fca91

5 files changed

Lines changed: 38 additions & 30 deletions

File tree

examples/megatron_bridge/prune_minitron.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def get_args() -> argparse.Namespace:
111111
type=str,
112112
default=None,
113113
help=(
114-
"Path to save/restore intermediate pruning scores for resuming / faster re-run. "
115-
"If not provided, it will default to `<output_path>/modelopt_pruning_scores.pth`"
114+
"Directory to save/restore per-rank intermediate pruning scores for resuming / faster re-run. "
115+
"If not provided, it will default to `<output_path>/modelopt_pruning_scores`"
116116
),
117117
)
118118

@@ -187,13 +187,11 @@ def get_args() -> argparse.Namespace:
187187
# Post-process arguments
188188
if args.prune_intermediate_ckpt is None:
189189
if args.output_megatron_path:
190-
args.prune_intermediate_ckpt = (
191-
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
192-
)
190+
args.prune_intermediate_ckpt = f"{args.output_megatron_path}/modelopt_pruning_scores"
193191
elif args.output_hf_path:
194-
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth"
192+
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores"
195193
print_rank_0(
196-
"No checkpoint provided to cache intermediate pruning scores. "
194+
"No directory provided to cache per-rank intermediate pruning scores. "
197195
f"Setting to: {args.prune_intermediate_ckpt}"
198196
)
199197

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ This mode can be useful when you know the exact dimensions you want to prune to
9797
# Specify the pruning constraints (Check Support Matrix for available pruning dimensions)
9898
# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again
9999
constraints = {"export_config": {"num_layers": 32, "hidden_size": 3584, "ffn_hidden_size": 10240}}
100-
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"}
100+
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores/"}
101101

102102
mtp.prune(...)
103103
```
@@ -129,7 +129,7 @@ def score_func(m):
129129
constraints = {"params": 6e9} # Prune to 6B parameters
130130
config = {
131131
"forward_loop": forward_loop,
132-
"checkpoint": "/path/to/cache/pruning/scores.pth",
132+
"checkpoint": "/path/to/cache/pruning/scores/",
133133
"score_func": score_func,
134134
# Optional: Configure search space constraints (showing defaults)
135135
"max_width_pruning": 0.4, # Maximum 40% per width pruning hparams (hidden_size, ffn_hidden_size, etc.)

modelopt/torch/opt/searcher.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from collections.abc import Callable
2828
from contextlib import nullcontext
29-
from typing import Any, final
29+
from typing import TYPE_CHECKING, Any, final
3030

3131
import numpy as np
3232
import pulp
@@ -36,6 +36,9 @@
3636
from modelopt.torch.utils import distributed as dist
3737
from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0
3838

39+
if TYPE_CHECKING:
40+
from pathlib import Path
41+
3942
LimitsTuple = tuple[float, float]
4043
ConstraintsDict = dict[str, str | float | dict | None]
4144
Deployment = dict[str, str]
@@ -238,9 +241,18 @@ def state_dict(self) -> SearchStateDict:
238241

239242
def _get_checkpoint_path(self) -> str | None:
240243
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
241-
checkpoint = self.config["checkpoint"]
244+
checkpoint: str | Path | None = self.config["checkpoint"]
242245
if checkpoint is None:
243246
return None
247+
checkpoint = str(checkpoint)
248+
# Detect directory: exists as dir, ends with separator, or has no file extension
249+
is_dir_path = (
250+
os.path.isdir(checkpoint)
251+
or checkpoint.endswith(os.sep)
252+
or not os.path.splitext(checkpoint)[1]
253+
)
254+
if is_dir_path:
255+
return os.path.join(checkpoint, f"rank{dist.rank()}.pth")
244256
if dist.is_initialized():
245257
dirname, basename = os.path.split(checkpoint)
246258
name, ext = os.path.splitext(basename)

tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _test_mcore_gpt_pruning(
123123
uneven_pp,
124124
position_embedding_type,
125125
skip_sorting,
126-
ckpt_path,
126+
ckpt_dir,
127127
rank,
128128
size,
129129
):
@@ -196,11 +196,11 @@ def forward_loop(m):
196196
constraints = {"export_config": export_config}
197197

198198
config = {
199-
"checkpoint": ckpt_path,
199+
"checkpoint": ckpt_dir,
200200
"skip_sorting": skip_sorting,
201201
}
202202
if skip_sorting:
203-
assert ckpt_path is None
203+
assert ckpt_dir is None
204204
else:
205205
config["forward_loop"] = forward_loop
206206
model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor)
@@ -236,11 +236,11 @@ def forward_loop(m):
236236
output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size)
237237

238238
# Assert re-pruning from checkpoint works without running the forward loop again
239-
if ckpt_path:
239+
if ckpt_dir:
240240
model_rerun = _get_model(initialize_megatron=False)
241241
model_rerun.load_state_dict(sd)
242242
model_rerun, pruning_scores = prune_minitron(
243-
model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor
243+
model_rerun, constraints, {"checkpoint": ckpt_dir}, channel_divisor
244244
)
245245

246246
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
@@ -305,7 +305,7 @@ def test_mcore_gpt_pruning(
305305
uneven_pp,
306306
position_embedding_type,
307307
skip_sorting,
308-
tmp_path / "minitron_scores.pth" if test_ckpt else None,
308+
tmp_path / "minitron_scores" if test_ckpt else None,
309309
),
310310
)
311311

@@ -391,7 +391,7 @@ def test_mcore_gpt_moe_parameter_sorting(dist_workers):
391391
dist_workers.run(_test_mcore_gpt_moe_parameter_sorting)
392392

393393

394-
def _test_mcore_gpt_pruning_moe(ckpt_path, rank, size):
394+
def _test_mcore_gpt_pruning_moe(ckpt_dir, rank, size):
395395
channel_divisor = 4
396396

397397
num_layers = size
@@ -442,7 +442,7 @@ def forward_loop(m):
442442
prune_minitron(
443443
model,
444444
constraints,
445-
{"checkpoint": ckpt_path, "forward_loop": forward_loop},
445+
{"checkpoint": ckpt_dir, "forward_loop": forward_loop},
446446
channel_divisor,
447447
)
448448

@@ -479,14 +479,14 @@ def forward_loop(m):
479479
# Assert re-pruning from checkpoint works without running the forward loop again
480480
model_rerun = _get_model(initialize_megatron=False)
481481
model_rerun.load_state_dict(sd)
482-
prune_minitron(model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor)
482+
prune_minitron(model_rerun, constraints, {"checkpoint": ckpt_dir}, channel_divisor)
483483

484484
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
485485
assert torch.allclose(output, output_rerun, atol=1e-5)
486486

487487

488488
def test_mcore_gpt_pruning_moe(dist_workers, tmp_path):
489-
dist_workers.run(partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"))
489+
dist_workers.run(partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores"))
490490

491491

492492
def test_generate_search_space_combos():

tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_mcore_mamba_parameter_sorting(dist_workers):
119119
dist_workers.run(_test_mcore_mamba_parameter_sorting)
120120

121121

122-
def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size):
122+
def _test_mcore_mamba_hybrid_pruning(ckpt_dir, rank, size):
123123
channel_divisor = 4
124124

125125
num_layers = min(size * 2, 8)
@@ -193,7 +193,7 @@ def forward_loop(m):
193193
prune_minitron(
194194
model,
195195
constraints,
196-
{"forward_loop": forward_loop, "checkpoint": ckpt_path},
196+
{"forward_loop": forward_loop, "checkpoint": ckpt_dir},
197197
channel_divisor,
198198
)
199199

@@ -224,16 +224,14 @@ def forward_loop(m):
224224

225225
# Assert re-pruning from checkpoint works without running the forward loop again
226226
model = _get_model(initialize_megatron=False)
227-
prune_minitron(model, constraints, {"checkpoint": ckpt_path}, channel_divisor)
227+
prune_minitron(model, constraints, {"checkpoint": ckpt_dir}, channel_divisor)
228228

229229

230230
def test_mcore_mamba_hybrid_pruning(dist_workers, tmp_path):
231-
dist_workers.run(
232-
partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth")
233-
)
231+
dist_workers.run(partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "minitron_scores"))
234232

235233

236-
def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size):
234+
def _test_mcore_mamba_hybrid_pruning_nas(ckpt_dir, rank, size):
237235
set_seed(SEED)
238236
channel_divisor = 4
239237

@@ -297,7 +295,7 @@ def score_func(m):
297295
constraints = {"params": int(param_count * 0.7)}
298296
config = {
299297
"forward_loop": forward_loop,
300-
"checkpoint": ckpt_path,
298+
"checkpoint": ckpt_dir,
301299
"score_func": score_func,
302300
"max_width_pruning": 0.5,
303301
"max_depth_pruning": 0.5,
@@ -363,5 +361,5 @@ def score_func(m):
363361
)
364362
def test_mcore_mamba_hybrid_pruning_nas(dist_workers, tmp_path):
365363
dist_workers.run(
366-
partial(_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "modelopt_minitron_scores.pth"),
364+
partial(_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "minitron_scores"),
367365
)

0 commit comments

Comments
 (0)