-
Notifications
You must be signed in to change notification settings - Fork 423
Expand file tree
/
Copy pathmcore_minitron.py
More file actions
1272 lines (1094 loc) · 54.4 KB
/
mcore_minitron.py
File metadata and controls
1272 lines (1094 loc) · 54.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module implementing top-level ``mcore_minitron`` pruning handler for NVIDIA Megatron-Core / NeMo models.
Minitron pruning algorithm uses activation magnitudes to estimate importance of neurons / attention heads / mamba heads
in the model.
More details on Minitron pruning algorithm can be found here: https://arxiv.org/pdf/2407.14679
Supports both GPT (attention-based) and Mamba (state-space) models, as well as hybrid models with both types of layers.
Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`.
"""
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from itertools import product
from typing import Any
from warnings import warn
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.mamba.mamba_model import MambaModel
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
)
from pydantic import create_model
from tqdm import tqdm
from modelopt.torch.nas.conversion import NASModeRegistry
from modelopt.torch.nas.plugins.megatron import (
HAS_MAMBA,
SUPPORTED_MODELS,
_DynamicMambaLayer,
_DynamicMambaMixer,
_DynamicMCoreLanguageModel,
_DynamicMLP,
_DynamicMoELayer,
_DynamicSelfAttention,
_DynamicSequentialMLP,
_DynamicTransformerLayer,
)
from modelopt.torch.nas.registry import DMRegistry
from modelopt.torch.nas.utils import get_subnet_config, sample, sort_parameters
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
from modelopt.torch.opt.conversion import ApplyModeError
from modelopt.torch.opt.dynamic import DynamicModule, DynamicSpace
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
ConvertReturnType,
ModeDescriptor,
RestoreEntrypoint,
)
from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict
from modelopt.torch.opt.utils import named_hparams
from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import get_module_device, num2hrb, print_rank_0
from ..pruning import PruneModeRegistry
SUPPORTED_HPARAMS = {
# 1. Width pruning
"hidden_size",
# MLP
"ffn_hidden_size",
# Attention
"num_attention_heads",
# Mamba
"mamba_num_heads",
"mamba_head_dim",
# MoE
"moe_ffn_hidden_size",
"moe_shared_expert_intermediate_size",
"num_moe_experts",
# 2. Depth pruning
"num_layers",
}
__all__ = [
"SUPPORTED_HPARAMS",
"MCoreMinitronConfig",
"MCoreMinitronModeDescriptor",
"MCoreMinitronSearcher",
"drop_mcore_language_model_layers",
"get_mcore_minitron_config",
"get_mcore_param_count",
]
def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
"""Remove given layers (1-indexed) of the model (works with TP and/or PP).
If model is a wrapper around GPTModel or MambaModel, it will be unwrapped.
"""
layers_to_drop = sorted(layers_to_drop)
assert layers_to_drop[0] >= 1, (
f"Layers to drop should be in range 1 to {model.config.num_layers}, got {layers_to_drop}."
)
supported_model_types = tuple(SUPPORTED_MODELS.keys())
for n, m in model.named_modules():
if isinstance(m, supported_model_types):
model = m
break
assert isinstance(model, supported_model_types), (
f"Model should have one of {supported_model_types} submodule, got {model}"
)
print_rank_0(f"Dropping decoder layers {layers_to_drop} from model.")
# get the number of layers remaining in each pp rank
layers_remaining_per_pp = torch.zeros(
get_pipeline_model_parallel_world_size(),
dtype=torch.int,
device=get_module_device(model),
)
layers_remaining = torch.tensor(
sum(1 for layer in model.decoder.layers if layer.layer_number not in layers_to_drop),
dtype=torch.int,
device=get_module_device(model),
)
# Below distributed gather requires tensors to be on cuda
layers_remaining_per_pp = layers_remaining_per_pp.cuda()
layers_remaining = layers_remaining.cuda()
torch.distributed.all_gather_into_tensor(
layers_remaining_per_pp, layers_remaining, group=get_pipeline_model_parallel_group()
)
layers_remaining_per_pp = [i.item() for i in layers_remaining_per_pp]
new_num_layers = sum(layers_remaining_per_pp)
# reindex kept layers, exclude sharded state dict for dropped layers
layer_number = sum(layers_remaining_per_pp[: get_pipeline_model_parallel_rank()]) + 1
kept_layers = []
for layer in model.decoder.layers:
if layer.layer_number not in layers_to_drop:
layer.layer_number = layer_number
layer_number += 1
kept_layers.append(layer)
model.decoder.layers = nn.ModuleList(kept_layers)
model.config.num_layers = new_num_layers
@dataclass
class CandidateSubnet:
ss_config: dict
params: float
score: float | None
class MCoreMinitronSearcher(BaseSearcher):
"""Searcher for Minitron pruning algorithm.
Available additional config options (used when `params` constraint is provided):
- `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40).
Only top (1 - max_width_pruning) choices will be considered.
- `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20).
Only top (1 - max_depth_pruning) choices will be considered.
- `hparams_to_skip`: List of hparams to skip during the search (default: None).
- `top_k`: Number of candidates to consider for score_func validation (default: 10).
"""
activations_per_rank: list[dict[str, torch.Tensor]]
layer_scores: dict[int, torch.Tensor]
sorted_layers: list[int] | None # 1-indexed sorted list of layer numbers
# Dict from params constraint to list of tuples (ss_config, params, score)
top_k_candidates_per_constraint: dict[float, list[CandidateSubnet]]
@property
def default_search_config(self) -> SearchConfig:
"""Get the default config for the searcher."""
return {
**super().default_search_config,
"max_iter_data_loader": 1024,
"skip_sorting": False,
"scores_path": None,
# Additional search config for parameter-based pruning
"max_width_pruning": 0.40,
"max_depth_pruning": 0.20,
"hparams_to_skip": None,
"top_k": 10,
}
@property
def default_state_dict(self) -> SearchStateDict:
"""Return default state dict for importance scores and activations from forward loop."""
return {
"activations_per_rank": [],
"layer_scores": {},
"sorted_layers": None,
"top_k_candidates_per_constraint": {},
}
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
"""Sanitize the search config dict."""
config = super().sanitize_search_config(config)
if config["scores_path"]:
config["checkpoint"] = config["scores_path"]
config["verbose"] = True # Print for all ranks
return config
def before_search(self) -> None:
"""Optional pre-processing steps before the search."""
super().before_search()
# Check that the constraint is valid
assert len(self.constraints) == 1 and next(iter(self.constraints.keys())) in {
"export_config",
"params",
}, "Only `export_config` or `params` constraint is supported!"
if "export_config" in self.constraints:
export_config = self.constraints["export_config"]
assert isinstance(export_config, dict) # to keep mypy happy
if "num_query_groups" in export_config:
warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.")
if export_config["num_query_groups"] != self.model.config.num_query_groups:
raise ValueError(
f"num_query_groups must be {self.model.config.num_query_groups}!"
)
export_config.pop("num_query_groups")
assert export_config.keys() <= SUPPORTED_HPARAMS, (
f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config=}"
)
# Only sort the parameters that are to be pruned
# If a user only prunes depth, we should not sort width parameters
self.hps_to_sort = set(export_config.keys())
else:
assert isinstance(self.constraints["params"], (int, float)), "params must be a float!"
assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!"
export_config = None
# Sort all parameters for parameter-based pruning
self.hps_to_sort = SUPPORTED_HPARAMS
for n, hp in named_hparams(self.model, unique=True):
hp_name = n.split(".")[-1]
if hp.is_configurable:
# Make sure configurable hparams are the ones with right names else implementation needs to be fixed!
assert hp_name in SUPPORTED_HPARAMS, f"[ImplError] Invalid hparam {hp_name}!"
if export_config is not None and hp_name in export_config:
assert export_config[hp_name] in hp.choices, (
f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}"
)
hp.reset_choices() # Make sure ConcatHparam choices are updated after modify()
assert isinstance(self.model, _DynamicMCoreLanguageModel), (
"Input should be unwrapped MCore model!"
)
def run_search(self) -> None:
"""Run forward loop to collect activations, sort parameters, and prune the model."""
registry = ImportanceEstimatorRegistry(self.model)
if self.layer_scores and self.activations_per_rank: # Available from checkpoint
registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores)
elif not self.config["skip_sorting"]:
assert self.forward_loop is not None
is_training = self.model.training
self.model.eval()
with torch.no_grad():
self.forward_loop(self.model)
self.model.train(is_training)
# Store activations and layer scores for re-pruning with different export configs
self.activations_per_rank, self.layer_scores = (
registry.get_activations_and_layer_scores()
)
self.save_search_checkpoint(verbose=True)
if self.config["skip_sorting"]:
print_rank_0("Skipping sorting parameters...")
else:
sort_parameters(self.model, self.hps_to_sort, verbose=True)
registry.cleanup()
if self.layer_scores:
# sort layers by scores and drop the lowest ones
self.sorted_layers = [
layer
for layer, _ in sorted(self.layer_scores.items(), key=lambda x: x[1], reverse=True)
]
assert sorted(self.sorted_layers) == list(range(1, self.model.config.num_layers + 1))
else:
assert (
self.constraints.keys() == {"export_config"}
and "num_layers" not in self.constraints["export_config"]
), "Cannot prune `num_layers` without collecting layer scores!"
self.sorted_layers = None
if "params" in self.constraints:
export_config = self.search_best_arch_by_params()
else:
export_config = self.constraints["export_config"]
# Prune homogeneously
self._prune(export_config, prune_depth=True)
# TODO: Rename to hybrid_layer_pattern after https://github.com/NVIDIA/Megatron-LM/pull/3377
# Update hybrid_override_pattern if pruning is done on a hybrid model
if isinstance(self.model, MambaModel):
print_rank_0(f"Original hybrid_override_pattern: {self.model.hybrid_override_pattern}")
new_num_layers = self.model.config.num_layers
assert self.sorted_layers is not None
kept_layers_numbers = self.sorted_layers[:new_num_layers]
self.model.hybrid_override_pattern = "".join(
c
for i, c in enumerate(self.model.hybrid_override_pattern)
if i + 1 in kept_layers_numbers
)
print_rank_0(f"Pruned hybrid_override_pattern: {self.model.hybrid_override_pattern}")
def _prune(
self,
export_config: dict,
prune_depth: bool = True,
) -> None:
"""Prune the model homogeneously based on the export_config by setting active choices for configurable hparams.
Args:
export_config: Dictionary mapping hyperparameter names to their pruned values.
prune_depth: Whether to drop layers based on sorted_layers (default: True).
"""
# Prune homogeneously
for n, hp in named_hparams(self.model, configurable=True):
hp_name = n.split(".")[-1]
if hp_name in export_config:
hp.active = export_config[hp_name]
# Drop layers if depth pruning is enabled
if prune_depth:
num_layers_hp = self.model.get_hparam("num_layers")
if num_layers_hp.active != num_layers_hp.max:
assert self.sorted_layers is not None
layers_to_drop = self.sorted_layers[num_layers_hp.active :]
drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop)
# Update model config with pruned architecture
# kv_channels can be None so we need to save from original hidden_size and num_attention_heads
if self.model.config.kv_channels is None:
self.model.config.kv_channels = (
self.model.config.hidden_size // self.model.config.num_attention_heads
)
# num_query_groups can be None so we need to save from original num_attention_heads
if self.model.config.num_query_groups is None:
self.model.config.num_query_groups = self.model.config.num_attention_heads
# moe_ffn_hidden_size can be None so we need to save from original ffn_hidden_size
if (
self.model.config.moe_ffn_hidden_size is None
and self.model.config.num_moe_experts is not None
):
self.model.config.moe_ffn_hidden_size = self.model.config.ffn_hidden_size
# Now set hparam active choices
for hp_name, hp_value in export_config.items():
setattr(self.model.config, hp_name, hp_value)
# Reinitialize the MoE token dispatcher after pruning
for m in self.model.modules():
if isinstance(m, _DynamicMoELayer):
m._export_reinit_token_dispatcher()
break
def search_best_arch_by_params(self) -> dict:
"""Search for the best architecture based on the given parameters constraints.
We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints.
Top-k candidates (sorted by param count) are then validated using the score_func (e.g. MMLU)
and the best subnet is returned.
Returns:
export_config: Dictionary mapping hyperparameter names to their pruned values.
"""
assert self.sorted_layers is not None
max_params = float(self.constraints["params"]) # type: ignore[arg-type]
max_width_pruning = self.config["max_width_pruning"]
max_depth_pruning = self.config["max_depth_pruning"]
hparams_to_skip = self.config["hparams_to_skip"]
top_k = self.config["top_k"]
print_rank_0(
f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..."
)
# 1. Find available search space choices (across all PP ranks)
hp_choices = {}
for n, hp in named_hparams(self.model, configurable=True):
hp_name = n.split(".")[-1]
hp_choices[hp_name] = hp.choices
pp_group = dist.DistributedProcessGroup(get_pipeline_model_parallel_group())
hp_choices = dist.DistributedProcessGroup.get_dist_syncd_obj(
hp_choices,
pp_group,
op=lambda all_pp_search_spaces: {
k: v for d in all_pp_search_spaces for k, v in d.items()
},
)
# 2. Perform grid-search over the search space to find subnets fitting the constraints
if (
max_params not in self.top_k_candidates_per_constraint
or len(self.top_k_candidates_per_constraint[max_params]) != top_k
):
max_num_layers = self.model.get_hparam("num_layers").max
search_space_configs = MCoreMinitronSearcher._generate_search_space_combos(
hp_choices,
max_width_pruning,
max_depth_pruning,
hparams_to_skip,
)
sample(self.model, sample_func=max) # reset to max subnet (for sanity)
selected = []
for ss_config in tqdm(
search_space_configs,
desc=f"Finding top {top_k} (`config['top_k']`) candidates fitting the constraints...",
disable=not dist.is_master(),
):
self._prune(ss_config, prune_depth=False)
layer_ids = None
if "num_layers" in ss_config and ss_config["num_layers"] < max_num_layers:
layer_ids = self.sorted_layers[: ss_config["num_layers"]]
candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids)
if candidate_params <= max_params:
selected.append(CandidateSubnet(ss_config, candidate_params, None))
sample(self.model, sample_func=max) # reset to max subnet
assert len(selected) > 0, "No subnets found fitting the constraints!"
print_rank_0(f"Found {len(selected)} candidates fitting the constraints!")
self.top_k_candidates_per_constraint[max_params] = sorted(
selected, key=lambda x: x.params, reverse=True
)[:top_k]
self.save_search_checkpoint(verbose=True)
else:
print_rank_0(f"\nUsing top {top_k} candidates from checkpoint")
top_k_candidates = self.top_k_candidates_per_constraint[max_params]
print_rank_0(f"\n====================\nTop {top_k} candidates:")
for candidate in top_k_candidates:
print_rank_0(f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params")
print_rank_0("====================\n")
# 3. Optional Knowledge Distillation (KD) step for all top-k candidates
print_rank_0(
"\nSkipping optional Knowledge Distillation (KD) step for candidates as it is a manual step. "
"As per the original paper (https://arxiv.org/pdf/2407.14679), ideally we need to perform a short "
f"Knowledge Distillation on ~2B tokens for all top {top_k} candidates before evaluating the "
"`score_func`, which will take a lot longer to prune, require splitting the pruning process into multiple "
"stages and a lot more compute for pruning but can lead to better pruned model selection. If you are "
f"interested to do this, you can take the top {top_k} candidates' `export_config` from the logs above and "
"then export all models separately and perform Knowledge Distillation on each of them before evaluating "
"the `score_func`.\n"
)
# 4. Validate top-k candidates using the score_func and return the best subnet
for candidate in tqdm(
top_k_candidates,
desc=f"Validating top {top_k} candidates on given score_func (this will take some time)...",
disable=not dist.is_master(),
smoothing=0.7,
):
if candidate.score is None: # not restored from checkpoint
all_layers = self.model.decoder.layers
start_layer_number = all_layers[0].layer_number
self._prune(candidate.ss_config, prune_depth=True)
candidate.score = self.eval_score(silent=False)
self.save_search_checkpoint(verbose=False)
# reset to max subnet and revert dropped layers
sample(self.model, sample_func=max)
for layer in all_layers:
layer.layer_number = start_layer_number
start_layer_number += 1
self.model.decoder.layers = all_layers
print_rank_0(
f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n"
)
print_rank_0(f"\n====================\nTop {top_k} candidates with scores:")
for candidate in top_k_candidates:
print_rank_0(
f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score"
)
print_rank_0("====================\n")
dist.barrier()
best = max(top_k_candidates, key=lambda x: x.score) # type: ignore[arg-type, return-value]
print_rank_0(
f"\n[BEST SUBNET] {best.ss_config} -> {num2hrb(best.params)} params, {best.score:.4f} score\n"
)
return best.ss_config
@staticmethod
def _generate_search_space_combos(
search_space: dict[str, list],
max_width_pruning: float = 0.40,
max_depth_pruning: float = 0.20,
hparams_to_skip: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Generate all possible combinations of hyperparameters from the search space.
Args:
search_space: Dictionary mapping hyperparameter names to their possible sorted choices.
Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]}
max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.40).
Only top (1 - max_width_pruning) choices will be considered.
max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.20).
Only top (1 - max_depth_pruning) choices will be considered.
hparams_to_skip: List of hparams to skip during the search (default: None).
Returns:
List of configuration dictionaries, where each dictionary maps hyperparameter
names to their chosen values. Example:
[
{"hidden_size": 1024, "num_layers": 1},
{"hidden_size": 1024, "num_layers": 2},
...
{"hidden_size": 4096, "num_layers": 32},
]
"""
print_rank_0(
f"\nOnly considering atmost {(max_width_pruning * 100):.0f}% for width and "
f"{max_depth_pruning * 100:.0f}% for depth pruning hparams"
)
if hparams_to_skip:
search_space = dict(search_space) # Avoid modifying the original search space
print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...")
for hparam in hparams_to_skip:
if hparam in search_space:
search_space.pop(hparam)
else:
warn(f"Hparam {hparam} not found in search space! Skipping...")
filtered_ss = {
k: (
sorted(v)[int((1 - max_depth_pruning) * len(v)) :]
if k == "num_layers"
else sorted(v)[int((1 - max_width_pruning) * len(v)) :]
)
for k, v in search_space.items()
if len(v) > 1
}
ss_size = 1
for k, v in filtered_ss.items():
print_rank_0(f"\tSearch space for {k}: {v}")
ss_size *= len(v)
print_rank_0(f"\tTotal search space in consideration: {ss_size}\n")
hparam_names = list(filtered_ss.keys())
hparam_choices_lists = [filtered_ss[name] for name in hparam_names]
search_space_combos = [
dict(zip(hparam_names, choices)) for choices in product(*hparam_choices_lists)
]
assert len(search_space_combos) == ss_size
return search_space_combos
def get_mcore_param_count(model: GPTModel | MambaModel) -> float:
"""Get the number of parameters in the MCore GPTModel or MambaModel (reduced across TP and PP ranks)."""
assert isinstance(model, (GPTModel, MambaModel)), "Model must be a GPTModel or MambaModel"
if isinstance(model, DynamicModule):
return _param_num_dynamic(model)
else:
return _param_num(model)
def _param_num(model: GPTModel | MambaModel) -> float:
"""Get the number of parameters in the model (reduced across TP and PP ranks)."""
# Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True
params = sum(
p.numel()
for name, p in model.named_parameters()
if not model.share_embeddings_and_output_weights or "output_layer.weight" not in name
)
reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device)
torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group())
torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group())
return reduced_params.item()
def _param_num_dynamic(
model: _DynamicMCoreLanguageModel, *, layer_numbers_to_count: list[int] | None = None
) -> float:
"""Get the number of parameters in the Dynamic Module (reduced across TP and PP ranks).
Args:
model: GPTModel or MambaModel converted to a DynamicModule.
layer_numbers_to_count: If specified, only count the parameters of the given layer numbers (1-indexed).
Only needed when input is a DynamicModule to correctly count the parameters of the active layers.
"""
# NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters!
def get_param_count(mod, name) -> int:
"""Use getattr to access parameters correctly."""
module_path, _, param_name = name.rpartition(".")
submodule = mod.get_submodule(module_path) if module_path else mod
return getattr(submodule, param_name).numel()
# Account for depth pruning with uneven PP and hybrid models!
# Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True
params = sum(
get_param_count(model, name)
for name, _ in model.named_parameters()
if ("decoder.layers." not in name or layer_numbers_to_count is None)
and not (model.share_embeddings_and_output_weights and "output_layer.weight" in name)
)
if layer_numbers_to_count is not None:
for layer in model.decoder.layers:
if layer.layer_number in layer_numbers_to_count:
params += sum(get_param_count(layer, name) for name, _ in layer.named_parameters())
reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device)
torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group())
torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group())
return reduced_params.item()
MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model(
"MCoreMinitronConfig",
**get_kwargs_for_create_model_with_rules(
registry=DMRegistry,
default_rules={
"megatron.core.models.gpt.GPTModel": {
"hidden_size_divisor": 256,
"ffn_hidden_size_divisor": 512,
"num_moe_experts_divisor": 8,
"num_layers_divisor": 2,
},
**(
{
"megatron.core.models.mamba.MambaModel": {
"hidden_size_divisor": 256,
"ffn_hidden_size_divisor": 512,
"mamba_head_dim_divisor": 8,
"num_moe_experts_divisor": 8,
"num_layers_divisor": 2,
}
}
if HAS_MAMBA
else {}
),
},
doc='Configuration for the ``"mcore_minitron"`` mode.',
),
)
def get_mcore_minitron_config(
*,
hidden_size_divisor: int = 256,
ffn_hidden_size_divisor: int = 512,
mamba_head_dim_divisor: int = 8,
num_moe_experts_divisor: int = 8,
num_layers_divisor: int = 2,
) -> ModeloptBaseConfig:
"""Get a MCoreMinitronConfig with the given divisors instead of default."""
config = MCoreMinitronConfig()
def _set_divisors(c):
for k, v in c.items():
if isinstance(v, dict):
_set_divisors(v)
elif k == "hidden_size_divisor":
c[k] = hidden_size_divisor
elif k == "ffn_hidden_size_divisor":
c[k] = ffn_hidden_size_divisor
elif k == "mamba_head_dim_divisor":
c[k] = mamba_head_dim_divisor
elif k == "num_moe_experts_divisor":
c[k] = num_moe_experts_divisor
elif k == "num_layers_divisor":
c[k] = num_layers_divisor
_set_divisors(config)
return config
def _convert_model_to_dynamic_space(
model: nn.Module, config: ModeloptBaseConfig | None = None
) -> DynamicSpace:
"""Create a dynamic space for the model (in-place)."""
dynamic_space = DynamicSpace(model)
dynamic_space._should_be_converted = lambda mod: isinstance(mod, tuple(SUPPORTED_MODELS.keys()))
dynamic_space.convert_to_dynamic(config.model_dump() if config else None, DMRegistry)
if not dynamic_space.is_configurable():
raise ApplyModeError(
"The model does not contain any configurable hyperparameters! Please check the"
" documentation for modules and config and how to get a configurable model."
)
return dynamic_space
def convert_mcore_minitron(model: nn.Module, config: ModeloptBaseConfig) -> ConvertReturnType:
"""Convert the model to the dynamic search space (in-place) and return the converted model and metadata.
This is a simplified version of convert_fastnas_searchspace that removes the automated recursive tracing
and instead directly converts the top-level model to a DynamicModule. Submodules should not need to be explicitly
converted as that happens from the top-level model.
"""
_convert_model_to_dynamic_space(model, config)
# store current config in metadata
metadata = {"subnet_config": get_subnet_config(model)}
# return converted model as well as metadata
return model, metadata
def restore_mcore_minitron(
model: nn.Module, config: ModeloptBaseConfig, metadata: dict
) -> nn.Module:
"""Restore the model (no-op since we don't want to convert again which forces TP=1)."""
return model
@NASModeRegistry.register_mode
@PruneModeRegistry.register_mode
class MCoreMinitronModeDescriptor(ModeDescriptor):
"""Class to describe the ``"mcore_minitron"`` mode.
The properties of this mode can be inspected via the source code.
"""
@property
def name(self) -> str:
"""Returns the value (str representation) of the mode."""
return "mcore_minitron"
@property
def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return MCoreMinitronConfig
@property
def next_modes(self) -> set[str] | None:
"""Modes that must immediately follow this mode."""
return {"export_nas", "kd_loss", "quantize", "sparse_magnitude", "sparse_gpt"}
@property
def export_mode(self) -> str | None:
"""The mode that corresponds to the export mode of this mode."""
return "export_nas"
@property
def search_algorithm(self) -> type[BaseSearcher]:
"""Specifies the search algorithm to use for this mode."""
return MCoreMinitronSearcher
@property
def convert(self) -> ConvertEntrypoint:
"""The mode's entrypoint for converting a model to a search space."""
return convert_mcore_minitron
@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model with the modelopt_state."""
return restore_mcore_minitron
class ImportanceEstimatorRegistry:
"""Register importance estimators and forward hooks for all supported modules in the model.
This class should be instantiated after converting the model to DynamicModule but before
running the forward loop for importance estimation.
"""
def __init__(self, model: DynamicModule):
"""Initialize the registry."""
assert isinstance(model, _DynamicMCoreLanguageModel), "Model must be a DynamicModule"
self.model = model
self._hooks: list[tuple[nn.Module, Any]] = [] # List of (module, hook_handle) tuples
print_rank_0("Registering importance estimators and forward hooks...")
for module in self.model.modules():
if isinstance(module, _DynamicMCoreLanguageModel):
_register_hidden_size_importance(module, self)
elif isinstance(module, (_DynamicTransformerLayer, _DynamicMambaLayer)):
_register_depth_cosine_importance(module, self)
elif isinstance(module, _DynamicSelfAttention):
_register_self_attention_importance(module, self)
elif isinstance(module, _DynamicMLP):
_register_mlp_importance(module, self)
elif isinstance(module, _DynamicSequentialMLP):
_register_sequential_mlp_importance(module, self)
elif isinstance(module, _DynamicMambaMixer):
_register_mamba_mixer_importance(module, self)
def register_hook(
self,
module: nn.Module,
hook_fn: Callable,
hook_type: str = "forward",
**hook_kwargs,
) -> None:
"""Register a forward or forward_pre hook on a module.
Args:
module: The module to register the hook on.
hook_fn: The hook function to register.
hook_type: Type of hook ("forward" or "forward_pre").
**hook_kwargs: Additional kwargs for hook registration.
"""
if hook_type == "forward":
handle = module.register_forward_hook(hook_fn, **hook_kwargs)
elif hook_type == "forward_pre":
handle = module.register_forward_pre_hook(hook_fn, **hook_kwargs)
else:
raise ValueError(f"Unsupported hook_type: {hook_type}")
self._hooks.append((module, handle))
def register_importance(
self,
dynamic_module: DynamicModule,
hparam_name: str,
importance_fn: Callable,
importance_is_order: bool = False,
) -> None:
"""Register an importance estimator for a hyperparameter.
Args:
dynamic_module: The DynamicModule instance.
hparam_name: Name of the hyperparameter to register importance for.
importance_fn: Function that returns importance scores.
importance_is_order: Whether the importance is a ranking order.
"""
hp = dynamic_module.get_hparam(hparam_name)
if importance_is_order:
hp._importance_is_order = True
hp.register_importance(importance_fn)
def cleanup(self) -> None:
"""Remove all registered hooks and temporary attributes."""
# Remove all hooks
for _, handle in self._hooks:
handle.remove()
self._hooks.clear()
def get_layer_scores(self) -> dict[int, torch.Tensor]:
"""Get the layer scores (1-indexed) from the model.
Returns:
Dictionary mapping layer number to layer score.
"""
num_layers_hp = self.model.get_hparam("num_layers")
for layer in self.model.decoder.layers:
assert layer._scores > 0, "No scores collected for importance estimation."
# gather layer scores from all PP ranks
layer_scores = {}
for layer in self.model.decoder.layers:
layer_scores[layer.layer_number] = layer._scores
pp_group = dist.DistributedProcessGroup(get_pipeline_model_parallel_group())
layer_scores = dist.DistributedProcessGroup.get_dist_syncd_obj(
layer_scores,
pp_group,
op=lambda all_pp_layer_scores: {
k: v for d in all_pp_layer_scores for k, v in d.items()
},
)
print_rank_0(f"Layerwise scores (1-indexed, higher is better): {layer_scores}")
assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type]
return layer_scores
def get_activations_and_layer_scores(
self,
) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]:
"""Get the per-rank activations and layer scores from the model."""
local_activations = {}
for n, m in self.model.named_modules():
if hasattr(m, "_activations"):
local_activations[n] = m._activations
activations_per_rank = dist.allgather(
local_activations, group=get_pipeline_model_parallel_group()
)
assert len(activations_per_rank) == get_pipeline_model_parallel_world_size()
layer_scores = self.get_layer_scores()
return activations_per_rank, layer_scores
def set_activations_and_layer_scores(
self,
activations_per_rank: list[dict[str, torch.Tensor]],
layer_scores: dict[int, torch.Tensor],
) -> None:
"""Set the pre-computed layer_scores and per-rank activations instead of running forward.
Args:
activations_per_rank: List of dicts from module name to activations. Should match PP size.
layer_scores: Dict from layer_number (1-indexed) to score.
"""
print_rank_0("Loading activations and scores per rank from checkpoint...")
rank = get_pipeline_model_parallel_rank()
pp_size = get_pipeline_model_parallel_world_size()
assert len(activations_per_rank) == pp_size, (
f"Expected same PP size for stored pruning scores ({len(activations_per_rank)}) as current ({pp_size})!"
)
for layer in self.model.decoder.layers:
layer._scores = layer_scores[layer.layer_number]
for n, m in self.model.named_modules():
if hasattr(m, "_activations"):
m._activations = activations_per_rank[rank][n]
# Module-specific registration functions
def _register_hidden_size_importance(
module: _DynamicMCoreLanguageModel, registry: ImportanceEstimatorRegistry
) -> None:
"""Register importance estimators for Language Model (GPT/Mamba) modules."""
module._register_temp_attribute("_activations", {})
def _emb_layernorm_forward_hook(mod, module_inner, input, output):
"""Hook to collect activations for importance estimation.
Activations are computed as mean over seq_len and then squared and summed over batch_size.
Later we take the square root of the sum to get the L2 norm.
"""
# Gather output [seq_len, batch_size, hidden_size] over all TP regions
# NOTE: This is not used at the moment since we restrict to TP=1
output = gather_from_tensor_model_parallel_region(output).detach()
output = output.to(torch.float32) # use full precision to avoid overflow
activations = output.abs().mean(dim=0) # [batch_size, hidden_size]
activations = activations.pow(2).sum(dim=0)
if id(module_inner) not in mod._activations:
mod._activations[id(module_inner)] = activations
else:
mod._activations[id(module_inner)] += (
activations # aggregate sum instead of mean of scores for simplicity
)
def _estimate_hidden_size_importance(mod):
"""Return the activation magnitude-based importance of the hidden_size."""
assert mod._activations, "No activations collected for importance estimation."
# Convert squared sum to L2 norm over global batch size per hook
aggregated_activations = [act.pow(0.5) for act in mod._activations.values()]
activations = torch.stack(aggregated_activations).sum(dim=0) # [hidden_size]
# Reduce over all PP ranks
activations = activations.clone()
torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM)
return activations
# Register hooks for all layers
for layer in module.decoder.layers:
if isinstance(layer, _DynamicTransformerLayer):
if isinstance(layer.self_attention, _DynamicSelfAttention):
registry.register_hook(
layer.input_layernorm,
partial(_emb_layernorm_forward_hook, module),
hook_type="forward",
)
if isinstance(layer.mlp, (_DynamicMLP, _DynamicSequentialMLP)):
registry.register_hook(
layer.pre_mlp_layernorm,
partial(_emb_layernorm_forward_hook, module),
hook_type="forward",
)
elif isinstance(layer, _DynamicMambaLayer):
registry.register_hook(
layer.norm, partial(_emb_layernorm_forward_hook, module), hook_type="forward"
)
registry.register_importance(
module, "hidden_size", lambda: _estimate_hidden_size_importance(module)
)
def _register_depth_cosine_importance(
module: _DynamicTransformerLayer | _DynamicMambaLayer, registry: ImportanceEstimatorRegistry
) -> None:
"""Register importance estimators for TransformerLayer and MambaLayer modules."""