Skip to content

Commit 492af47

Browse files
committed
Updated script, updated readme
1 parent b17ccac commit 492af47

3 files changed

Lines changed: 20 additions & 4 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ hf auth login --token <your token>
5050

5151
This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md).
5252

53+
To estimate the importance of each layer of a model oen can use a `rank_layer_importance.py` script. This script compares the final hidden state representation with and without particular layer. The intuition behind is the less a layer affects the final hidden state the less important it is i.e. the safer is to drop it completly from the model.
54+
Example usage of the script to estiamte importance of NVIDIA-Nemotron-Nano-12B-v2 layers. Usually the first and the last layers are the most important ones - resulting scoring should be similar to the one in the Fig.
55+
56+
![Importance scres](nemotron-nano-12b-v2.png)
57+
58+
```bash
59+
torchrun --nproc_per_node=8 examples/megatron_bridge/rank_layer_importance.py \
60+
--hf_model_name_or_path /path/to/hf-checkpoint/nvidia/NVIDIA-Nemotron-Nano-12B-v2 \
61+
--trust_remote_code \
62+
--calib_dataset_name wikitext \
63+
--num_layers 62 \
64+
--save_scores_path /path/to/scores.pt
65+
```
66+
5367
Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
5468
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration,
5569
at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...),
47.8 KB
Loading

examples/megatron_bridge/rank_layer_importance.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import argparse
17-
import pickle
1817
from collections import defaultdict
1918

2019
import torch
@@ -254,6 +253,9 @@ def get_args() -> argparse.Namespace:
254253
"Useful for iterative pruning"
255254
),
256255
)
256+
parser.add_argument(
257+
"--save_scores_path", type=str, default="scores.pt", help="Path to save scores"
258+
)
257259

258260
args = parser.parse_args()
259261

@@ -280,7 +282,8 @@ def collect_scores(
280282
for i in range(num_layers):
281283
scores[i] = {}
282284
for metric in metrics:
283-
scores[i][metric] = stats[metric][i]
285+
scores[i][metric] = stats[metric][i].cpu()
286+
284287
# print(f"{scores=}")
285288
print("Layers ordered by <MSE> importance:")
286289
res = sorted(
@@ -422,8 +425,7 @@ def reset_train_data_iterator():
422425
if is_pipeline_last_stage() and get_data_parallel_rank() == 0:
423426
scores = collect_scores(unwrapped_model)
424427
assert scores is not None
425-
with open(f"scores_{get_pipeline_model_parallel_rank()}.p", "wb") as f:
426-
pickle.dump(scores, f)
428+
torch.save(scores, args.save_scores_path)
427429

428430

429431
if __name__ == "__main__":

0 commit comments

Comments
 (0)