diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 76b72806d5..9bac2756b8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -24,6 +24,7 @@ modelopt/torch/nas @NVIDIA/modelopt-torch-nas-prune-codeowners modelopt/torch/opt @NVIDIA/modelopt-torch-opt-codeowners modelopt/torch/peft @NVIDIA/modelopt-torch-peft-codeowners modelopt/torch/prune @NVIDIA/modelopt-torch-nas-prune-codeowners +modelopt/torch/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners modelopt/torch/quantization @NVIDIA/modelopt-torch-quantization-codeowners modelopt/torch/sparsity @NVIDIA/modelopt-torch-sparsity-codeowners modelopt/torch/speculative @NVIDIA/modelopt-torch-speculative-codeowners @@ -49,6 +50,7 @@ modelopt_recipes @NVIDIA/modelopt-recipes-codeowners /examples/model_hub @NVIDIA/modelopt-examples-model_hub-codeowners /examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners /examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners +/examples/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners /examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners /examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners /examples/torch_onnx @NVIDIA/modelopt-onnx-codeowners diff --git a/.github/workflows/_example_tests_runner.yml b/.github/workflows/_example_tests_runner.yml index 71f9c3a443..b34aa87e0d 100644 --- a/.github/workflows/_example_tests_runner.yml +++ b/.github/workflows/_example_tests_runner.yml @@ -48,6 +48,7 @@ jobs: - name: Install dependencies run: | # use `python -m pip` instead of `pip` to avoid conflicts with system pip for nemo containers + pip uninstall -y nvidia-modelopt python -m pip install ".${{ inputs.pip_install_extras }}" if [[ "${{ inputs.example }}" == *"diffusers"* ]]; then @@ -64,7 +65,7 @@ jobs: COVERAGE_FILE: ${{ github.workspace }}/.coverage run: | echo "Running tests for: ${{ inputs.example }}" - pytest tests/examples/${{ inputs.example }} --cov + python -m pytest tests/examples/${{ inputs.example }} --cov - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index 085d858a99..e316618852 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -132,7 +132,7 @@ jobs: docker_image: "nvcr.io/nvidia/nemo:26.02" example: ${{ matrix.example }} timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" + pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-1 nemo-non-pr: @@ -144,7 +144,7 @@ jobs: docker_image: "nvcr.io/nvidia/nemo:26.02" example: ${{ matrix.example }} timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" + pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-2 ##### ONNX/TensorRT Example Tests ##### diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 9d36919a9a..30b47bb216 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -68,7 +68,7 @@ jobs: matrix: include: - example: gpu - timeout: 45 + timeout: 60 container_image: pytorch:26.01-py3 # tests/gpu/_extensions/test_onnx_extensions.py fails for newer containers until https://github.com/tbenthompson/cppimport/pull/98 - example: gpu-regression diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0f7a0da210..3c4c11a090 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -94,6 +94,7 @@ repos: modelopt/onnx/quantization/ort_patching.py| modelopt/torch/_deploy/utils/onnx_utils.py| modelopt/torch/export/transformer_engine.py| + modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py| modelopt/torch/quantization/export_onnx.py| modelopt/torch/quantization/plugins/attention.py| modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py| diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 78a70d6d94..63f8c79eca 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ Changelog **New Features** - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. +- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. - Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. diff --git a/docs/source/conf.py b/docs/source/conf.py index dbf66c18d2..6fe7a86002 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -31,6 +31,7 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) +import contextlib import os import sys @@ -44,6 +45,14 @@ sys.path.insert(0, os.path.abspath("../../")) sys.path.append(os.path.abspath("./_ext")) +# Pre-import modelopt.torch so it is cached in sys.modules before Sphinx applies +# autodoc_mock_imports. Mocking triton/tensorrt_llm at the Sphinx level can break +# transitive imports (transformers, transformer_engine, …) and cause modelopt.torch +# to fail inside autosummary. Importing here — while the real packages are still on +# sys.path — avoids that problem entirely. +with contextlib.suppress(Exception): + import modelopt.torch # noqa: F401 + # -- Project information ----------------------------------------------------- project = "Model Optimizer" # pylint: disable=C0103 diff --git a/examples/llm_eval/README.md b/examples/llm_eval/README.md index eae306cdc4..79f1b85d7e 100644 --- a/examples/llm_eval/README.md +++ b/examples/llm_eval/README.md @@ -40,6 +40,22 @@ accelerate launch --multi_gpu --num_processes \ --batch_size 4 ``` +### Heterogeneous Pruned Checkpoints (Puzzletron) + +Heterogeneous pruned checkpoints produced by Puzzletron are automatically detected and loaded with the appropriate model patcher. No additional flags are needed beyond specifying the checkpoint path: + +```sh +python lm_eval_hf.py --model hf \ + --model_args pretrained=path/to/anymodel/checkpoint,dtype=bfloat16,parallelize=True \ + --tasks mmlu \ + --num_fewshot 5 \ + --batch_size 4 +``` + +For a quick smoke test, add `--limit 10`. + +> **Note:** Requires the `puzzletron` extra to be installed (`pip install -e ".[puzzletron]"`). + ### Quantized (simulated) - For simulated quantization with any of the default quantization formats: diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 0ab19d0319..dbdf22d868 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -36,11 +36,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import warnings import datasets +import lm_eval from lm_eval import utils from lm_eval.__main__ import cli_evaluate, parse_eval_args, setup_parser + +if not lm_eval.__version__.startswith("0.4.8"): + warnings.warn( + f"lm_eval_hf.py is tested with lm-eval 0.4.8; found {lm_eval.__version__}. " + "Later versions may have incompatible API changes." + ) from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model @@ -50,9 +58,29 @@ from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified +try: + import modelopt.torch.puzzletron as mtpz + + _ANYMODEL_AVAILABLE = True +except ImportError: + _ANYMODEL_AVAILABLE = False + + +def _anymodel_patcher_context(pretrained, trust_remote_code=False): + """Return a deci_x_patcher context if *pretrained* is a Puzzletron checkpoint, else a no-op.""" + if not _ANYMODEL_AVAILABLE or not pretrained: + return contextlib.nullcontext() + try: + descriptor = mtpz.anymodel.resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=trust_remote_code + ) + except (ValueError, AttributeError): + return contextlib.nullcontext() + return mtpz.anymodel.deci_x_patcher(model_descriptor=descriptor) + def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: - """Overrides the HFLM.create_from_arg_obj""" + """Override HFLM.create_from_arg_obj to add quantization, sparsity, and Puzzletron support.""" quant_cfg = arg_dict.pop("quant_cfg", None) auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None) @@ -72,7 +100,10 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() - model_obj = cls(**arg_dict, **additional_config) + with _anymodel_patcher_context( + arg_dict.get("pretrained"), arg_dict.get("trust_remote_code", False) + ): + model_obj = cls(**arg_dict, **additional_config) model_obj.tokenizer.padding_side = "left" if is_quantized(model_obj.model): # return if model is already quantized @@ -109,10 +140,28 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | return model_obj +def create_from_arg_string( + cls: type[T], arg_string: str, additional_config: dict | None = None +) -> T: + """Override HFLM.create_from_arg_string to support Puzzletron checkpoints.""" + args = utils.simple_parse_args_string(arg_string) + additional_config = {} if additional_config is None else additional_config + args2 = {k: v for k, v in additional_config.items() if v is not None} + + mto.enable_huggingface_checkpointing() + + with _anymodel_patcher_context(args.get("pretrained"), args.get("trust_remote_code", False)): + model_obj = cls(**args, **args2) + + return model_obj + + HFLM.create_from_arg_obj = classmethod(create_from_arg_obj) +HFLM.create_from_arg_string = classmethod(create_from_arg_string) def setup_parser_with_modelopt_args(): + """Extend the lm-eval argument parser with ModelOpt quantization and sparsity options.""" parser = setup_parser() parser.add_argument( "--quant_cfg", diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 8d7f9b840c..571a0c4988 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -46,6 +46,9 @@ Note that the default dataset for pruning and quantization is [`nemotron-post-tr hf auth login --token ``` +> [!WARNING] +> Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. + ## Pruning This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). @@ -92,7 +95,7 @@ This section shows how to distill a student model from a teacher model in the Me This can be used stand-alone or after [Pruning](#pruning) / [Post-Training Quantization](#post-training-quantization) to recover accuracy of the model by distilling from the original model (teacher). -The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `/checkpoints` in Megatron distributed checkpoint format. +The [distill.py](distill.py) script supports both standard HuggingFace checkpoints and [Puzzletron AnyModel](../puzzletron/README.md) checkpoints as student/teacher inputs. Just pass the checkpoint path via `--student_hf_path` / `--teacher_hf_path`. The distilled model is saved to `/checkpoints` in Megatron distributed checkpoint format. ### Data Preparation @@ -158,9 +161,22 @@ torchrun --nproc_per_node 8 distill.py \ To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=` clause in your Slurm script. -### Convert Megatron checkpoint to Hugging Face format +### Converting to Hugging Face format (optional) + +The distilled checkpoint is saved in Megatron distributed format. If you need a HuggingFace checkpoint, there are two ways to convert it: + +**Inline** -- add `--hf_export_path` and `--student_hf_model` to the `distill.py` command to automatically convert the final checkpoint after distillation: + +```bash +torchrun --nnodes 1 --nproc_per_node 8 distill.py \ + ... \ + --hf_export_path /path/to/save/distilled_hf_ckpt \ + --student_hf_model Qwen/Qwen3-4B +``` + +`--student_hf_model` should match the base architecture of the student (used as a template for export). For non-Puzzletron (i.e. standard) models, it should be same as `--student_hf_path`. -To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`/checkpoints/iter_`) to run the following command: +**Separate conversion** -- convert any saved iteration using the Megatron-Bridge conversion script: ```bash uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ @@ -169,7 +185,11 @@ uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py ex --hf-path ``` -For more details, you can refer to the checkpoint conversion scripts in the [Megatron-Bridge README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion). +For more details, see the [Megatron-Bridge conversion README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion). + +### Distillation Results + +See [results/puzzletron.md](results/puzzletron.md) for MMLU results demonstrating knowledge distillation on Puzzletron-compressed student models. ## Post-Training Quantization diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index f725fa07ac..6adeb19e2f 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -15,17 +15,22 @@ """Distillation script for Megatron-Bridge. Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model -to `/checkpoints` in megatron distributed checkpoint format. +to `/checkpoints` in megatron distributed checkpoint or HuggingFace format. See `README.md` in this directory for example usage and data preparation instructions. """ import argparse +import contextlib import os +from dataclasses import fields import torch from megatron.bridge import AutoBridge -from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.models.distillation_provider import ( + DistillationProvider, + convert_to_distillation_provider, +) from megatron.bridge.recipes.utils.optimizer_utils import ( distributed_fused_adam_with_cosine_annealing, ) @@ -43,13 +48,50 @@ from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig from megatron.core.datasets.utils import get_blend_from_list from megatron.core.distributed import DistributedDataParallelConfig +from transformers import AutoConfig import modelopt.torch.utils.distributed as dist from modelopt.torch.utils import print_rank_0 +with contextlib.suppress(ModuleNotFoundError): + import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401 + SEED = 1234 +def _patched_to_cfg_dict(self): + """Patched DistillationProvider.to_cfg_dict method for heterogeneous teacher and student models. + + TODO: Upstream this patch to Megatron-Bridge. + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + # Use fields from the actual student provider class, not DistillationProvider. + # DistillationProvider's __dataclass_fields__ only includes TransformerConfig fields + # (set at class definition time), missing GPTModelProvider-level fields like + # vocab_size, share_embeddings_and_output_weights, etc. + excluded_fields = {"teacher", "kd_config"} + for field in fields(self._super_class): + if field.name.startswith("_") or field.name in excluded_fields: + continue + if hasattr(self, field.name): + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + for field in fields(self): + if field.name.startswith("_") or field.name in excluded_fields: + continue + if field.name not in result: + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + return result + + +DistillationProvider.to_cfg_dict = _patched_to_cfg_dict + + def get_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") @@ -124,12 +166,33 @@ def get_args(): ) parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + # Export arguments + parser.add_argument( + "--hf_export_path", + type=str, + default=None, + help=( + "Path where to save the HuggingFace export. " + "If provided, exports last iteration checkpoint to HF format after distillation." + ), + ) + parser.add_argument( + "--student_hf_model", + type=str, + required=False, + default=None, + help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). " + "Should match the base architecture of the student model if --hf_export_path is provided.", + ) args = parser.parse_args() # Sanity checks if not args.use_mock_data and not args.data_paths: raise ValueError("Must provide either --data_paths or set --use_mock_data.") + if args.hf_export_path and not args.student_hf_model: + raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.") + print_rank_0("\n==================== Arguments ====================") for k, v in args.__dict__.items(): print_rank_0(f"{k:<35} {v}") @@ -252,9 +315,35 @@ def _build_model_provider(hf_path): print_rank_0("\nStarting distillation...") distill(config) print_rank_0( - f"\nDistillation done! Saved checkpoint to {checkpoint_dir} in megatron distributed checkpoint format.\n" + f"\nDistillation done! Saved checkpoint to {checkpoint_dir}" + " in megatron distributed checkpoint format.\n" ) + if args.hf_export_path: + print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}") + # Save rank before destroying process group (dist.rank() won't work after destruction) + is_rank_0 = dist.rank() == 0 + + # Destroy process group on all ranks -- export_ckpt will create its own temporary one. + # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone). + dist.cleanup() + + if is_rank_0: + export_bridge = AutoBridge.from_hf_pretrained( + args.student_hf_model, trust_remote_code=args.trust_remote_code + ) + # Copy weights and remote code + export_bridge.export_ckpt( + megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}", + hf_path=args.hf_export_path, + show_progress=True, + strict=True, + ) + # Copy config.json from student_hf_path (handles both local paths and HF model IDs) + AutoConfig.from_pretrained( + args.student_hf_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(args.hf_export_path) + if __name__ == "__main__": dist.setup() diff --git a/examples/megatron_bridge/results/puzzletron.md b/examples/megatron_bridge/results/puzzletron.md new file mode 100644 index 0000000000..89ba114f58 --- /dev/null +++ b/examples/megatron_bridge/results/puzzletron.md @@ -0,0 +1,42 @@ +# Puzzletron Distillation Results + +The following MMLU results demonstrate knowledge distillation on student models that were first compressed using [Puzzletron](../../puzzletron/README.md). The original (uncompressed) model serves as the teacher, and distillation recovers accuracy lost during compression. + +## Qwen3-8B compressed to 80% of original + +The student was created by compressing Qwen3-8B to 80% of its original size using Puzzletron. + +| Model | MMLU | Humanities | Other | Social Sci | STEM | +|-------|------|------------|-------|------------|------| +| Student (before distillation) | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 | +| Student (after distillation) | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 | +| Teacher (original Qwen3-8B) | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 | + +MMLU accuracy improved from 59.10% to 69.21% (+10.11 pp) after distillation with just 100 iterations on WikiText-103, recovering 64% of the gap to the teacher model. + +## Llama-3.1-8B-Instruct compressed to 50% of original + +The student was created by compressing Llama-3.1-8B-Instruct to 50% of its original size using Puzzletron. + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Student (before distillation) | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 | +| Student (after distillation) | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 | +| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +## Llama-3.1-8B-Instruct compressed to 69% of original (regression) + +The student was created by compressing Llama-3.1-8B-Instruct to ~69% of its original size using Puzzletron. This example shows regression due to overfitting on the small WikiText-103 dataset (100 iterations). MMLU was evaluated on a subset of 100 samples per task: + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Student (before distillation) | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 | +| Student (after distillation) | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 | +| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +MMLU decreased from 66.26% to 64.96% (-1.30 pp) -- the model overfitted to WikiText-103. This highlights the importance of using larger, more diverse datasets for distillation. + +## Recommendations + +- **Use larger datasets** for production distillation (e.g., [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1)) to avoid overfitting as shown in the regression case above. +- **Train for more iterations** to ensure proper convergence. diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 656d6315db..930e9c6d25 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -7,6 +7,7 @@ Pruning can involve removal (prune) of Linear and Conv layers; and Transformer a This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: 1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. +1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm. 1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints. 1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints. diff --git a/examples/puzzletron/GPTOSS.md b/examples/puzzletron/GPTOSS.md new file mode 100644 index 0000000000..7c160c8997 --- /dev/null +++ b/examples/puzzletron/GPTOSS.md @@ -0,0 +1,14 @@ + +## GptOss + +With this release Puzzle algorithm supports only experts removal for `Gpt-Oss`. + +This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format. +In the pruning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation. +This means, during the conversion to puzzle format we decompress the model and store it as a BF16. +Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint. +To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format. + +```bash +python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24 +``` diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md new file mode 100644 index 0000000000..5bd7c65064 --- /dev/null +++ b/examples/puzzletron/README.md @@ -0,0 +1,306 @@ +# Puzzletron Algorithm Tutorial + +This tutorial demonstrates how to compress large language models using the puzzletron algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. +The supported modifications are: + +- `ffn_intermediate_size`: different FFN intermediate sizes +- `attention op/noop`: complete removal of attention layers + +To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. + +In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). + +> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). + +## Environment + +### Container setup (NeMo) + +The recommended way to run puzzletron is inside an NVIDIA NeMo container (e.g. `nvcr.io/nvidia/nemo:26.02`). NeMo containers ship a pre-installed `nvidia-modelopt` that does not include the puzzletron extras so you need to replace it with an editable install from this repo. + +> [!WARNING] +> Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. + +> [!NOTE] +> NeMo containers ship `nvidia-lm-eval` which may conflict with `lm-eval` that is used for evaluation, hence we uninstall and replace it with `lm-eval` from the repo. + +Once inside the container with the repo available, install dependencies from the repo root: + +```bash +python -m pip uninstall nvidia-lm-eval -y 2>/dev/null +python -m pip install -e ".[hf,puzzletron,dev-test]" +python -m pip install -r examples/puzzletron/requirements.txt +``` + +To verify the install, you can run the GPU tests as a smoke check: + +```bash +python -m pytest tests/gpu/torch/puzzletron/test_puzzletron.py -k "Qwen3-8B" +``` + +### Hardware + +- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use a single GPU. + +- To make use of [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), you need to accept the terms and conditions for the corresponding model and the dataset in the Huggingface Hub. Log in to the Huggingface Hub and enter your HF token. + +```bash +hf auth login --token +``` + +## Compress the Model + +1. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + + dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) + + ```bash + python -m modelopt.torch.puzzletron.dataset.prepare_dataset \ + --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 \ + --output_dir path/to/Nemotron-Post-Training-Dataset-v2 + ``` + +2. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. + + - `puzzle_dir` indicates a new directory for saving the resulting model. + - `input_hf_model_path` indicates the local directory with the input model checkpoint. + - `dataset_path` indicates the directory with the dataset downloaded earlier. + + **_NOTE:_** + How to choose `intermediate_size_list`? + The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. + + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + + We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model. + +3. Run the puzzletron pipeline. + + ```bash + torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress" + ``` + + This will save the full output to `log.txt` and display the following progress on screen: + + ```bash + [2025-11-02 12:06:34][rank-0][main.py:71] Puzzletron Progress 1/8: starting puzzletron pipeline + [2025-11-02 12:06:45][rank-0][puzzletron_nas_plugin.py:123] Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu) + [2025-11-02 12:07:07][rank-0][puzzletron_nas_plugin.py:132] Puzzletron Progress 3/8: scoring pruning activations (multi-gpu) + [2025-11-02 12:11:36][rank-0][puzzletron_nas_plugin.py:137] Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) + [2025-11-02 12:12:20][rank-0][puzzletron_nas_plugin.py:217] Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu) + [2025-11-02 12:12:21][rank-0][puzzletron_nas_plugin.py:222] Puzzletron Progress 6/8: calculating one block scores (multi-gpu) + [2025-11-02 12:50:41][rank-0][puzzletron_nas_plugin.py:226] Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu) + [2025-11-02 12:52:34][rank-0][main.py:115] Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu) + ``` + + Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review: + + ```bash + ... + block_0: attention gqa_4 ffn intermediate_14336 + block_1: attention gqa_4 ffn intermediate_14336 + block_2: attention gqa_4 ffn intermediate_14336 + block_3: attention gqa_4 ffn intermediate_14336 + block_4: attention gqa_4 ffn intermediate_14336 + block_5: attention gqa_4 ffn intermediate_14336 + block_6: attention gqa_4 ffn intermediate_14336 + block_7: attention gqa_4 ffn intermediate_14336 + block_8: attention gqa_4 ffn intermediate_14336 + block_9: attention gqa_4 ffn intermediate_14336 + block_10: attention gqa_4 ffn intermediate_14336 + block_11: attention gqa_4 ffn intermediate_14336 + block_12: attention gqa_4 ffn intermediate_14336 + block_13: attention gqa_4 ffn intermediate_14336 + block_14: attention gqa_4 ffn intermediate_14336 + block_15: attention gqa_4 ffn intermediate_14336 + block_16: attention gqa_4 ffn intermediate_14336 + block_17: attention no_op ffn intermediate_14336 + block_18: attention no_op ffn intermediate_14336 + block_19: attention no_op ffn intermediate_14336 + block_20: attention no_op ffn intermediate_14336 + block_21: attention no_op ffn intermediate_14336 + block_22: attention no_op ffn intermediate_14336 + block_23: attention no_op ffn intermediate_14336 + block_24: attention no_op ffn intermediate_14336 + block_25: attention no_op ffn intermediate_14336 + block_26: attention no_op ffn intermediate_14336 + block_27: attention no_op ffn intermediate_14336 + block_28: attention no_op ffn intermediate_14336 + block_29: attention gqa_4 ffn intermediate_14336 + block_30: attention gqa_4 ffn intermediate_14336 + block_31: attention gqa_4 ffn intermediate_14336 + + [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32} + ... + ################################################################ + validate_model_and_extract_token_probs(model_name='teacher') + ################################################################ + ... + Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883} + ... + ################################################################ + validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) + ################################################################ + .... + Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227} + ``` + + 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). + +## Re-run MIP Search with different constraints + +If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. +This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed. + +For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. + +```bash +torchrun --nproc_per_node 2 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml \ + --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" +``` + +This will generate the following network architecture (see `log.txt`): + +```bash +block_0: attention gqa_4 ffn intermediate_14336 +block_1: attention gqa_4 ffn intermediate_14336 +block_2: attention gqa_4 ffn intermediate_14336 +block_3: attention gqa_4 ffn intermediate_14336 +block_4: attention gqa_4 ffn intermediate_14336 +block_5: attention gqa_4 ffn intermediate_14336 +block_6: attention gqa_4 ffn intermediate_14336 +block_7: attention gqa_4 ffn intermediate_14336 +block_8: attention gqa_4 ffn intermediate_14336 +block_9: attention gqa_4 ffn intermediate_14336 +block_10: attention gqa_4 ffn intermediate_14336 +block_11: attention gqa_4 ffn intermediate_14336 +block_12: attention gqa_4 ffn intermediate_14336 +block_13: attention gqa_4 ffn intermediate_14336 +block_14: attention gqa_4 ffn intermediate_14336 +block_15: attention gqa_4 ffn intermediate_14336 +block_16: attention gqa_4 ffn intermediate_14336 +block_17: attention gqa_4 ffn intermediate_14336 +block_18: attention no_op ffn intermediate_14336 +block_19: attention no_op ffn intermediate_14336 +block_20: attention no_op ffn intermediate_14336 +block_21: attention gqa_4 ffn intermediate_14336 +block_22: attention no_op ffn intermediate_14336 +block_23: attention no_op ffn intermediate_14336 +block_24: attention no_op ffn intermediate_14336 +block_25: attention gqa_4 ffn intermediate_14336 +block_26: attention gqa_4 ffn intermediate_14336 +block_27: attention gqa_4 ffn intermediate_14336 +block_28: attention gqa_4 ffn intermediate_14336 +block_29: attention gqa_4 ffn intermediate_14336 +block_30: attention gqa_4 ffn intermediate_14336 +block_31: attention gqa_4 ffn intermediate_14336 + +[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208} +... +################################################################ +validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) +################################################################ +Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992 +``` + +On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details): + +```bash +block_5: attention no_op ffn intermediate_11520 +block_6: attention no_op ffn intermediate_14336 +block_7: attention no_op ffn intermediate_8704 +block_8: attention no_op ffn intermediate_14336 +block_9: attention no_op ffn intermediate_3072 +block_10: attention no_op ffn intermediate_11520 +block_11: attention no_op ffn intermediate_11520 +block_12: attention no_op ffn intermediate_11520 +block_13: attention no_op ffn intermediate_11520 +block_14: attention no_op ffn intermediate_3072 +``` + +### MIP Sweep Mode + +The **MIP sweep mode** lets you explore multiple memory compression rates in a single run and compare the accuracy-memory trade-offs. + +#### Quick Start + +1. Enable sweep in your config YAML (e.g., `llama-3_1-8B_pruneffn_memory.yaml`): + + ```yaml + mip: + sweep: + enabled: true + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + output_csv: ${puzzle_dir}/mip_sweep_results.csv + ``` + +2. Run the sweep: + + ```bash + torchrun --nproc_per_node 2 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml \ + --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" + ``` + +3. View results: The CSV file contains compression rates, memory usage, and accuracy metrics for each configuration. + +#### Example Results + +MIP Sweep Results + +The plot shows how token accuracy changes with different compression rates. Higher compression (0.5 = 50% of original memory) reduces accuracy, while lower compression maintains accuracy closer to the teacher model. + +## Evaluation + +Evaluate AnyModel checkpoints using [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) directly. + +```bash +python examples/llm_eval/lm_eval_hf.py \ + --model hf \ + --model_args pretrained=path/to/checkpoint,dtype=bfloat16,parallelize=True \ + --tasks mmlu \ + --num_fewshot 5 \ + --batch_size 4 +``` + +For a quick smoke test, add `--limit 10`. + +> **Alternative:** For server-based evaluation via an OpenAI-compatible endpoint, +> see [evaluation/nemo_evaluator_instructions.md](./evaluation/nemo_evaluator_instructions.md). + +## Inference Performance Benchmarking + +Now let's evaluate how much speedup we get with the compressed model in terms of throughput and latency. + +- Install [vLLM from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source). +- Rearrange the model safetensors to be used for vLLM. + +```bash +cd path/to/model +mv subblocks_safetensors/* . +sed -i 's+subblocks_safetensors/++g' model.safetensors.index.json +``` + +- Benchmark latency + +```bash +vllm bench latency --model path/to/model --load-format safetensors +``` + +- Benchmark throughput + +```bash +vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors +``` + +## Knowledge Distillation + +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. + +See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) for instructions on using Megatron-Bridge for knowledge distillation. The distillation script supports both standard HuggingFace and Puzzletron AnyModel checkpoints. + +## Advanced Usage + +Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml new file mode 100644 index 0000000000..b48f1de78c --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: gpt_oss +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml new file mode 100644 index 0000000000..8ed06e9568 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml @@ -0,0 +1,17 @@ +defaults: + - gptoss-20b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 16_000 # 45 GiB diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..258e6c38a3 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +eval_samples: 2500 #10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0eff799d7e --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10_000 +micro_batch_size: 1 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml new file mode 100644 index 0000000000..21903db162 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 7_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml new file mode 100644 index 0000000000..ad16dbc5ea --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -0,0 +1,26 @@ +defaults: + - Llama-3_1-8B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for puzzletron outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + # Memory sweep configuration (optional) + sweep: + enabled: false + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9] + output_csv: ${puzzle_dir}/mip_sweep_results.csv + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..da0b972070 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,19 @@ +defaults: + - pruning_defaults + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml new file mode 100644 index 0000000000..7de281e788 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml new file mode 100644 index 0000000000..b5303d318a --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - Llama-3_2-3B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.2-3B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 45_000 # 45 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 8192, so we use proportionally smaller values +pruning: + intermediate_size_list: [2048, 4096, 6144] diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..05de8bfdcc --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# Llama-3.2-3B has intermediate_size=8192, so we use proportionally smaller pruning sizes +intermediate_size_list: [2048, 4096, 6144] +mlp_init_mode: "PruneByActivationsLog" + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml new file mode 100644 index 0000000000..18213f9b7a --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: mistral_small +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 24_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml new file mode 100644 index 0000000000..68a0652d6f --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml @@ -0,0 +1,21 @@ +defaults: + - Mistral-Small-24B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/mistralai/Mistral-Small-24B-Instruct-2501 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 234_000 # 234 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [8192, 16384, 24576] # teacher_intermediate_size is 32768 diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..cb24e1bc24 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,17 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# Mistral Small 24B: 32 query heads, 8 KV heads +# n_heads_in_group = num_query_heads / num_kv_heads +# num_kv_heads = num_query_heads / n_heads_in_group +# Base: n_heads_in_group = 4, num_kv_heads = 8 +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..5fb7fcbdd2 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 32768 +intermediate_size_list: [8192, 16384, 24576] +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..7de32621e0 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +# Mistral Small 24B: hidden_size is 5120 +hidden_size_list: [3072, 4096] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml new file mode 100644 index 0000000000..62b6ecb4cb --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: nemotron_h_v2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 90_000 + num_params: 12_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml new file mode 100644 index 0000000000..3b880b2c7d --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - nemotron_nano_12b_v2 + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/nvidia/Nemotron-Nano-12B-v2 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 90_000 # 90 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 20480 +pruning: + intermediate_size_list: [4352, 8448, 12544, 16384] diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..1e2ecf07a0 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..3f7a248ee7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..18d7e234ac --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..af8af990b7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml new file mode 100644 index 0000000000..aa11499a3c --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 7_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml new file mode 100644 index 0000000000..fb961033bc --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen2_5_7b_instruct + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen2.5-7B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 18944 +pruning: + intermediate_size_list: [4096, 7808, 11520, 15104] diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..93590d13e5 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml new file mode 100644 index 0000000000..eec82a7d63 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen3 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 8_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml new file mode 100644 index 0000000000..4ee81286dd --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen3_8b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen3-8B + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 12288 +pruning: + intermediate_size_list: [2560, 5120, 7424, 9984] diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/evaluation/hf_deployable_anymodel.py b/examples/puzzletron/evaluation/hf_deployable_anymodel.py new file mode 100644 index 0000000000..d0055dde63 --- /dev/null +++ b/examples/puzzletron/evaluation/hf_deployable_anymodel.py @@ -0,0 +1,720 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import json +import logging +from typing import Any + +import numpy as np +import torch +from nemo_deploy import ITritonDeployable +from nemo_deploy.utils import broadcast_list, cast_output, str_ndarray2list +from nemo_export_deploy_common.import_utils import ( + MISSING_TRITON_MSG, + UnavailableError, + null_decorator, +) +from peft import PeftModel +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.puzzletron as mtpz + +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor + + HAVE_TRITON = True +except (ImportError, ModuleNotFoundError): + from unittest.mock import MagicMock + + HAVE_TRITON = False + batch = MagicMock() + Tensor = MagicMock() + batch = null_decorator + + +LOGGER = logging.getLogger("NeMo") + +SUPPORTED_TASKS = ["text-generation"] + + +class HuggingFaceLLMDeploy(ITritonDeployable): + """A Triton inference server compatible wrapper for HuggingFace models. + + This class provides a standardized interface for deploying HuggingFace models + in Triton inference server. It supports various NLP tasks and handles model + loading, inference, and deployment configurations. + + Args: + hf_model_id_path (Optional[str]): Path to the HuggingFace model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + hf_peft_model_id_path (Optional[str]): Path to the PEFT model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + tokenizer_id_path (Optional[str]): Path to the tokenizer or tokenizer identifier. + If None, will use the same path as hf_model_id_path. + model (Optional[AutoModel]): Pre-loaded HuggingFace model. + tokenizer (Optional[AutoTokenizer]): Pre-loaded HuggingFace tokenizer. + tokenizer_padding (bool): Whether to enable padding in tokenizer. Defaults to True. + tokenizer_truncation (bool): Whether to enable truncation in tokenizer. Defaults to True. + tokenizer_padding_side (str): Which side to pad on ('left' or 'right'). Defaults to 'left'. + task (str): HuggingFace task type (e.g., "text-generation"). Defaults to "text-generation". + **hf_kwargs: Additional keyword arguments to pass to HuggingFace model loading. + """ + + def __init__( + self, + hf_model_id_path: str | None = None, + hf_peft_model_id_path: str | None = None, + tokenizer_id_path: str | None = None, + model: AutoModel | None = None, + tokenizer: AutoTokenizer | None = None, + tokenizer_padding=True, + tokenizer_truncation=True, + tokenizer_padding_side="left", + task: str | None = "text-generation", + torch_dtype: torch.dtype | None = "auto", + device_map: str | None = "auto", + **hf_kwargs, + ): + if not HAVE_TRITON: + raise UnavailableError(MISSING_TRITON_MSG) + + if hf_model_id_path is None and model is None: + raise ValueError("hf_model_id_path or model parameters has to be passed.") + elif hf_model_id_path is not None and model is not None: + LOGGER.warning( + "hf_model_id_path will be ignored and the HuggingFace model set with model parameter will be used." + ) + + assert task in SUPPORTED_TASKS, "Task {} is not a support task.".format(task) + + self.hf_model_id_path = hf_model_id_path + self.hf_peft_model_id_path = hf_peft_model_id_path + self.task = task + self.model = model + self.tokenizer = tokenizer + self.tokenizer_padding = tokenizer_padding + self.tokenizer_truncation = tokenizer_truncation + self.tokenizer_padding_side = tokenizer_padding_side + + if tokenizer_id_path is None: + self.tokenizer_id_path = hf_model_id_path + else: + self.tokenizer_id_path = tokenizer_id_path + + if model is None: + self._load(torch_dtype=torch_dtype, device_map=device_map, **hf_kwargs) + + def _load( + self, torch_dtype: torch.dtype | None = "auto", device_map: str | None = "auto", **hf_kwargs + ) -> None: + """Load the HuggingFace pipeline with the specified model and task. + + This method initializes the HuggingFace AutoModel classes using the provided model + configuration and task type. It handles the model and tokenizer loading + process. + + Args: + torch_dtype (torch.dtype): Data type for the model. Defaults to "auto". + device_map (str): Device map for the model. Defaults to "auto". + **hf_kwargs: Additional keyword arguments to pass to the HuggingFace model loading. + + Raises: + AssertionError: If task is not specified. + """ + assert self.task is not None, "A task has to be given for the generation task." + + if self.task == "text-generation": + # ========================================================================= + # BEGIN ANYMODEL PATCH + # Wraps model loading with deci_x_patcher for heterogeneous layer configs. + # See: modelopt/torch/puzzletron/anymodel/puzzformer/patcher.py + # ========================================================================= + + descriptor = mtpz.anymodel.resolve_descriptor_from_pretrained( + self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False) + ) + + with mtpz.anymodel.deci_x_patcher(model_descriptor=descriptor): + self.model = AutoModelForCausalLM.from_pretrained( + self.hf_model_id_path, + torch_dtype=torch_dtype, + device_map=device_map, + **hf_kwargs, + ) + # ========================================================================= + # END ANYMODEL PATCH + # ========================================================================= + + if self.hf_peft_model_id_path is not None: + self.model = PeftModel.from_pretrained(self.model, self.hf_peft_model_id_path) + else: + raise ValueError("Task {} is not supported.".format(self.task)) + num_gpus = torch.cuda.device_count() + # If there is only one GPU, move the model to GPU. If you are using device_map as "auto" or "balanced", + # the model will be moved to GPU automatically. + if device_map is None and num_gpus >= 1 and self.model.device.type != "cuda": + self.model.cuda() + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id_path, + trust_remote_code=hf_kwargs.pop("trust_remote_code", False), + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + padding_side=self.tokenizer_padding_side, + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def generate( + self, + **kwargs: Any, + ) -> list[str]: + """Generate text based on the provided input prompts. + + This method processes input prompts through the loaded pipeline and + generates text according to the specified parameters. + + Args: + **kwargs: Generation parameters including: + - text_inputs: List of input prompts + - max_length: Maximum number of tokens to generate + - num_return_sequences: Number of sequences to generate per prompt + - temperature: Sampling temperature + - top_k: Number of highest probability tokens to consider + - top_p: Cumulative probability threshold for token sampling + - do_sample: Whether to use sampling, default is False for greedy decoding + - echo: Whether to return prompt + generated text (True) or just generated text (False) + - return_full_text: Whether to return full text or only generated part + + Returns: + If output logits and output scores are False: + List[str]: A list of generated texts, one for each input prompt. + If output logits and output scores are True: + Dict: A dictionary containing: + - sentences: List of generated texts + - logits: List of logits + - scores: List of scores + - input_lengths: List of input token lengths (for echo processing) + + Raises: + RuntimeError: If the pipeline is not initialized. + """ + if not self.model: + raise RuntimeError("Model is not initialized") + + inputs = self.tokenizer( + kwargs["text_inputs"], + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + + # Store input lengths to extract only generated tokens later + input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]] + + # Get echo parameter (default False - only return generated text) + echo = kwargs.pop("echo", False) + kwargs.pop("text_inputs") # Remove text_inputs as it's already been tokenized + + kwargs = {**inputs, **kwargs} + for key, val in kwargs.items(): + if torch.is_tensor(val): + kwargs[key] = val.cuda() + + with torch.no_grad(): + generated_ids = self.model.generate(**kwargs) + return_dict_in_generate = kwargs.get("return_dict_in_generate", False) + if return_dict_in_generate: + # Handle dict output (when logits/scores are requested) + sequences = generated_ids["sequences"] + output = {"sentences": [], "input_lengths": input_lengths, "sequences": sequences} + + if echo: + # Return full text (prompt + generated). + # HF model's generate returns the input/prompt tokens as well by default. + for i, seq in enumerate(sequences): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output["sentences"].append(full_text) + else: + # Extract only the generated tokens (skip input tokens). + # This is required as HF model's generate returns the input/prompt tokens + # as well by default. (return_full_text is specific to some models) + for i, seq in enumerate(sequences): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output["sentences"].append(generated_text) + + if kwargs.get("output_logits", False): + output["logits"] = generated_ids["logits"] + if kwargs.get("output_scores", False): + output["scores"] = generated_ids["scores"] + else: + # Handle list output (normal case) + output = [] + if echo: + # Return full text (prompt + generated), which is the default in case of HF model generate. + for i, seq in enumerate(generated_ids): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output.append(full_text) + else: + # Extract only the generated tokens (skip input tokens) as the default + # behavior returns the input/prompt tokens as well. + for i, seq in enumerate(generated_ids): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output.append(generated_text) + + return output + + def generate_other_ranks(self): + """Generate function for ranks other than the rank 0.""" + while True: + message = torch.empty(1, dtype=torch.long, device="cuda") + torch.distributed.broadcast(message, src=0) + if message == 0: + prompts = broadcast_list(data=[None], src=0) + ( + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ) = broadcast_list(data=[None], src=0) + + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + ) + else: + return + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="output_logits", shape=(-1,), dtype=np.bool_, optional=True), + Tensor(name="output_scores", shape=(-1,), dtype=np.bool_, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + return ( + Tensor(name="sentences", shape=(-1,), dtype=bytes), + Tensor(name="logits", shape=(-1,), dtype=np.single), + Tensor(name="scores", shape=(-1,), dtype=np.single), + ) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + output_infer = {} + + try: + prompts = str_ndarray2list(inputs.pop("prompts")) + temperature = inputs.pop("temperature")[0][0] if "temperature" in inputs else 1.0 + top_k = int(inputs.pop("top_k")[0][0] if "top_k" in inputs else 1) + top_p = inputs.pop("top_p")[0][0] if "top_p" in inputs else 0 + num_tokens_to_generate = ( + inputs.pop("max_length")[0][0] if "max_length" in inputs else 256 + ) + output_logits = ( + inputs.pop("output_logits")[0][0] if "output_logits" in inputs else False + ) + output_scores = ( + inputs.pop("output_scores")[0][0] if "output_scores" in inputs else False + ) + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=False, + ) + + if isinstance(output, dict): + output_infer = {"sentences": cast_output(output["sentences"], np.bytes_)} + + if "scores" in output: + output_scores = [] + for r in output["scores"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_scores.append([0]) + else: + output_scores.append(lp) + output_infer["scores"] = np.array(output_scores).transpose(1, 0, 2) + + if "logits" in output: + output_logits = [] + for r in output["logits"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_logits.append([0]) + else: + output_logits.append(lp) + output_infer["logits"] = np.array(output_logits).transpose(1, 0, 2) + else: + output_infer = {"sentences": cast_output(output, np.bytes_)} + + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + output_infer["sentences"] = cast_output([err_msg], np.bytes_) + + return output_infer + + def _compute_logprobs( + self, + prompts: list[str], + output_infer: dict[str, Any], + compute_logprob: bool, + n_top_logprobs: int, + echo: bool, + ): + """Compute log probabilities and top log probabilities from model scores. + Used by ray_infer_fn to provide OAI API compatible output for evaluations. + + This method processes the raw scores from model generation to compute: + - Log probabilities for chosen tokens + - Top-k log probabilities for each position (if requested) + - Handles both prompt tokens (when echo=True) and generated tokens + + Args: + prompts: List of input prompts + output_infer: Dictionary containing model outputs including scores, sequences, and input_lengths + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return (0 to disable) + echo: Whether to include prompt token log probabilities + + Returns: + Tuple[Optional[List], Optional[List]]: + - log_probs_list: List of log probabilities for each sample (None if not computed) + - top_logprobs_list: List of top-k log probabilities for each sample (None if not computed) + """ + # Tokenize the prompts to get prompt token IDs (needed for echo) + prompt_token_ids = None + prompt_inputs = None + if echo: + prompt_inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + prompt_token_ids = prompt_inputs["input_ids"] + # Move to same device as model + for key, val in prompt_inputs.items(): + if torch.is_tensor(val): + prompt_inputs[key] = val.cuda() + + # Process each sample + log_probs_list = [] + top_logprobs_list = [] + + for sample_idx in range(len(prompts)): + sample_log_probs = [] + sample_top_logprobs = [] + + # Get the generated sequence for this sample + sequences = output_infer["sequences"][sample_idx] + + # For echo, compute prompt token logprobs by running forward pass + if echo and prompt_token_ids is not None: + prompt_len = len(prompt_token_ids[sample_idx]) + + # Run forward pass on prompt to get logits for prompt tokens as scores in output_infer contains + # logits only for generated tokens. + with torch.no_grad(): + # Create input for this specific sample + sample_prompt_input = { + key: val[sample_idx : sample_idx + 1] for key, val in prompt_inputs.items() + } + prompt_outputs = self.model(**sample_prompt_input) + prompt_logits = prompt_outputs.logits[0] # Shape: [seq_len, vocab_size] + + # Calculate log probs for each prompt token (except the first BOS token) + for token_pos in range(1, prompt_len): # Start from 1 to skip BOS + # The logit at position i-1 predicts token at position i + logit_for_current_token = prompt_logits[token_pos - 1] + current_token_id = prompt_token_ids[sample_idx][token_pos].item() + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(logit_for_current_token, dim=-1) + chosen_log_prob = log_probs[current_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + # Process the scores for generated tokens + for token_idx, score_tensor in enumerate(output_infer["scores"]): + # Get the chosen token ID from the sequence + # Scores start after the prompt, so we need to offset + input_len = ( + output_infer.get("input_lengths", [0])[sample_idx] + if "input_lengths" in output_infer + else 0 + ) + seq_idx = input_len + token_idx + + if seq_idx < len(sequences): + chosen_token_id = ( + sequences[seq_idx].item() + if hasattr(sequences[seq_idx], "item") + else sequences[seq_idx] + ) + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(score_tensor[sample_idx], dim=-1) + chosen_log_prob = log_probs[chosen_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + log_probs_list.append(sample_log_probs) + if n_top_logprobs > 0: + top_logprobs_list.append(sample_top_logprobs) + + # Return log probs and top logprobs + return_log_probs = log_probs_list if compute_logprob else None + return_top_logprobs = top_logprobs_list if n_top_logprobs > 0 else None + + return return_log_probs, return_top_logprobs + + def ray_infer_fn(self, inputs: dict[Any, Any]): + """Perform inference using Ray with dictionary inputs and outputs. + + Args: + inputs (Dict[Any, Any]): Dictionary containing input parameters: + - prompts: List of input prompts + - temperature: Sampling temperature (optional) + - top_k: Number of highest probability tokens to consider (optional) + - top_p: Cumulative probability threshold for token sampling (optional) + - max_tokens: Maximum number of tokens to generate (optional) + - compute_logprob: Whether to compute log probabilities (optional) + - n_top_logprobs: Number of top log probabilities to return (optional) + - echo: Whether to echo the prompt in output (optional) + + Returns: + Dict[str, Any]: Dictionary containing: + - sentences: List of generated texts + - log_probs: Optional list of log probabilities if compute_logprob is True + - top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0 + """ + try: + prompts = inputs.pop("prompts") + temperature = inputs.pop("temperature", 1.0) + top_k = int(inputs.pop("top_k", 1)) + top_p = inputs.pop("top_p", 0.0) + num_tokens_to_generate = inputs.pop("max_tokens", 256) + output_logits = inputs.pop("output_logits", False) + output_scores = inputs.pop("output_scores", False) + compute_logprob = inputs.pop("compute_logprob", False) + n_top_logprobs = inputs.pop("n_top_logprobs", 0) + echo = inputs.pop("echo", False) + + output_infer = self._infer_fn_ray( + prompts=prompts, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_tokens_to_generate=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + # Code to get logprobs (required in OAI API format for eval) from the scores in output_infer. + if ( + (compute_logprob or n_top_logprobs > 0) + and "scores" in output_infer + and output_infer["scores"] + ): + log_probs_list, top_logprobs_list = self._compute_logprobs( + prompts=prompts, + output_infer=output_infer, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + + # Add to output + if log_probs_list is not None: + output_infer["log_probs"] = log_probs_list + if top_logprobs_list is not None: + # Convert to JSON strings for compatibility + output_infer["top_logprobs"] = [ + json.dumps(top_logprobs) for top_logprobs in top_logprobs_list + ] + + # Remove raw outputs that are not needed in the final response + output_infer.pop("scores", None) + output_infer.pop("sequences", None) + output_infer.pop("input_lengths", None) + return output_infer + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + return {"sentences": [err_msg]} + + def _infer_fn_ray( + self, + prompts, + temperature=1.0, + top_k=1, + top_p=0.0, + num_tokens_to_generate=256, + output_logits=False, + output_scores=False, + compute_logprob=False, + n_top_logprobs=0, + echo=False, + cast_output_func=None, + ): + """Common internal function for inference operations. + + Args: + prompts: List of input prompts + temperature: Sampling temperature + top_k: Number of highest probability tokens to consider + top_p: Cumulative probability threshold for token sampling + num_tokens_to_generate: Maximum number of tokens to generate + output_logits: Whether to output logits + output_scores: Whether to output scores + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return + echo: Whether to echo the prompt in output + cast_output_func: Optional function to cast output values + + Returns: + Dict containing inference results with raw outputs + """ + # Enable return_dict if we need scores for logprobs or if output_logits/scores are requested + return_dict_in_generate = ( + output_logits or output_scores or compute_logprob or n_top_logprobs > 0 + ) + # Enable output_scores if we need to compute logprobs. scores and logits from generate are both identical in + # case of greedy decoding. Hence setting output_scores to True when compute_logprob or n_top_logprobs > 0. + if compute_logprob or n_top_logprobs > 0: + output_scores = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=echo, + ) + + if isinstance(output, dict): + return output + + else: + return {"sentences": output} diff --git a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md new file mode 100644 index 0000000000..f8b53889c6 --- /dev/null +++ b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md @@ -0,0 +1,70 @@ +# Evaluation with NeMo Evaluator (Alternative) + +> **Recommended approach:** Use lm-eval for direct evaluation without a +> deployment server. See the main [README](../README.md#evaluation) for details. + +Evaluate AnyModel checkpoints by deploying a local OpenAI-compatible completions endpoint and running benchmarks against it. + +This flow requires Ray for serving the model and NeMo Export-Deploy (included in NeMo containers): + +```bash +pip install -r examples/puzzletron/requirements.txt +``` + +**1. Deploy the model (2 GPUs example):** + +We need to patch the `hf_deployable.py` script from Export-Deploy. Best way is to do it as a mount in docker run: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/examples/puzzletron/evaluation/hf_deployable_anymodel.py:/opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Alternatively you can manually update the file + +```bash +# Install the AnyModel-patched deployable (first time only: backs up the original) +# /opt/Export-Deploy is the default path in NeMo containers — adjust if needed +cp /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py.bak +cp examples/puzzletron/evaluation/hf_deployable_anymodel.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py +``` + +Now start ray server and deploy the model + +```bash +# Start the server (blocks while running — use a separate terminal) +ray start --head --num-gpus 2 --port 6379 --disable-usage-stats +python /opt/Export-Deploy/scripts/deploy/nlp/deploy_ray_hf.py \ + --model_path path/to/checkpoint \ + --model_id anymodel-hf \ + --num_gpus 2 --num_gpus_per_replica 2 --num_cpus_per_replica 16 \ + --trust_remote_code --port 8083 --device_map "auto" --cuda_visible_devices "0,1" +``` + +**2. Run MMLU:** + +```bash +eval-factory run_eval \ + --eval_type mmlu \ + --model_id anymodel-hf \ + --model_type completions \ + --model_url http://0.0.0.0:8083/v1/completions/ \ + --output_dir examples/puzzletron/evals/mmlu_anymodel +``` + +For a quick debug run, add `--overrides "config.params.limit_samples=5"`. diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py new file mode 100644 index 0000000000..8ceed37831 --- /dev/null +++ b/examples/puzzletron/main.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). + +This script provides three modes: +1. Default mode: Runs the full puzzletron pipeline +2. MIP-only mode: Runs only the MIP search and realize models phase +3. MIP sweep mode: Runs MIP for multiple memory compression rates (enabled via config) + +Usage: + # Full puzzletron pipeline + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml + + # Only MIP search and realize models phase + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only + + # MIP sweep mode (set mip.sweep.enabled: true in config) + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only +""" + +import argparse +from datetime import timedelta +from pathlib import Path + +import modelopt.torch.nas as mtn +import modelopt.torch.puzzletron as mtpz +import modelopt.torch.utils.distributed as dist + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compress large language models using the Puzzletron algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)", + ) + parser.add_argument( + "--mip-only", + action="store_true", + help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)", + ) + + return parser.parse_args() + + +def run_full_puzzletron(hydra_config_path: str): + """Run the full puzzletron pipeline. + + Args: + config_path: Path to the YAML configuration file + """ + mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") + dist.setup(timeout=timedelta(minutes=10)) + + # Register Hydra custom resolvers (needed for config resolution) + mtpz.tools.register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + dist.cleanup() + mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + + +def run_mip_only(hydra_config_path: str): + """Run only the MIP search and realize models phase. + + This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation + have already been completed. + + Args: + hydra_config_path: Path to the YAML configuration file + """ + dist.setup(timeout=timedelta(minutes=10)) + + # Register Hydra custom resolvers (needed for config resolution) + mtpz.tools.register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Check if sweep mode is enabled + if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): + mtpz.tools.mprint( + "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + ) + mtpz.mip.run_mip_sweep(hydra_cfg) + else: + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mtpz.mip.launch_mip_and_realize_model(hydra_cfg) + + dist.cleanup() + mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + + +def main(): + args = parse_args() + + if args.mip_only: + run_mip_only(hydra_config_path=args.config) + else: + run_full_puzzletron(hydra_config_path=args.config) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/mip_sweep_example.png b/examples/puzzletron/mip_sweep_example.png new file mode 100644 index 0000000000..4eb1089fe0 Binary files /dev/null and b/examples/puzzletron/mip_sweep_example.png differ diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt new file mode 100644 index 0000000000..317a38f5ea --- /dev/null +++ b/examples/puzzletron/requirements.txt @@ -0,0 +1,5 @@ +lm-eval==0.4.8 +math-verify +ray +# Likely works for transformers v5 also, but we need to test it +transformers<5.0 diff --git a/modelopt/torch/export/plugins/__init__.py b/modelopt/torch/export/plugins/__init__.py index d54d423f38..99a755598a 100644 --- a/modelopt/torch/export/plugins/__init__.py +++ b/modelopt/torch/export/plugins/__init__.py @@ -25,3 +25,6 @@ with import_plugin("vllm_fakequant_megatron"): from .vllm_fakequant_megatron import * + +with import_plugin("hf_checkpoint_utils"): + from .hf_checkpoint_utils import * diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py index d6696d8199..4d9bc6fc29 100644 --- a/modelopt/torch/export/plugins/hf_checkpoint_utils.py +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -21,14 +21,13 @@ from pathlib import Path import torch -from huggingface_hub import hf_hub_download, list_repo_files +from huggingface_hub import snapshot_download from safetensors.torch import safe_open from tqdm import tqdm -def copy_remote_code( - pretrained_model_path: str | os.PathLike, - save_directory: str | os.PathLike, +def copy_hf_ckpt_remote_code( + pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike ): """Copy remote code from pretrained model to save directory. @@ -37,7 +36,7 @@ def copy_remote_code( frameworks. If ``pretrained_model_path`` is a local directory, Python files are copied directly. - If it is a HuggingFace Hub model ID, Python files are downloaded from the Hub first. + If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub. Args: pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID. @@ -45,18 +44,17 @@ def copy_remote_code( """ hf_checkpoint_path = Path(pretrained_model_path) save_dir = Path(save_directory) + save_dir.mkdir(parents=True, exist_ok=True) if hf_checkpoint_path.is_dir(): for py_file in hf_checkpoint_path.glob("*.py"): - if py_file.is_file(): - shutil.copy(py_file, save_dir / py_file.name) + shutil.copy2(py_file, save_dir / py_file.name) else: - # Hub model ID: download any top-level .py files (custom modeling code) - repo_id = str(pretrained_model_path) - for filename in list_repo_files(repo_id): - if "/" not in filename and filename.endswith(".py"): - local_path = hf_hub_download(repo_id=repo_id, filename=filename) - shutil.copy(local_path, save_dir / filename) + snapshot_download( + repo_id=str(pretrained_model_path), + local_dir=str(save_dir), + allow_patterns=["*.py"], + ) def load_multimodal_components( diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 4750e3b59c..c901a2159d 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -45,7 +45,7 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, ) -from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components +from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping from .plugins.mcore_custom import ( CustomModuleMapping, @@ -349,7 +349,7 @@ def save_pretrained( torch.distributed.barrier() if is_last_stage_main_rank and self._hf_config is not None: - copy_remote_code(pretrained_model_name_or_path, save_directory) + copy_hf_ckpt_remote_code(pretrained_model_name_or_path, save_directory) # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" diff --git a/modelopt/torch/prune/importance_hooks/__init__.py b/modelopt/torch/prune/importance_hooks/__init__.py index 3bf30c2a46..1e86ddcf65 100644 --- a/modelopt/torch/prune/importance_hooks/__init__.py +++ b/modelopt/torch/prune/importance_hooks/__init__.py @@ -18,6 +18,7 @@ from .base_hooks import * from .base_hooks_analysis import * +from .expert_removal_hooks import * with import_plugin("megatron_hooks"): from .plugins.megatron_hooks import * diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py index 22e82c27b0..5eccd033d6 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks.py @@ -149,7 +149,8 @@ def dump_activations_logs( torch.save(activations_log, activations_log_path) if rank == 0: - args.activation_hooks_kwargs.pop("model") + if args.activation_hooks_kwargs is not None: + args.activation_hooks_kwargs.pop("model", None) json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") dist.barrier() @@ -565,9 +566,9 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): assert self.optimize_for in ["latency", "memory"] self.hidden_size = model_config.hidden_size - self.n_heads_in_group = block_config.attention.n_heads_in_group self.num_q_heads = model_config.num_attention_heads - self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.num_kv_heads = block_config.attention.num_key_value_heads + self.n_heads_in_group = self.num_q_heads // self.num_kv_heads self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) self.agg_kv_head_contributions = torch.zeros( diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py index 0a7ea542b8..37e7ef6934 100644 --- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -52,7 +52,8 @@ python compare_module_outputs.py \ --reference output_unpruned.pt \ --compare output_l2norm.pt \ - --output-json comparison_stats.json + --output-json comparison_stats.json \ + --no-weights-only The saved file format\: diff --git a/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py new file mode 100644 index 0000000000..2d0a9ad4c5 --- /dev/null +++ b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig).""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import torch +import transformers +from packaging.version import Version +from torch import nn + +from .base_hooks import ForwardHook + +if TYPE_CHECKING: + # Okay since this is only used for type hints else we should not import puzzletron here + # as its dependencies may not be installed + from modelopt.torch.puzzletron.block_config import BlockConfig + +__all__ = [ + "NemotronHRemoveExpertsIndependentHook", + "Qwen3VLRemoveExpertsIndependentHook", + "RankedChoiceVotingHook", + "RankedChoiceVotingHookNemotronH", + "RemoveExpertsIndependentHook", +] + + +class RemoveExpertsIndependentHook(ForwardHook, ABC): + """Base hook for measuring expert importance in Mixture-of-Experts models. + + This hook measures how much removing each expert affects the model output + by comparing outputs with and without each expert. + """ + + def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + moe: The MoE module to analyze + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.moe = moe + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.num_local_experts = block_config.ffn.moe.num_local_experts + self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok + # tensor of zeros of size num experts + self.diffs = ["mse", "cosine"] + some_param = next(self.moe.parameters()) + self.diffs = { + k: torch.zeros( + size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device + ) + for k in self.diffs + } + self.call_count = 0 + + @abstractmethod + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for measuring expert importance. + + This method is called twice per forward pass: + 1. First call (router_logits=None): Compute original routing and expert outputs + 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits. If None, compute from hidden_states. + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) + - routed_experts: Shape (num_tokens, hidden_dim) + """ + raise NotImplementedError + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that measures expert importance.""" + hidden_states = args[0] + router_logits, original_routed_out = self.get_router_logits_and_routed_experts( + hidden_states + ) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) + + _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) + self.call_count += 1 + + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + + num_tokens_displaced = is_token_routed_to_this_expert.sum() + if num_tokens_displaced == 0: + continue + num_total_tokens = is_token_routed_to_this_expert.numel() + + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + + router_logits_without_i = router_logits.clone() + router_logits_without_i[..., i_expert] = -float("inf") # disable expert i + router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] + _, routed_out_without_i = self.get_router_logits_and_routed_experts( + relevant_hidden_states, router_logits_without_i + ) + + relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] + self.diffs["mse"][i_expert] += ( + nn.functional.mse_loss( + relevant_tokens_original_out, routed_out_without_i, reduction="mean" + ) + * num_tokens_displaced + / num_total_tokens + ) + self.diffs["cosine"][i_expert] += ( + -nn.functional.cosine_similarity( + relevant_tokens_original_out, routed_out_without_i, dim=-1 + ).mean() + * num_tokens_displaced + / num_total_tokens + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format.""" + expert_ranks_mse = torch.argsort(self.diffs["mse"]) + expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) + return { + "expert_ranks_mse": expert_ranks_mse.cpu(), + "expert_ranks_cosine": expert_ranks_cosine.cpu(), + "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), + "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert importance scores.""" + return self.diffs["mse"] + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "diffs_mse": self.diffs["mse"].cpu(), + "diffs_cosine": self.diffs["cosine"].cpu(), + "call_count": self.call_count, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) + self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) + self.call_count = state_dict["call_count"] + + +class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for NemotronH models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for NemotronH MoE. + + Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. + """ + orig_shape = hidden_states.shape + # NemotronHMOE.gate forward, copied to extract router_logits + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if router_logits is None: + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) + ) + router_logits = router_logits.sigmoid() + router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) + + topk_indices = self._get_topk_indices_without_correction_bias(router_logits) + topk_weights = router_logits.gather(1, topk_indices) + if self.moe.gate.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.moe.gate.routed_scaling_factor + # Routed experts forward + hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + return router_logits, hidden_states + + @torch.no_grad() + def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: + """Get topk indices without correction bias. + + Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. + """ + group_scores = ( + scores.view( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .reshape(-1, self.moe.gate.n_routed_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] + return topk_indices + + +class RankedChoiceVotingHook(ForwardHook): + """Hook for ranking experts using ranked choice voting algorithm. + + This hook tracks router decisions and uses ranked choice voting to determine + which experts are least important (can be pruned first). + """ + + def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + router: The router module (typically nn.Linear) + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that records router decisions. + + Args: + module: The router module + args: Tuple with one tensor entry (B, T, I) + output: Router logits of shape (B, T, E) + """ + router_logits = output[0] if isinstance(output, tuple) else output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format using ranked choice voting.""" + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts_val + + # Compute zero-shot expert ranks (double argsort converts counts to rank positions) + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + return { + "expert_ranks": expert_ranks, + "zero_shot_expert_ranks": zero_shot_expert_ranks, + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert ranks.""" + if not self.router_argsort: + return torch.tensor([]) + router_argsort = torch.concat(self.router_argsort, dim=0) + return router_argsort[:, 0].float() + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): + """Ranked choice voting hook for NemotronH models. + + In NemotronH, router_logits is an internal temporary state that never leaves + the forward() function. We reconstruct router_logits from the input hidden_states. + """ + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that reconstructs router logits from hidden states.""" + hidden_states = args[0] + hidden_states = hidden_states.view(-1, module.config.hidden_size) + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), module.weight.type(torch.float32) + ) + super().__call__(module, args, router_logits) + + +class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for Qwen3-VL models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for Qwen3-VL MoE. + + Based on Qwen3VLMoeSparseMoe forward pass. + """ + orig_shape = hidden_states.shape + # Use hidden_states.shape[-1] instead of self.moe.hidden_size for transformers v5 compatibility + hidden_size = ( + self.moe.hidden_size if hasattr(self.moe, "hidden_size") else hidden_states.shape[-1] + ) + + # Flatten to (num_tokens, hidden_size) for processing + hidden_states_flat = hidden_states.reshape(-1, hidden_size) + + if router_logits is None: + router_logits = self.moe.gate(hidden_states_flat) + # In transformers vf the gate returns (logits, aux_loss) tuple + if isinstance(router_logits, tuple): + router_logits = router_logits[0] + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk( + routing_weights, self.num_experts_per_tok, dim=-1 + ) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + + if Version(transformers.__version__) >= Version("5.0"): + # transformers 5.x: grouped_mm_experts_forward expects + # (hidden_states_flat 2D, top_k_index, top_k_weights) + routed_out = self.moe.experts(hidden_states_flat, router_indices, routing_weights) + else: + # transformers 4.x: loop-based experts expects + # (hidden_states_3d 3D, routing_weights_full, router_indices) + batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1 + hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, hidden_size) + router_weights = torch.zeros( + router_logits.shape, dtype=routing_weights.dtype, device=router_logits.device + ).scatter_(1, router_indices, routing_weights) + routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices) + + # Return in same shape as input + routed_out = routed_out.reshape(*orig_shape) + + return router_logits, routed_out diff --git a/modelopt/torch/puzzletron/README.md b/modelopt/torch/puzzletron/README.md new file mode 100644 index 0000000000..4c6da80e54 --- /dev/null +++ b/modelopt/torch/puzzletron/README.md @@ -0,0 +1,3 @@ +Experimental model compression algorithm based on a Local Neural Architecture Search. +Based on the Puzzle paper: +PoC for Llama 3.1 model. diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py new file mode 100644 index 0000000000..15389dedfa --- /dev/null +++ b/modelopt/torch/puzzletron/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: Some modules also trigger factory registration as side effect +from . import ( + activation_scoring, + anymodel, + block_config, + build_library_and_stats, + dataset, + entrypoint, + mip, + plugins, + pruning, + puzzletron_nas_plugin, + replacement_library, + scoring, + subblock_stats, + tools, + utils, +) diff --git a/modelopt/torch/puzzletron/activation_scoring/__init__.py b/modelopt/torch/puzzletron/activation_scoring/__init__.py new file mode 100644 index 0000000000..c2785679db --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .activation_hooks import * +from .score_pruning_activations import * diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py new file mode 100644 index 0000000000..8a8c9adce1 --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import * diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py new file mode 100644 index 0000000000..1195d2369c --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides a function to register activation hooks for a model. +Activation hooks are used to compute activation scores for pruning.""" + +from typing import Type + +import torch + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook as ActivationsHook + +from ...tools.logger import aprint +from ...utils.dummy_modules import DummyBlock, DummyModule + +__all__ = ["register_activation_hooks"] + + +def register_activation_hooks( + model, + activation_hooks_kwargs: dict, + pruning_mixin, + hook_class: Type[ActivationsHook], +) -> dict[str, ActivationsHook]: + """Register activation hooks using the pruning mixin approach. + + Args: + model: The model to register hooks on. + activation_hooks_kwargs: Keyword arguments passed to hook constructors. + pruning_mixin: The pruning mixin that defines which modules to hook. + hook_class: The hook class to instantiate for each module. + + Returns: + Dictionary mapping module names to hook instances. + """ + activation_hooks_kwargs["model"] = model + + if hook_class not in pruning_mixin.supported_hooks(): + raise ValueError( + f"Hook class not supported for {pruning_mixin.__class__.__name__}, " + f"must be in {pruning_mixin.supported_hooks()}" + ) + + module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) + activation_hooks = dict() + for block_idx, module_name in module_names_to_hook: + try: + module = model.get_submodule(module_name) + except AttributeError: + # Module doesn't exist on this rank's shard (e.g., in distributed setup) + continue + + # Skip dummy modules - they don't have real activations to hook + if isinstance(module, (DummyModule, DummyBlock)): + continue + + block_config = None + if block_idx is not None: + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + if len(activation_hooks) == 0: + # In distributed mode, it's okay for a rank to have 0 hooks if it doesn't own + # the target modules (e.g., with hybrid patterns like "*-" where different + # ranks own different layer types). However, we still want to catch real bugs + # where no hooks are found at all. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if is_distributed: + aprint( + "No hooks registered on this rank. This is expected if this rank " + "doesn't own any layers matching the hook pattern (e.g., in hybrid " + "patterns with distributed model sharding)." + ) + else: + raise ValueError("couldn't find any hooks") + + if len(activation_hooks) > 0: + aprint(f"Found the following hooks: {activation_hooks.keys()}") + return activation_hooks diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py new file mode 100644 index 0000000000..27b7607e6b --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + +from ..tools.logger import mprint + +__all__ = ["launch_score_activations"] + + +def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: + """Determine if the activation hook method has proper checkpoint support implemented. + + Args: + activation_hooks_kwargs: Hook configuration + + Returns: + bool: True if the hook method has save_state/load_state implemented + """ + method = activation_hooks_kwargs.get("method", "") + + # Methods with implemented checkpoint support + supported_methods = { + "iterative", # IterativeChannelContributionHook: save_state/load_state implemented + "independent", # IndependentChannelContributionHook: save_state/load_state implemented + "stats", # RouterStatsHook: save_state/load_state implemented + "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented + } + + return method in supported_methods + + +def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool: + """Check if scoring is already completed by looking for the expected output files. + Also checks if the scoring method is safe for resume. + + Args: + activations_log_dir: Directory where activation logs should be stored + activation_hooks_kwargs: Hook configuration to check if resume is safe + + Returns: + bool: True if scoring is completed (has rank files and args.json) + """ + # Only check completion on main process + if dist.is_master(): + log_dir = Path(activations_log_dir) + + # Check if directory exists + if not log_dir.exists(): + return False + + # Check for rank files (at least rank_0.pth should exist) + rank_files = list(log_dir.glob("rank_*.pth")) + + if not rank_files: + return False + + # Check for args.json (created by main process) + args_file = log_dir / "args.json" + has_args_json = args_file.exists() + + # Check for completion: if we have rank files and args.json, scoring is complete + if rank_files and has_args_json: + # Add optional completion info for debugging + mprint(f"Found completed scoring in {activations_log_dir}") + mprint(f" - Found {len(rank_files)} rank files") + mprint(f" - Found args.json: {has_args_json}") + + return True + + return False + + +def should_skip_scoring_completely(cfg: DictConfig) -> bool: + """Determine if we should skip scoring entirely (only if 100% complete). + Partial progress should proceed to validate_model for proper resume. + + Args: + cfg: Configuration object + + Returns: + bool: True if we should skip scoring (100% completed), False if we should run/resume it + """ + # Check if activations_log_dir is specified + if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None: + mprint("No activations_log_dir specified, running scoring") + return False + + # Check for force restart flag + force_restart = getattr(cfg.pruning, "force_restart_scoring", False) + if force_restart: + mprint("Force restart flag set, will restart scoring regardless of existing artifacts") + return False + + # Get hook configuration to check if resume is mathematically safe + activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {}) + + # Check if scoring is already completed + is_completed = check_scoring_completion( + cfg.pruning.activations_log_dir, activation_hooks_kwargs + ) + + # Broadcast the result to all processes in distributed mode + if dist.size() > 1: + should_skip = [is_completed] # Use list for mutable object + torch.distributed.broadcast_object_list(should_skip, src=0) + is_completed = should_skip[0] + + if is_completed: + mprint("Scoring 100% completed, skipping...") + + return is_completed + + +def launch_score_activations(cfg: DictConfig): + from ..tools.validate_model import validate_model + + # Check if we should skip scoring entirely (only if 100% complete) + if should_skip_scoring_completely(cfg): + return + + mprint("Starting pruning activation scoring...") + + # The checkpoint manager inside validate_model handles all progress tracking + validate_model(args=cfg.pruning) diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md new file mode 100644 index 0000000000..1c8c68b60b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -0,0 +1,204 @@ +# AnyModel Guide + +This guide explains how to add support for new models in the Puzzletron pipeline. + +## Convert model + +Convert a HuggingFace model to Puzzletron format. + +Step 1: Create Model Descriptor + +Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention). + +Key points: + +- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)) + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +Step 2: Create Converter + +Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config. + +Key points: + +- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models//configuration_.py` + +See example: [llama_converter.py](models/llama/llama_converter.py) + +Step 3: Create `models//__init__.py` + +Export descriptor and converter classes: + +```python +from models.._model_descriptor import MyModelDescriptor +from models.._converter import MyConverter +``` + +Step 4: Register in `models/__init__.py` + +Add import to trigger factory registration: + +```python +from models. import * +``` + +## Usage + +```python +from modelopt.torch.puzzletron.anymodel import convert_model + +convert_model( + input_dir="path/to/hf_checkpoint", + output_dir="path/to/puzzletron_checkpoint", + converter="model_name", +) +``` + +## Compress model + +Run pruning and compression on a Puzzletron model. + +Step 1: Implement ModelDescriptor methods for compression + +Add to your `ModelDescriptor`: + +- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support +- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides)) +- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding)) +- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods)) +- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods)) +- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods)) +- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods)) +- `attn_no_op_post_init()` - replace attention sublayers with no-op modules +- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules + +Step 2: Create FFN Layer Descriptor + +Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`. + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor` + +Step 3: Configure YAML files + +Update the main model config YAML: + +- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")` +- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml) + +Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.): + +- Set `pruning_mixin._target_` to the appropriate mixin class +- Set `layer_descriptor._target_` to your layer descriptor class +- Set `hook_class` to the activation hook for scoring +- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment +- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/) + +## End-to-end example + +See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps. For container setup and dependencies needed to run this test, see the [Puzzletron README environment section](../../../../examples/puzzletron/README.md#environment). + +--- + +## Advanced Topics + +## Pruning Configuration + +### Pruning YAML Structure + +Each pruning type has a YAML config with these key fields: + +```yaml +pruning_mixin: + _target_: pruning._pruning_mixin. + layer_descriptor: + _target_: models.. + +hook_class: ${get_object:utils.activation_hooks.hooks.} +activation_hooks_kwargs: + method: + target_layer: "" # e.g., "mlp.down_proj", "self_attn.o_proj" +``` + +| Field | Description | +|-------|-------------| +| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type | +| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks | +| `hook_class` | Activation hook class for importance scoring | +| `target_layer` | Layer name (relative to decoder block) where hooks attach | + +### Adding a New Hook Class + +1. **Implement the hook** under `modelopt/torch/prune/importance_hooks/` (e.g. `base_hooks.py` for generic hooks, `expert_removal_hooks.py` for MoE expert removal): + - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook` in `expert_removal_hooks.py`) + - Implement required methods (e.g., `get_router_logits_and_routed_experts`) + +2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`: + + For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook] + ``` + + For expert removal (`pruning/expert_removal_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [RankedChoiceVotingHook, ..., YourNewHook] + ``` + +3. **Reference in YAML**: + + ```yaml + hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook} + ``` + +### Pruning Types Reference + +| Type | Mixin | Example Hooks | +|------|-------|---------------| +| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../prune/importance_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../prune/importance_hooks/base_hooks.py) | +| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py) | +| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../prune/importance_hooks/base_hooks.py) | + +## Implementing `block_config_to_layer_overrides` + +Maps Puzzletron's [`BlockConfig`](../block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning: + +| BlockConfig Field | HuggingFace Attribute (check `config.json`) | +|-------------------|---------------------------------------------| +| `attention.num_key_value_heads` | `num_key_value_heads` | +| `ffn.intermediate_size` | `intermediate_size` | +| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) | +| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` | + +**Tip**: Check the model's `config.json` for exact attribute names - they vary between models. + +See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py) + +--- + +## Implementing path-based methods + +These methods return paths derived from the model's weight names: + +- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()` + +Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)). + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +--- + +## Implementing `init_rotary_embedding` + +Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype. + +Look in `github.com/huggingface/transformers/tree/main/src/transformers/models//modeling_.py` for: + +- `class.*Rotary` — the rotary embedding class name and constructor arguments +- `self.rotary_emb` — the attribute path + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) diff --git a/modelopt/torch/puzzletron/anymodel/__init__.py b/modelopt/torch/puzzletron/anymodel/__init__.py new file mode 100644 index 0000000000..1b4648c04f --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/__init__.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""AnyModel: Architecture-agnostic model compression for HuggingFace models. + +This module provides a declarative approach to model compression that works with +any HuggingFace model without requiring custom modeling code. Instead of duplicating +HuggingFace modeling classes, AnyModel uses ModelDescriptors that define: + +1. Which decoder layer class(es) to patch for heterogeneous configs +2. How to map BlockConfig to layer-specific overrides +3. Weight name patterns for subblock checkpointing + +Example usage: + >>> from modelopt.torch.puzzletron.anymodel import convert_model + >>> convert_model( + ... input_dir="path/to/hf_checkpoint", + ... output_dir="path/to/anymodel_checkpoint", + ... converter="llama", + ... ) + +Supported models: + - llama: Llama 2, Llama 3, Llama 3.1, Llama 3.2 + - (more to come: qwen2, mistral_small, etc.) +""" + +from . import models # trigger factory registration +from .converter import * +from .model_descriptor import * +from .puzzformer import * diff --git a/modelopt/torch/puzzletron/anymodel/converter/__init__.py b/modelopt/torch/puzzletron/anymodel/converter/__init__.py new file mode 100644 index 0000000000..9a444467ef --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converters for transforming HuggingFace models to AnyModel format.""" + +from .base import * +from .convert_any_model import * +from .converter_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/converter/base.py b/modelopt/torch/puzzletron/anymodel/converter/base.py new file mode 100644 index 0000000000..c8e01ffe28 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/base.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import fnmatch +import json +import os +import shutil +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +from ...block_config import BlockConfig +from ...tools.checkpoint_utils_hf import load_model_config, save_model_config +from ..model_descriptor import ModelDescriptor + +__all__ = ["Converter"] + + +class Converter(ABC): + """Base class for converting HuggingFace models to Puzzletron/AnyModel format.""" + + @staticmethod + def _get_weight_map(input_dir: Path) -> Dict[str, str]: + """Load weight map from checkpoint directory (supports both sharded and single-file models). + + Returns a dict mapping parameter names to their safetensors filenames. + """ + index_path = input_dir / "model.safetensors.index.json" + single_file_path = input_dir / "model.safetensors" + + if index_path.exists(): + # Sharded model + with open(index_path, "r") as f: + index = json.load(f) + return index["weight_map"] + elif single_file_path.exists(): + # Single file model - create a synthetic weight map + data = load_file(single_file_path) + return {name: "model.safetensors" for name in data.keys()} + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + + @classmethod + def convert_model_weights( + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int + ): + """Convert model weights to subblock format.""" + param_to_file = Converter._get_weight_map(input_dir) + all_param_names = list(param_to_file.keys()) + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = descriptor.get_weight_groups( + all_param_names, num_hidden_layers=num_hidden_layers + ) + + # Output directory + out_dir = output_dir / "subblocks_safetensors" + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + for subblock, param_names in tqdm(subblocks.items(), desc="Processing subblocks"): + param_files = set(param_to_file[name] for name in param_names) + tensors = {} + + # Load only needed files for this subblock + for file in param_files: + data = load_file(os.path.join(input_dir, file)) + for name in param_names: + if param_to_file[name] == file and name in data: + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, "quantized", None) == "mxfp4": + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors( + data[name], + data[name.replace("_blocks", "_scales")], + ) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] + + # Save this subblock + print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") + for layer in tensors.keys(): + print(f" - {layer}") + + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors.keys(): + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with (output_dir / "model.safetensors.index.json").open("w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") + + @classmethod + def convert_configs_in_dirs( + cls, + input_dir: Path, + output_dir: Path, + trust_remote_code: bool = False, + ): + """Convert config and add block_configs.""" + config = load_model_config(input_dir, trust_remote_code=trust_remote_code) + + block_configs = cls.create_block_configs_from_main_config(config) + out_config = copy.deepcopy(config) + out_config.block_configs = block_configs + + save_model_config(out_config, output_dir) + return out_config + + @staticmethod + def copy_checkpoint_files(input_dir: Path, output_dir: Path): + """Copy checkpoint files except model weights (which will be converted).""" + ignore_patterns = [ + "model-*.safetensors", + "model.safetensors", + "model.safetensors.index.json", + "subblocks_safetensors", + ] + + def ignore_func(dir, files): + ignored = set() + for pattern in ignore_patterns: + ignored.update(fnmatch.filter(files, pattern)) + return ignored + + shutil.copytree(str(input_dir), str(output_dir), ignore=ignore_func, dirs_exist_ok=True) + + @classmethod + def convert( + cls, + descriptor: ModelDescriptor, + input_dir: Path, + output_dir: Path, + ): + """Convert a HuggingFace model to AnyModel format. + + Args: + descriptor: Model descriptor for the model type. + input_dir: Path to the input HuggingFace checkpoint. + output_dir: Path to the output AnyModel checkpoint. + """ + cls.copy_checkpoint_files(input_dir, output_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + config = cls.convert_configs_in_dirs( + input_dir, output_dir, trust_remote_code=trust_remote_code + ) + cls.convert_model_weights( + input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers + ) + + @staticmethod + @abstractmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create per-layer BlockConfig list from a HuggingFace model config. + + This method extracts layer-specific parameters (e.g., intermediate_size, + num_key_value_heads) from the main model config and creates a BlockConfig + for each layer. These BlockConfigs enable layer-specific pruning and + modifications during the compression pipeline. + + Args: + config: HuggingFace PretrainedConfig (e.g., LlamaConfig, Qwen2Config) + + Returns: + List of BlockConfig, one per hidden layer. Each BlockConfig contains: + - AttentionConfig: attention settings (no_op, num_key_value_heads) + - FFNConfig: FFN settings (no_op, intermediate_size) + + Example: + For a model with uniform layers (e.g., Llama): + return [BlockConfig(...)] * config.num_hidden_layers + + For a model with heterogeneous layers (e.g., NemotronH with Mamba/Attention): + return [BlockConfig(...) for layer_idx in range(num_layers)] + """ + raise NotImplementedError + + @staticmethod + def convert_weight_name(name: str) -> str: + """ + Convert weight names during checkpoint conversion. + + This method can be overridden by subclasses to apply model-specific weight name + transformations when converting checkpoints from HuggingFace format to Puzzletron format. + + Default implementation returns the name unchanged (identity function). + + Args: + name: Original weight name from HuggingFace checkpoint + + Returns: + Converted weight name for Puzzletron format + + Example: + For Qwen2.5-VL, this converts: + - visual.* → model.visual.* + - model.* → model.language_model.* + """ + return name diff --git a/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py new file mode 100644 index 0000000000..5e0b71753f --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Convert a HuggingFace model to AnyModel format.""" + +from pathlib import Path + +from ..model_descriptor import ModelDescriptorFactory +from .base import Converter +from .converter_factory import ConverterFactory + +__all__ = ["convert_model"] + + +def convert_model( + input_dir: str, + output_dir: str, + converter: Converter | str, +): + """Convert a HuggingFace model to AnyModel format. + + This function converts a HuggingFace checkpoint to the AnyModel format used + for compression. The conversion process: + + 1. Copies non-weight files (config, tokenizer, etc.) + 2. Creates block_configs for each layer + 3. Reorganizes weights into subblock checkpoints + + Args: + input_dir: Path to the input HuggingFace checkpoint directory. + output_dir: Path to the output AnyModel checkpoint directory. + converter: Either a converter name (e.g., "llama") or a Converter class. + + Example: + >>> convert_model( + ... input_dir="/path/to/Llama-3.1-8B-Instruct", + ... output_dir="/path/to/output/ckpts/teacher", + ... converter="llama", + ... ) + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get descriptor and converter from factories (they use the same name) + descriptor = ModelDescriptorFactory.get(converter) + converter = ConverterFactory.get(converter) + + converter.convert(descriptor=descriptor, input_dir=input_dir, output_dir=output_dir) + + +if __name__ == "__main__": + from fire import Fire + + Fire(convert_model) diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py new file mode 100644 index 0000000000..a8f5cf5f32 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from ..model_descriptor import ModelDescriptor + +__all__ = ["ConverterFactory"] + + +class ConverterFactory: + """Factory for registering and retrieving Converter classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register converter classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered converter by name or return the converter if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py new file mode 100644 index 0000000000..7b0ec18b75 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model descriptors for defining model-specific properties and layer naming conventions.""" + +from .base import * +from .model_descriptor_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py new file mode 100644 index 0000000000..3c1749d46e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Type + +import torch.nn as nn + +from ...block_config import BlockConfig +from ...utils.dummy_modules import DummyBlock + +__all__ = ["ModelDescriptor"] + + +class ModelDescriptor(ABC): + @staticmethod + @abstractmethod + def decoder_layer_cls() -> Type[nn.Module] | List[Type[nn.Module]]: + """Decoder layer class types to patch for heterogeneous config support. + + In most cases this class will hold as attributes both FFN & attention layers. + + Returns: + nn.Module class type or a list if several class types should be patched. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + """Map between BlockConfig and layer config overrides. + + These overrides are consumed by a specific decoder layer and by the whole model. + Usage can be seen in `deci_x_patcher` under the method `_patched_decoder_layer_init`. + + Example implementation to override the FFN intermediate size of a block: + >>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + >>> return {"intermediate_size": block_config.ffn.intermediate_size} + """ + raise NotImplementedError + + @staticmethod + def requires_trust_remote_code() -> bool: + """Whether this model descriptor requires trust_remote_code=True for loading. + + Models that use custom code (e.g., via auto_map in config) should override + this to return True. + + Returns: + True if trust_remote_code=True is required, False otherwise. + """ + return False + + @staticmethod + def mlp_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to + the residuals hidden_states so a no-op implementation will leave residual the same): + + >>> decoder_layer.mlp = MatchingZeros() + + In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + def attn_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that Attention subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an attention layer with zeroes: + + >>> decoder_layer.self_attn = MatchingZeros() + + In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def init_rotary_embedding(model, runtime): + """Re-initiate the rotary embeddings based on an existing model. + + In puzzletron we initiate a sharded model by first creating a meta model then replacing + to the actual device by loading the state_dict with the real weights. + + Rotary embeddings frequencies are tensor buffers that are created dynamically during init + and are not part of the model state_dict, so cannot be restored after a meta device + initialization. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def input_embedding_name(): + """Return the name of the input embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def output_embedding_name(): + """Return the name of the output embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def final_norm_name(): + """Return the name of the final normalization layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_block_name(index: int): + """Return the name of the decoder layer at the given index.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Return predicates for grouping model weights to support subblock checkpointing. + + For every group name return a regex predicate whether a layer name is part of the group. + + Returns: + Dictionary of group name to regex pattern predicate. + """ + raise NotImplementedError + + @staticmethod + def uses_autocast() -> bool: + """Whether this model supports torch.autocast. + + Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast. + Override and return False for models that do not support autocast. + """ + return True + + @staticmethod + def get_language_model_config(config): + """Get the language model config from a PretrainedConfig. + + For regular LM models, returns the config itself. + For VL/multimodal models with nested configs, override to return the + language model portion (e.g., config.text_config for Qwen-VL). + """ + return config + + @staticmethod + def truncate_pattern_for_subblock( + lm_config: Any, parent_layer_index: int | None = None + ) -> None: + """Adjust per-layer config fields so a single-layer model represents the correct layer type. + + The default implementation handles ``hybrid_override_pattern`` for + hybrid architectures. It is a no-op when the field is absent. + Override if a model uses a different pattern alphabet. + """ + pattern = getattr(lm_config, "hybrid_override_pattern", None) + if not pattern: + return + # Strip cosmetic pipe separators (e.g. "M|-|*" -> "M-*") before indexing. + pattern = pattern.replace("|", "") + if not pattern: + raise ValueError( + f"hybrid_override_pattern is set but contains no layer-type characters " + f"(original: {lm_config.hybrid_override_pattern!r})" + ) + if parent_layer_index is not None and 0 <= parent_layer_index < len(pattern): + lm_config.hybrid_override_pattern = pattern[parent_layer_index] + return + lm_config.hybrid_override_pattern = pattern[0] + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block to replace a layer for sharded model initialization.""" + return DummyBlock(block_index=block_index) + + @classmethod + def mlp_no_op_supported(cls) -> bool: + """Check whether `mlp_no_op_post_init` is overridden for mlp no-op support.""" + method_name = ModelDescriptor.mlp_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def attn_no_op_supported(cls): + """Check whether `attn_no_op_post_init` is overridden for attention no-op support.""" + method_name = ModelDescriptor.attn_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """Group model weights to support the puzzle subblock checkpointing format. + + This method uses the abstract method `layer_name_predicates` by default. + + Args: + layer_names: state_dict layer names of the model. + num_hidden_layers: number of decoder layers in the model. + + Returns: + Dictionary of group names to list of layer names per group, e.g.: + >>> { + ... "embedding": ["model.embed_tokens.weight"], + ... "lm_head": ["lm_head.weight", "model.norm.weight"], + ... "block_0_ffn": ["model.layers.0.mlp.down_proj", ...], + ... "block_0_attention": ["model.layers.0.self_attn.q_proj", ...], + ... } + """ + weight_groups = defaultdict(list) + for name in layer_names: + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + break + else: + raise ValueError(f"Couldn't find a match for {name}") + return weight_groups diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py new file mode 100644 index 0000000000..74aaf311bf --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from transformers import AutoConfig + +from ...tools.checkpoint_utils_hf import force_cache_dynamic_modules +from .base import ModelDescriptor + +__all__ = ["ModelDescriptorFactory", "resolve_descriptor_from_pretrained"] + +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. +# Local to this script; add entries when supporting new model types for auto-detection. +_MODEL_TYPE_TO_DESCRIPTOR = { + "llama": "llama", + "mistral": "mistral_small", + "qwen2": "qwen2", + "qwen3": "qwen3", + "nemotron_h": "nemotron_h", + "nemotron_h_v2": "nemotron_h_v2", + "gpt_oss_20b": "gpt_oss_20b", +} + + +def resolve_descriptor_from_pretrained(pretrained: str, trust_remote_code: bool = False): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type. + + Args: + pretrained: Path to a pretrained model checkpoint or HuggingFace model identifier. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + The resolved ModelDescriptor class for the detected model type. + + Raises: + ValueError: If pretrained is not provided or if the model type cannot be auto-detected. + """ + + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + force_cache_dynamic_modules(config, pretrained, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", None) + + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] + print( + f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'" + ) + return ModelDescriptorFactory.get(detected) + + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) + raise ValueError( + f"Cannot auto-detect descriptor for model_type='{model_type}'. " + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." + ) + + +class ModelDescriptorFactory: + """Factory for registering and retrieving ModelDescriptor classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register model descriptor classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered model descriptor by name or return the descriptor if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py new file mode 100644 index 0000000000..c126d61b88 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from packaging.version import Version as _Version +from transformers import __version__ as _transformers_version + +# Import models to trigger factory registration +from .gpt_oss import * +from .llama import * +from .mistral_small import * +from .nemotron_h import * +from .nemotron_h_v2 import * +from .qwen2 import * +from .qwen3 import * + +if _Version(_transformers_version) >= _Version("4.57.0"): + from .qwen3_vl import * diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py new file mode 100644 index 0000000000..cd3872ab5c --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPT-OSS model support for AnyModel.""" + +from .gpt_oss_converter import * +from .gpt_oss_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py new file mode 100644 index 0000000000..4b35bec4bf --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS-20B converter for AnyModel compression.""" + +from typing import List + +from transformers import PretrainedConfig + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig, MoEConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["GptOssConverter"] + + +@ConverterFactory.register_decorator("gpt_oss") +class GptOssConverter(Converter): + """Converter for GPT-OSS models to AnyModel format. + + GPT-OSS is a pure MoE model with 32/128 experts per layer and 4/16 active experts. + All layers use MoE FFN (no standard dense FFN layers). + """ + + quantized = "mxfp4" + + @staticmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create block configs for GPT-OSS layers. + + GPT-OSS uses MoE for all FFN layers with: + - 32/128 local experts (num_local_experts) + - 4/16 active experts per token (experts_per_token) + - No dense/standard FFN layers + """ + num_hidden_layers = config.num_hidden_layers + num_local_experts = config.num_local_experts + experts_per_token = config.experts_per_token + intermediate_size = config.intermediate_size + + block_configs = [] + for layer_idx in range(num_hidden_layers): + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig( + no_op=False, + intermediate_size=None, # MoE doesn't use this field + moe=MoEConfig( + num_local_experts=num_local_experts, + num_experts_per_tok=experts_per_token, + expert_intermediate_dim=intermediate_size, + ), + ), + ).to_dict() + block_configs.append(block_config) + + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py new file mode 100644 index 0000000000..c8fd86b4bb --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssRotaryEmbedding + +from ....block_config import BlockConfig +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) + +# Expert removal is supported for unquantized models (test models). +# Production models use MXFP4 quantized MoE with combined tensors +# (gate_up_proj_blocks, down_proj_blocks), which is not yet supported. +from ....pruning.pruning_mixin import PruningMixIn +from ....utils.dummy_modules import DummyBlock +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size + +__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"] + + +@ModelDescriptorFactory.register_decorator("gpt_oss") +class GptOssModelDescriptor(ModelDescriptor): + """Model descriptor for GPT-OSS (pure MoE model).""" + + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward` in transformers<5.4 + if hasattr(original_layer, "attention_type"): + dummy_block.attention_type = original_layer.attention_type + return dummy_block + + @staticmethod + def decoder_layer_cls(): + """Get the decoder layer class for GPT-OSS models. + + GPT-OSS is a standard transformers model in recent versions. + Import directly from transformers.models.gpt_oss.modeling_gpt_oss. + """ + return GptOssDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + """Map BlockConfig to layer constructor overrides.""" + override_kwargs = {} + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_local_experts"] = block_config.ffn.moe.num_local_experts + override_kwargs["num_experts_per_tok"] = block_config.ffn.moe.num_experts_per_tok + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer): + """Replace attention sublayers with no-op modules.""" + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + """Replace MLP sublayers with no-op modules. + + Note: GPT-OSS MoE layers return (hidden_states, router_scores), so we need + to return a tuple of 2 values. + """ + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def init_rotary_embedding(model, runtime): + """Initialize rotary embeddings on the correct device.""" + # GPT-OSS uses RoPE with YARN scaling + + model.model.rotary_emb = GptOssRotaryEmbedding( + config=model.config, + device=runtime.device, + ) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Define regex patterns for grouping weights into subblocks.""" + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + """FFN is MoE in GPT-OSS with MXFP4 quantization.""" + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(post_attention_layernorm\.weight" + r"|mlp\.router\.weight" + r"|mlp\.router\.bias" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.o_proj\.bias" + r"|self_attn\.sinks)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + """Return available pruning mixins for GPT-OSS. + + Note: Expert removal works for unquantized models (test models). + Production models use MXFP4 quantization which is not yet supported. + """ + return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + + +@dataclass +class GptOssExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + GPT-OSS MoE layer descriptor for expert removal. + + Note: This only works for unquantized models (e.g., test models). + Production GPT-OSS models use MXFP4 quantization with fused experts + (_blocks, _scales, _bias), which requires a different approach. + + Structure: + - Router: mlp.router with .weight and .bias + - Experts: mlp.experts.{idx}.{gate_up_proj,down_proj} with .weight and .bias + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.layers.{layer_idx}.mlp" + expert_prefix_name: str = "experts" + + # Router has both weight and bias + router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + + # Fused format: single tensors containing all experts (test models) + fused_expert_weights: List[str] = field( + default_factory=lambda: [ + "experts.gate_up_proj", + "experts.gate_up_proj_bias", + "experts.down_proj", + "experts.down_proj_bias", + ] + ) + + # Not used for fused format, but kept for compatibility + expert_weights: List[str] = field(default_factory=lambda: ["gate_up_proj", "down_proj"]) + expert_biases: List[str] = field( + default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 0000000000..85355146ab --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Deduces expert mappings by comparing weights +4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Any, Dict, List, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +__all__ = [] + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + if "experts" in key or "router" in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0) + print( + f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts" + ) + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + original_experts_dequant.append( + {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]} + ) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f"model.layers.{layer}.mlp" + student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] # type: ignore[index] + student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] # type: ignore[index] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError( + f"Missing student expert weights for layer {layer} expert {student_idx}" + ) + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = ( + ( + orig_expert["up"].flatten()[:quick_compare_size] + - student_up.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + down_quick = ( + ( + orig_expert["down"].flatten()[:quick_compare_size] + - student_down.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float("inf") + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError( + f"Could not find match for student expert {student_idx} in layer {layer}" + ) + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print( + f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})" + ) + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, "r") as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks", + f"model.layers.{layer}.mlp.experts.gate_up_proj_scales", + f"model.layers.{layer}.mlp.experts.gate_up_proj_bias", + f"model.layers.{layer}.mlp.experts.down_proj_blocks", + f"model.layers.{layer}.mlp.experts.down_proj_scales", + f"model.layers.{layer}.mlp.experts.down_proj_bias", + f"model.layers.{layer}.mlp.router.weight", # Router weight + f"model.layers.{layer}.mlp.router.bias", # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index["weight_map"]: + filename = index["weight_map"][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework="pt") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Copy embeddings + src_emb = os.path.join(student_subblocks, "embeddings.safetensors") + dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors") + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/embeddings.safetensors" + + # Copy lm_head + src_head = os.path.join(student_subblocks, "lm_head.safetensors") + dst_head = os.path.join(subblocks_dir, "lm_head.safetensors") + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/lm_head.safetensors" + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors") + dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors") + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework="pt") as f: + for key in f.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors" + + return weight_map + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if "experts" not in key and "router" not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"] + orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight + tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] + mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[ + kept_indices_tensor + ] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[ + kept_indices_tensor + ] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors") + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors" + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "chat_template.jinja", + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith("transformers_"): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, "config.json") + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, "r") as f: + config = json.load(f) # type: ignore[arg-type] + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config["architectures"] = ["DeciGptOssForCausalLM"] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config["quantization_config"] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } + + dst_config = os.path.join(output_path, "config.json") + with open(dst_config, "w") as f: + json.dump(config, f, indent=2) # type: ignore[arg-type] + + +def main(): + parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model") + parser.add_argument( + "--student-path", type=str, required=True, help="Path to student model checkpoint" + ) + parser.add_argument( + "--original-path", + type=str, + required=True, + help="Path to original gpt-oss-120b model with MXFP4 weights", + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for the new checkpoint" + ) + parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers") + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, "model.safetensors.index.json") + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'=' * 70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'=' * 70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-' * 70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-' * 70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'=' * 70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json") + with open(experts_to_keep_output, "w") as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors") + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + + index_path = os.path.join(args.output_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py new file mode 100644 index 0000000000..0c5c2d6370 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .llama_converter import * +from .llama_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py new file mode 100644 index 0000000000..100ee1a450 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama converter for AnyModel compression.""" + +from typing import List + +from transformers import LlamaConfig + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["LlamaConverter"] + + +@ConverterFactory.register_decorator("llama") +class LlamaConverter(Converter): + """Converter for Llama models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConfig]: + """Create uniform block configs for all Llama layers. + + Llama models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py new file mode 100644 index 0000000000..f528e223af --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaRotaryEmbedding, +) + +from ....block_config import BlockConfig +from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size + +__all__ = [ + "LlamaModelDescriptor", + "LlamaFFNIntermediateLayerDescriptor", + "LlamaKVHeadsLayerDescriptor", +] + + +@ModelDescriptorFactory.register_decorator("llama") +class LlamaModelDescriptor(ModelDescriptor): + """Model descriptor for Llama models (Llama 2, Llama 3, Llama 3.1, Llama 3.2).""" + + @staticmethod + def decoder_layer_cls(): + return LlamaDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: LlamaForCausalLM, runtime): + model.model.rotary_emb = LlamaRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + """Layer descriptor for Llama FFN intermediate pruning.""" + + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class LlamaKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py new file mode 100644 index 0000000000..e1f4a3ef42 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mistral_small_converter import * +from .mistral_small_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py new file mode 100644 index 0000000000..64e30bf89b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from typing import List + +from transformers import MistralConfig + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["MistralSmallConverter"] + + +@ConverterFactory.register_decorator("mistral_small") +class MistralSmallConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: MistralConfig) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config.copy() for _ in range(num_hidden_layers)] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py new file mode 100644 index 0000000000..1c2d4af425 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralRotaryEmbedding, +) + +from ....block_config import BlockConfig +from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size + +__all__ = [ + "MistralSmallModelDescriptor", + "MistralFFNIntermediateLayerDescriptor", + "MistralKVHeadsLayerDescriptor", +] + + +@ModelDescriptorFactory.register_decorator("mistral_small") +class MistralSmallModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return MistralDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: MistralForCausalLM, runtime): + model.model.rotary_emb = MistralRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class MistralFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class MistralKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py new file mode 100644 index 0000000000..d801afec5c --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nemotron_h_converter import * +from .nemotron_h_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py new file mode 100644 index 0000000000..45dc7274d1 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig, MambaConfig, MoEConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["NemotronHConverter"] + + +@ConverterFactory.register_decorator("nemotron_h") +class NemotronHConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py new file mode 100644 index 0000000000..1c5706d194 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import importlib +import inspect +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Tuple, Type + +import torch.nn as nn + +from ....block_config import BlockConfig +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) +from ....pruning.pruning_mixin import PruningMixIn +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same + +__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"] + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + matches = [] + for finder, modname, ispkg in pkgutil.walk_packages( + transformers_modules.__path__, transformers_modules.__name__ + "." + ): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + target_name: str = "mixer.gate" + moe_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + expert_prefix_name: str = "experts.{expert_idx}" + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["gate.e_score_correction_bias"]) + expert_weights: List[str] = field( + default_factory=lambda: ["up_proj.weight", "down_proj.weight"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + if self.target_name != "mixer": + return super().get_modules_names_to_hook(model) + + # when target is `mixer` we'll target moe layers of class type: `NemotronHMOE`, as NemotronH models use auto-map we'll check for class name instead of class type. + target_class_name = "NemotronHMOE" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + # restrict to attributes called "mixer" and with the desired class name + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +@ModelDescriptorFactory.register_decorator("nemotron_h") +class NemotronHModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + if block_config.ffn.no_op and block_config.attention.no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if only_name.endswith("norm.weight") and "layers" in only_name: + # Skip and don't append this group to valid_weight_groups + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings?\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py new file mode 100644 index 0000000000..27c46317d9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nemotron_h_v2_converter import * +from .nemotron_h_v2_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py new file mode 100644 index 0000000000..2f7634aa61 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig, MambaConfig, MoEConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["NemotronHV2Converter"] + + +@ConverterFactory.register_decorator("nemotron_h_v2") +class NemotronHV2Converter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py new file mode 100644 index 0000000000..a1e326f235 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +import re +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Type + +import torch.nn as nn + +from ....block_config import BlockConfig +from ....pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, + FFNIntermediatePruningMixIn, +) +from ....pruning.pruning_mixin import PruningMixIn +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same + +__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"] + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + prefix = transformers_modules.__name__ + "." + + # Search already-imported modules first to avoid executing unrelated cached code. + matches = [] + for modname, module in list(sys.modules.items()): + if modname.startswith(prefix) and module is not None: + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + if matches: + return matches + + # Fall back to walking only the transformers_modules namespace if nothing found yet. + for finder, modname, ispkg in pkgutil.walk_packages(transformers_modules.__path__, prefix): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mixer.down_proj" + ffn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) + + +@ModelDescriptorFactory.register_decorator("nemotron_h_v2") +class NemotronHV2ModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn is not None and block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if ( + block_config.attention is not None + and block_config.attention.num_key_value_heads is not None + ): + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn is not None and block_config.ffn.moe is not None: + if block_config.ffn.moe.expert_intermediate_dim is not None: + override_kwargs["moe_intermediate_size"] = ( + block_config.ffn.moe.expert_intermediate_dim + ) + if block_config.ffn.moe.num_local_experts is not None: + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + ffn_no_op = block_config.ffn is not None and block_config.ffn.no_op + attn_no_op = block_config.attention is not None and block_config.attention.no_op + if ffn_no_op and attn_no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if re.fullmatch(r"backbone\.layers\.\d+\.norm\.weight", only_name): + # Skip the duplicated root layer norm; don't drop mixer.norm.weight etc. + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight" + r"|up_proj\.weight" # Simple MLP (non-MoE) + r"|down_proj\.weight))$" # Simple MLP (non-MoE) + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "ffn_intermediate": FFNIntermediatePruningMixIn( + NemotronHV2FFNIntermediateLayerDescriptor() + ), + # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated + } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py new file mode 100644 index 0000000000..b59c7e937a --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .qwen2_converter import * +from .qwen2_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py new file mode 100644 index 0000000000..a59b52e33e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 converter for AnyModel compression.""" + +from typing import List + +from transformers import Qwen2Config + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["Qwen2Converter"] + + +@ConverterFactory.register_decorator("qwen2") +class Qwen2Converter(Converter): + """Converter for Qwen2 models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: Qwen2Config) -> List[BlockConfig]: + """Create uniform block configs for all Qwen2 layers. + + Qwen2 models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config.copy() for _ in range(num_hidden_layers)] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py new file mode 100644 index 0000000000..46588abd19 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass +from typing import Dict + +from torch import nn +from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2RotaryEmbedding + +from ....block_config import BlockConfig +from ....utils.dummy_modules import DummyBlock +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size +from ..llama.llama_model_descriptor import LlamaFFNIntermediateLayerDescriptor + +__all__ = ["Qwen2ModelDescriptor", "Qwen2FFNIntermediateLayerDescriptor"] + + +@ModelDescriptorFactory.register_decorator("qwen2") +class Qwen2ModelDescriptor(ModelDescriptor): + """Model descriptor for Qwen2 models.""" + + @staticmethod + def decoder_layer_cls(): + return Qwen2DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen2-specific attributes like attention_type. + + Qwen2's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen2's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: nn.Module): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: nn.Module): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: nn.Module, runtime): + model.model.rotary_emb = Qwen2RotaryEmbedding(config=model.config, device=runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + # Qwen2 has biases on attention projections + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor): + """Layer descriptor for Qwen2 FFN intermediate pruning. + + Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj). + """ diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py new file mode 100644 index 0000000000..bc7ba2bef7 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .qwen3_converter import * +from .qwen3_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py new file mode 100644 index 0000000000..1911e23466 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import List + +from transformers import Qwen3Config + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig +from ...converter import Converter, ConverterFactory + +__all__ = ["Qwen3Converter"] + + +@ConverterFactory.register_decorator("qwen3") +class Qwen3Converter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py new file mode 100644 index 0000000000..30ffd28d0c --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from torch import nn +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3RotaryEmbedding, +) + +from ....block_config import BlockConfig +from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ....utils.dummy_modules import DummyBlock +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size + +__all__ = [ + "Qwen3ModelDescriptor", + "Qwen3FFNIntermediateLayerDescriptor", + "Qwen3KVHeadsLayerDescriptor", +] + + +@ModelDescriptorFactory.register_decorator("qwen3") +class Qwen3ModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return Qwen3DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen3-specific attributes like attention_type. + + Qwen3's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen3's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: Qwen3ForCausalLM, runtime): + model.model.rotary_emb = Qwen3RotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py new file mode 100644 index 0000000000..4637dca9c1 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .qwen3_vl_converter import * +from .qwen3_vl_model_descriptor import * diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py new file mode 100644 index 0000000000..9ceb98af34 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import TYPE_CHECKING, List + +from ....block_config import AttentionConfig, BlockConfig, FFNConfig, MoEConfig +from ...converter import Converter, ConverterFactory + +if TYPE_CHECKING: + from transformers import Qwen3VLMoeConfig + +__all__ = ["Qwen3VLConverter"] + + +@ConverterFactory.register_decorator("qwen3_vl") +class Qwen3VLConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: "Qwen3VLMoeConfig") -> List[BlockConfig]: + # Qwen3-VL MoE has nested text_config + text_config = config.text_config if hasattr(config, "text_config") else config + + num_hidden_layers = text_config.num_hidden_layers + decoder_sparse_step = getattr(text_config, "decoder_sparse_step", 1) + mlp_only_layers = getattr(text_config, "mlp_only_layers", []) + + block_configs = [] + for layer_idx in range(num_hidden_layers): + # Check if this layer is MoE or dense + is_moe_layer = (layer_idx % decoder_sparse_step == 0) and ( + layer_idx not in mlp_only_layers + ) + + if is_moe_layer: + # MoE layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=text_config.num_experts, + expert_intermediate_dim=text_config.moe_intermediate_size, + num_experts_per_tok=text_config.num_experts_per_tok, + ) + ), + ) + else: + # Dense layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=text_config.intermediate_size), + ) + + block_configs.append(block_config) + + print( + f"Created {len(block_configs)} block configs for Qwen3-VL MoE (decoder_sparse_step={decoder_sparse_step})" + ) + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py new file mode 100644 index 0000000000..aeedd41992 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextRotaryEmbedding, + Qwen3VLMoeVisionRotaryEmbedding, +) + +from ....block_config import BlockConfig +from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor +from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size + +__all__ = [ + "Qwen3VLModelDescriptor", + "Qwen3VLFFNIntermediateLayerDescriptor", + "Qwen3VLKVHeadsLayerDescriptor", + "Qwen3VLExpertRemovalLayerDescriptor", +] + + +@ModelDescriptorFactory.register_decorator("qwen3_vl") +class Qwen3VLModelDescriptor(ModelDescriptor): + @staticmethod + def uses_autocast() -> bool: + """ + Qwen3-VL MoE has a dtype bug in HuggingFace transformers under torch.autocast: + scatter() in MoE routing fails with dtype mismatch. Use native bfloat16 instead. + See: https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct (recommended approach) + """ + return False + + @staticmethod + def get_language_model_config(config): + """Qwen3-VL has nested text_config for language model parameters.""" + return config.text_config if hasattr(config, "text_config") else config + + @staticmethod + def decoder_layer_cls(): + return Qwen3VLMoeTextDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {"num_key_value_heads": block_config.attention.num_key_value_heads} + + if block_config.ffn.moe: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_experts"] = block_config.ffn.moe.num_local_experts + else: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model, runtime): + # Re-initialize text rotary embedding on correct device and dtype + text_config = Qwen3VLModelDescriptor.get_language_model_config(model.config) + model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding( + config=text_config + ).to(device=runtime.device, dtype=runtime.dtype) + # Re-initialize vision rotary embedding on correct device and dtype + vision_config = ( + model.config.vision_config if hasattr(model.config, "vision_config") else None + ) + if vision_config is not None: + head_dim = vision_config.hidden_size // vision_config.num_heads + model.model.visual.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2).to( + device=runtime.device, dtype=runtime.dtype + ) + + @staticmethod + def input_embedding_name(): + return "model.language_model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.language_model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.language_model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + # Qwen3-VL has text model under model.language_model.* prefix + layer_name_patterns = { + "embeddings": re.compile(r"^model\.language_model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.language_model\.norm\.weight|lm_head\.weight)$"), + # Vision encoder (includes merger under model.visual.deepstack_merger_list.*) + "vision_encoding": re.compile(r"^model\.visual\..*"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + # MoE router + r"|mlp\.gate\.weight" + # MoE experts - fused format (gate_up_proj, down_proj without .weight suffix) + r"|mlp\.experts\.gate_up_proj" + r"|mlp\.experts\.down_proj" + # Shared expert (if present) + r"|mlp\.shared_expert\.up_proj\.weight" + r"|mlp\.shared_expert\.gate_proj\.weight" + r"|mlp\.shared_expert\.down_proj\.weight" + r"|mlp\.shared_expert_gate\.weight" + # Dense MLP fallback (for non-MoE layers) + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3VLFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3VLKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + +@dataclass +class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + Qwen3-VL MoE layer descriptor. + + Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py + - Qwen3VLMoeTextSparseMoeBlock: MoE block with .gate (router) and .experts + - Qwen3VLMoeTextTopKRouter: Router with .weight (no bias) + - Qwen3VLMoeTextExperts: Fused experts with .gate_up_proj and .down_proj tensors + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + # Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=list) + # Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors + # with shape [num_experts, ...] instead of separate tensors per expert. + is_fused_experts: bool = True + fused_expert_weights: List[str] = field( + default_factory=lambda: ["experts.gate_up_proj", "experts.down_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py new file mode 100644 index 0000000000..695e4495b0 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for patching and transforming HuggingFace models to work with AnyModel. + +Provides no-op modules for layer replacement and patching utilities for heterogeneous +per-layer configurations. +""" + +from .no_op import * +from .patcher import * diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py new file mode 100644 index 0000000000..795ce6d67b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-op modules for replacing layers during pruning.""" + +from functools import cache + +import torch +import torch.nn as nn + +__all__ = [ + "return_tuple_of_size", + "MatchingZeros", + "Same", +] + + +@cache +def return_tuple_of_size(cls: type[nn.Module], size: int) -> type[nn.Module]: + """Create a wrapper class that returns a tuple of the given size. + + Useful for replacing modules that return multiple outputs (e.g., attention layers + that return (hidden_states, attn_weights)). + + Args: + cls: The base module class to wrap. + size: The size of the tuple to return. + + Returns: + A new class that wraps the base class and returns a tuple of the given size. + + Example: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + + class Wrapped(cls): + def forward(self, *args, **kwargs): + result = super().forward(*args, **kwargs) + outputs = [None] * size + outputs[0] = result if isinstance(result, torch.Tensor) else result[0] + return tuple(outputs) + + def extra_repr(self) -> str: + return f"[{cls.__name__}]" + + return Wrapped + + +class MatchingZeros(nn.Module): + """Module that returns zeros matching the input shape. + + Used to replace MLP or attention layers with no-ops. Returns zeros because + the hidden_states are added to the residuals, so a no-op implementation + should leave the residual unchanged. + """ + + def forward(self, hidden_states, *args, **kwargs): + return torch.zeros_like(hidden_states) + + +class Same(nn.Module): + """Module that returns the input unchanged. + + Used to replace normalization layers with identity operations. + """ + + def forward(self, hidden_states, *args, **kwargs): + return hidden_states + + @property + def weight(self): + """Support NemotronH with scoring_activations, when lm_head is called `self.lm_head.weight.dtype`.""" + return torch.empty(0) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/patcher.py b/modelopt/torch/puzzletron/anymodel/puzzformer/patcher.py new file mode 100644 index 0000000000..52c094b3c3 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/patcher.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import inspect +from contextlib import ExitStack, contextmanager +from functools import wraps +from typing import Any, Dict, List + +from transformers import PretrainedConfig + +from ...block_config import BlockConfig, maybe_cast_block_configs +from ..model_descriptor.base import ModelDescriptor + +__all__ = [ + "deci_x_patcher", + "override_config_with_block_configs", +] + + +def _get_variable_from_stack(names: list[str]) -> Any: + """Search the call stack for a variable with one of the given names.""" + f = inspect.currentframe().f_back + while f: + for name in names: + if name in f.f_locals: + return f.f_locals[name] + f = f.f_back + raise RuntimeError(f"{names} not found in caller stack") + + +@contextmanager +def deci_x_patcher( + model_descriptor: ModelDescriptor, + block_configs: List[BlockConfig | dict] | None = None, +): + """Context manager that patches decoder layer __init__ for heterogeneous per-layer configs. + + This is the core mechanism that enables AnyModel to work with any HuggingFace model. + It patches the decoder layer class(es) to read per-layer block_configs and apply + layer-specific overrides (e.g., different intermediate_size per layer). + + Args: + model_descriptor: The model descriptor that defines which classes to patch + and how to map block_configs to layer overrides. + block_configs: Optional list of BlockConfig (one per layer). If not provided, + will try to read from config.block_configs during model initialization. + + Example: + >>> with deci_x_patcher(LlamaModelDescriptor, block_configs): + ... model = AutoModelForCausalLM.from_config(config) + """ + decoder_layer_classes = model_descriptor.decoder_layer_cls() # Now a list of classes + if not isinstance(decoder_layer_classes, list): + decoder_layer_classes = [decoder_layer_classes] + + orig_inits = [] + for cls in decoder_layer_classes: + orig_inits.append(cls.__init__) + + block_configs = maybe_cast_block_configs(block_configs) + + @wraps(orig_inits[0]) + def _patched_decoder_layer_init(self, config, *args, **kwargs): + _block_configs = block_configs or getattr(config, "block_configs", None) + if _block_configs is None: + return orig_inits[decoder_layer_classes.index(self.__class__)]( + self, config, *args, **kwargs + ) + + _block_configs = maybe_cast_block_configs(_block_configs) + layer_idx = _get_variable_from_stack(["layer_idx", "idx"]) + _block_config = _block_configs[layer_idx] + override_block_config = model_descriptor.block_config_to_layer_overrides(_block_config) + _config = override_config_with_block_configs(config, override_block_config) + orig_inits[decoder_layer_classes.index(self.__class__)](self, _config, *args, **kwargs) + + # Apply no-op post-init + if _block_config.attention.no_op: + if not model_descriptor.attn_no_op_supported(): + raise NotImplementedError( + f"attn no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `attn_no_op_post_init()`" + ) + model_descriptor.attn_no_op_post_init(decoder_layer=self) + + if _block_config.ffn.no_op: + if not model_descriptor.mlp_no_op_supported(): + raise NotImplementedError( + f"mlp no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `mlp_no_op_post_init()`" + ) + model_descriptor.mlp_no_op_post_init(decoder_layer=self) + + with ExitStack() as stack: + # Patch every decoder layer class + for orig_init, cls in zip(orig_inits, decoder_layer_classes): + stack.callback(setattr, cls, "__init__", orig_init) # Restore on exit + cls.__init__ = _patched_decoder_layer_init + yield + + +def override_config_with_block_configs( + config: PretrainedConfig, block_configs: Dict[str, Any] +) -> PretrainedConfig: + """Create a copy of config with block_config overrides applied.""" + _config = copy.deepcopy(config) + # Model initialization requires fails with None in case of no-ops + _config_overrides = {k: v for k, v in block_configs.items() if v is not None} + _config.update(_config_overrides) + return _config diff --git a/modelopt/torch/puzzletron/block_config.py b/modelopt/torch/puzzletron/block_config.py new file mode 100644 index 0000000000..bf68a72c13 --- /dev/null +++ b/modelopt/torch/puzzletron/block_config.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import dataclasses +import inspect +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, List, Optional, Type, Union, get_args, get_origin + +__all__ = [ + "BaseDataclass", + "SubblockConfig", + "MoEConfig", + "MambaConfig", + "Llama4AttentionConfig", + "AttentionConfig", + "FFNConfig", + "SUBBLOCK_CLS_DICT", + "BlockConfig", + "maybe_cast_block_configs", +] + + +@dataclass(frozen=True, kw_only=True) +class BaseDataclass: + """ + A dataclass base class with several utilities: + 1. Comparison via string representation. + 2. Initialization of dataclasses fields from dicts. + 3. Setting attributes even though it's frozen (but only inside __post_init__!) + """ + + def __eq__(self, other: "BaseDataclass") -> bool: + return str(self) == str(other) + + def __hash__(self) -> int: + return hash(str(self)) + + def __lt__(self, other: "BaseDataclass") -> bool: + return str(self) < str(other) + + def _force_setattr(self, name: str, value: Any) -> None: + """ + Set an attribute even in frozen dataclasses. + Use only inside __post_init__! + """ + assert _is_called_from_post_init(), ( + "_force_setattr should only be called from __post_init__, " + "if you need to change an attribute use dataclasses.replace " + "or create a new instance :)" + ) + object.__setattr__(self, name, value) + + def __post_init__(self): + """ + Init dataclass fields from dicts + """ + for field in dataclasses.fields(self): + field_dict = getattr(self, field.name) + if isinstance(field_dict, dict) and _is_dataclass_type(field.type): + dataclass_cls = _get_dataclass_type(field.type) + sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)] + unsupported_fields = [ + field_name for field_name in field_dict.keys() if field_name not in sub_fields + ] + if len(unsupported_fields) > 0: + warnings.warn( + f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}" + ) + + field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields} + self._force_setattr(field.name, dataclass_cls(**field_dict)) + + +def _is_called_from_post_init() -> bool: + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "__post_init__": + return True + frame = frame.f_back + return False + + +def _is_dataclass_type(tp: Type) -> bool: + """ + Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type + """ + try: + _get_dataclass_type(tp) + return True + except: + return False + + +def _get_dataclass_type(tp: Type) -> dataclass: + """ + If the given type is a dataclass, the function returns it. + If it is a Union[] or Optional[], the function extracts the first dataclass type. + If no dataclass type is found, the function raises a ValueError. + """ + origin = get_origin(tp) + if origin is Union: + for type_in_union in get_args(tp): + if dataclasses.is_dataclass(type_in_union): + return type_in_union + if dataclasses.is_dataclass(tp): + return tp + raise ValueError("Not a dataclass") + + +@dataclass(frozen=True, kw_only=True) +class SubblockConfig(BaseDataclass): + """Base configuration for a subblock (e.g. attention or FFN) within a transformer block.""" + + no_op: bool = False + replace_with_linear: bool = False + sparsify: Optional[list[str]] = None + weights_precision: Optional[str] = "bf16" + + def __post_init__(self): + super().__post_init__() + assert not (self.no_op and self.replace_with_linear) + if self.no_op: + self._force_setattr("sparsify", None) + + @abstractmethod + def to_blockconfig(self) -> "BlockConfig": + """ " + Convert to a block including this subblock only. + """ + ... + + +@dataclass(frozen=True, kw_only=True) +class MoEConfig(BaseDataclass): + """ + Configuration class for Mixture of Experts parameters. + """ + + num_local_experts: int = 8 + num_experts_per_tok: int = 1 + expert_intermediate_dim: int = 8192 + shared_expert_intermediate_dim: int = 8192 + # router_aux_loss_coef: float = 0.01 + # router_z_loss_coef: float = 0.0 # Optional z-loss coefficient + + def __post_init__(self): + # Validate the configuration + if self.num_local_experts <= 0: + raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}") + if self.num_experts_per_tok <= 0: + raise ValueError( + f"num_experts_per_tok must be positive, got {self.num_experts_per_tok}" + ) + if self.num_experts_per_tok > self.num_local_experts: + raise ValueError( + f"num_experts_per_tok ({self.num_experts_per_tok}) cannot be greater than num_local_experts ({self.num_local_experts})" + ) + # if self.router_aux_loss_coef < 0: + # raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}") + + +@dataclass(frozen=True, kw_only=True) +class MambaConfig(BaseDataclass): + """Configuration for a Mamba (state-space model) subblock.""" + + state_dim: int + num_heads: int + head_dim: int + num_groups: int + + +@dataclass(frozen=True, kw_only=True) +class Llama4AttentionConfig(BaseDataclass): + """Configuration for Llama-4-specific attention parameters.""" + + attention_chunk_size: Optional[int] = None + use_rope: Optional[bool] = None + use_qk_norm: Optional[bool] = None + attn_scale: Optional[float] = None + floor_scale: Optional[float] = None + attn_temperature_tuning: Optional[bool] = None + attention_dropout: Optional[float] = None + + +@dataclass(frozen=True, kw_only=True) +class AttentionConfig(SubblockConfig): + """Configuration for an attention subblock within a transformer block.""" + + num_key_value_heads: Optional[int] = None + llama4: Optional[Llama4AttentionConfig] = None + mamba: Optional[MambaConfig] = None + + def __post_init__(self): + super().__post_init__() + + if self.no_op: + assert not self.is_mamba + assert not self.is_llama4 + + if self.no_op or self.is_mamba: + for irrelevant_att in [ + "num_key_value_heads", + ]: + self._force_setattr(irrelevant_att, None) + else: + assert self.num_key_value_heads is not None + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) + + @property + def is_llama4(self) -> bool: + return self.llama4 is not None + + @property + def is_mamba(self) -> bool: + return self.mamba is not None + + +@dataclass(frozen=True, kw_only=True) +class FFNConfig(SubblockConfig): + """Configuration for a feed-forward network subblock within a transformer block.""" + + moe: Optional[MoEConfig] = None + intermediate_size: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.no_op: + self._force_setattr("moe", None) + self._force_setattr("intermediate_size", None) + elif self.is_moe: + self._force_setattr("intermediate_size", None) + else: + assert self.intermediate_size is not None, ( + "Intermediate size must be provided for an FFN block" + ) + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) + + @property + def is_moe(self) -> bool: + return self.moe is not None + + +SUBBLOCK_CLS_DICT = { + "attention": AttentionConfig, + "ffn": FFNConfig, +} + + +@dataclass(frozen=True, kw_only=True) +class BlockConfig(BaseDataclass): + """Configuration for a single transformer block, including its attention and FFN subblocks.""" + + attention: Optional[AttentionConfig] = None + ffn: Optional[FFNConfig] = None + parallel_blocks: Optional[list["BlockConfig"]] = None + + def __post_init__(self): + super().__post_init__() + if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict): + initialized_block_configs = [ + BlockConfig(**block_config) for block_config in self.parallel_blocks + ] + self._force_setattr("parallel_blocks", initialized_block_configs) + + def to_dict(self) -> dict: + """Convert BlockConfig to a dictionary.""" + return dataclasses.asdict(self) + + +def maybe_cast_block_configs( + block_configs: List[BlockConfig | dict] | None, +) -> List[BlockConfig] | None: + """Cast a list of dicts to BlockConfig objects if needed. + + Args: + block_configs: List of BlockConfig or dict objects, or None. + + Returns: + List of BlockConfig objects, or None if input is None/empty. + """ + if not block_configs: + return block_configs + if isinstance(block_configs[0], dict): + return [BlockConfig(**conf) for conf in block_configs] + return block_configs diff --git a/modelopt/torch/puzzletron/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py new file mode 100644 index 0000000000..efa6747a9c --- /dev/null +++ b/modelopt/torch/puzzletron/build_library_and_stats.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. + +This script combines the functionality of both commands into a single workflow: +1. First, it builds the replacement library for the puzzle +2. Then, it calculates subblock statistics + +Usage: + + python modelopt.torch.puzzletron.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset + +The script uses the same Hydra configuration as the individual commands and supports +all the same configuration parameters for both build_replacement_library and calc_subblock_stats. +""" + +from omegaconf import DictConfig + +from .replacement_library.build_replacement_library import launch_build_replacement_library +from .subblock_stats.calc_subblock_stats import launch_calc_subblock_stats +from .tools.logger import mprint + +__all__ = ["launch_build_library_and_stats"] + + +def launch_build_library_and_stats(cfg: DictConfig) -> None: + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. + + Args: + cfg: Hydra configuration containing settings for both commands + """ + mprint("=" * 80) + mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW") + mprint("=" * 80) + + # Step 1: Build replacement library + mprint("=" * 50) + mprint("STEP 1: Building Replacement Library") + mprint("=" * 50) + + try: + launch_build_replacement_library(cfg) + mprint("✅ Replacement library built successfully!") + except Exception as e: + mprint(f"❌ Failed to build replacement library: {e}") + raise + + # Step 2: Calculate subblock statistics + mprint("=" * 50) + mprint("STEP 2: Calculating Subblock Statistics") + mprint("=" * 50) + + try: + launch_calc_subblock_stats(cfg) + mprint("✅ Subblock statistics calculated successfully!") + except Exception as e: + mprint(f"❌ Failed to calculate subblock statistics: {e}") + raise + + mprint("=" * 80) + mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉") + mprint("=" * 80) + + mprint("Generated files:") + mprint(f" - {cfg.puzzle_dir}/block_library.json") + mprint(f" - {cfg.puzzle_dir}/subblock_library.json") + mprint(f" - {cfg.puzzle_dir}/replacement_library.json") + mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json") + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") + if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") diff --git a/modelopt/torch/puzzletron/dataset/__init__.py b/modelopt/torch/puzzletron/dataset/__init__.py new file mode 100644 index 0000000000..f65aa299e8 --- /dev/null +++ b/modelopt/torch/puzzletron/dataset/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset preparation utilities for Puzzletron.""" + +from .prepare_dataset import * diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py new file mode 100644 index 0000000000..0928b111af --- /dev/null +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import datasets +import fire +import numpy as np + +from ..tools.logger import mprint + +__all__ = ["process_and_save_dataset"] + + +def process_and_save_dataset( + dataset_name: str, + output_dir: str, + split: tuple = ("code", "math", "stem", "chat"), + overwrite: bool = False, +): + # Check if output_dir contains an existing dataset + dataset_dict_path = os.path.join(output_dir, "dataset_dict.json") + if os.path.exists(output_dir) and os.path.exists(dataset_dict_path): + if not overwrite: + mprint( + f"Output directory '{output_dir}' already contains a dataset. " + "Use '--overwrite True' to overwrite existing data." + ) + return + + ds = datasets.load_dataset(dataset_name, split=split) + ds = datasets.concatenate_datasets(ds) + # Filter out samples with reasoning = on + ds = ds.filter(lambda x: x["reasoning"] == "off") + # Hardcoded for dynamically create a deterministic train-val split + seed = 408 + generator = np.random.RandomState(seed=seed) + ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) + # Rename dataset names to follow previous conventions + ds_dict = datasets.DatasetDict( + { + "train": ds_split["train"], + "valid": ds_split["test"], + } + ) + # Save locally + os.makedirs(output_dir, exist_ok=True) + ds_dict.save_to_disk(output_dir) + + mprint(f"Dataset splits:\n{ds_dict}") + mprint(f"Saved processed datasets to {output_dir}") + + +if __name__ == "__main__": + fire.Fire(process_and_save_dataset) diff --git a/modelopt/torch/puzzletron/entrypoint.py b/modelopt/torch/puzzletron/entrypoint.py new file mode 100644 index 0000000000..6a4af31ce3 --- /dev/null +++ b/modelopt/torch/puzzletron/entrypoint.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides the main compression function for a model using MIP-based NAS search algorithm.""" + +import hydra +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + +from .activation_scoring import launch_score_activations +from .build_library_and_stats import launch_build_library_and_stats +from .mip import launch_mip_and_realize_model +from .pruning import launch_prune_ckpt +from .scoring import launch_scoring +from .tools.hydra_utils import initialize_hydra_config_for_dir + +__all__ = ["puzzletron"] + + +def puzzletron( + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str +) -> DictConfig: + """Compress a model using the MIP-based NAS search algorithm from Puzzletron. + + Args: + hydra_config_dir (str): path to a hydra_config_dir that defines the search space + hydra_config (str): the corresponding hydra config file + puzzle_dir (str): directory with a puzzletron model to compress + dataset_path (str): dataset used for scoring and distillation + + Returns: + Hydra config object after compressing the model. + The same hydra configuration object is used across all compression steps. + TODO: Investigate if this config object is immutable across steps and clarify + """ + # Step 0: Load puzzletron hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 1: score_pruning_activations (distributed processing) + launch_score_activations(hydra_cfg) + + # Step 2: pruning_ckpts (single process) + if dist.is_master(): + launch_prune_ckpt(hydra_cfg) + dist.barrier() + + # Step 3: build_library_and_stats (single process) + if dist.is_master(): + launch_build_library_and_stats(hydra_cfg) + dist.barrier() + + # Step 4: calc_one_block_scores (distributed processing) + launch_scoring(hydra_cfg) + + # Step 5: mip_and_realize_models (distributed processing) + launch_mip_and_realize_model(hydra_cfg) + + return hydra_cfg diff --git a/modelopt/torch/puzzletron/mip/__init__.py b/modelopt/torch/puzzletron/mip/__init__.py new file mode 100644 index 0000000000..2941b77a4b --- /dev/null +++ b/modelopt/torch/puzzletron/mip/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MIP-based pruning: model optimization via mixed-integer programming.""" + +from .mip_and_realize_models import * +from .sweep import * diff --git a/modelopt/torch/puzzletron/mip/mip_and_realize_models.py b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py new file mode 100644 index 0000000000..7a04322cce --- /dev/null +++ b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs MIP (Mixed Integer Programming) optimization and realizes the resulting model solutions.""" + +# mypy: ignore-errors +from pathlib import Path + +import torch +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + +from ..tools.logger import mprint +from ..tools.validate_puzzle_with_multi_replacements import validate_puzzle_solutions +from .run_puzzle import run_puzzle + +__all__ = [ + "launch_realize_model", + "launch_mip_and_realize_model", +] + + +def launch_mip(cfg: DictConfig) -> list[str]: + solution_paths = run_puzzle(args=cfg.mip) + return solution_paths + + +def launch_realize_model(cfg: DictConfig): + validate_puzzle_solutions(args=cfg.realize_model) + + +def launch_mip_and_realize_model(cfg: DictConfig) -> list[str]: + # Determine device for distributed operations (NCCL requires CUDA tensors) + device = "cpu" + if dist.size() > 1: + if torch.distributed.get_backend() == "nccl": + device = torch.cuda.current_device() + + if dist.is_master(): + solution_paths = launch_mip(cfg) + length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device) + else: + solution_paths = None + length_tensor = torch.tensor([0], dtype=torch.long, device=device) + + if not cfg.skip_realize_model: + if dist.size() > 1: + torch.distributed.broadcast(length_tensor, src=0) + + list_length = length_tensor.item() + + if not dist.is_master(): + solution_paths = [None] * list_length + + if dist.size() > 1: + torch.distributed.broadcast_object_list(solution_paths, src=0) + + for solution_path in solution_paths: + mprint(f"Realize model for the solution: {solution_path}") + cfg.realize_model.solutions_path = Path(solution_path) + launch_realize_model(cfg) + dist.barrier() + + return solution_paths diff --git a/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py new file mode 100644 index 0000000000..a906f88636 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solves multi-layer replacement optimization using Mixed Integer Programming.""" + +# mypy: ignore-errors +import math +import warnings +from collections import defaultdict +from collections.abc import Hashable, Iterable +from copy import deepcopy +from random import random +from typing import Any, TypeAlias + +from mip import BINARY, Model, maximize, minimize, xsum + +from .utils import consecutive_ngrams, get_nested_key, sort_replacements + +__all__ = ["run_mip"] + +ReplacementID: TypeAlias = Hashable +Replacement: TypeAlias = dict[str, Any] +ChosenReplacements: TypeAlias = list[Replacement] + + +def run_mip( + replacements: dict[ReplacementID, Replacement], + objective: str, + constraints: dict[str, float], + bigger_is_better: bool, + max_seconds_per_solution: float | None = None, +) -> tuple[ChosenReplacements, float, dict[str, float]]: + orig_num_replacements = len(replacements) + replacements = { + replacement_id: deepcopy(replacement) + for replacement_id, replacement in replacements.items() + if math.isfinite(get_nested_key(replacement, objective)) + } + if len(replacements) < orig_num_replacements: + print("\n\n\n") + warnings.warn( + f"mip: removed {orig_num_replacements - len(replacements)} replacements with NaN/inf objective value" + ) + print("\n\n\n") + + if not replacements: + return [], 0.0, {} + + mip_model = Model() + + objective_vars = [] + constraint_vars = {constraint_key: [] for constraint_key in constraints} + choice_indicators_by_layer = defaultdict(list) + for replacement_id, replacement in replacements.items(): + is_chosen = mip_model.add_var(var_type=BINARY) + replacement["is_chosen"] = is_chosen + + for parent_layer_idx in replacement["parent_layer_indices"]: + choice_indicators_by_layer[parent_layer_idx].append(is_chosen) + + objective_vars.append(is_chosen * get_nested_key(replacement, objective)) + + for constraint_key in constraints: + constraint_vars[constraint_key].append( + is_chosen * get_nested_key(replacement, constraint_key) + ) + + # MIP constraints: each parent layer must come from exactly one chosen replacement + for parent_layer_idx, curr_choice_indicators in choice_indicators_by_layer.items(): + mip_model += xsum(curr_choice_indicators) == 1 + + # MIP constraints: the sum of chosen replacement costs must be lower than the max cost + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) <= max_cost + if min_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) >= min_cost + + # MIP objective + mip_model.objective = ( + maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars)) + ) + + if max_seconds_per_solution is not None: + mip_model.max_seconds = max_seconds_per_solution + + mip_model.optimize() + + if is_chosen.x is None: + return [] + # raise InfeasibleError() + + # Trust But Verify: calculate total value and costs, and check that all the constraints are filled + total_value = 0.0 + total_costs = dict.fromkeys(constraints.keys(), 0) + chosen_replacements: ChosenReplacements = [] + chosen_layers = [] + for replacement_id, replacement in replacements.items(): + is_chosen = replacement["is_chosen"].x >= 0.99 + if is_chosen: + assert replacement not in chosen_replacements + chosen_replacements.append(replacement) + total_value += get_nested_key(replacement, objective) + for constraint_key in constraints: + total_costs[constraint_key] += get_nested_key(replacement, constraint_key) + for parent_layer_idx in replacement["parent_layer_indices"]: + assert parent_layer_idx not in chosen_layers + chosen_layers.append(parent_layer_idx) + + missing_layers = set(choice_indicators_by_layer.keys()) - set(chosen_layers) + assert len(missing_layers) == 0, ( + f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}" + ) + + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + assert total_costs[constraint_key] < max_cost or math.isclose( + total_costs[constraint_key], max_cost, rel_tol=1e-9 + ), ( + f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}" + ) + if min_cost is not None: + assert total_costs[constraint_key] > min_cost or math.isclose( + total_costs[constraint_key], min_cost, rel_tol=1e-9 + ), ( + f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}" + ) + + chosen_replacements = sort_replacements(chosen_replacements) + for cr in chosen_replacements: + del cr["is_chosen"] # not copyable, will cause errors in deep copy + if "block_config" in cr: + cr["child_block_configs"] = cr["block_config"] + # del cr['block_config'] for now the dump includes both keys (duplicated values) # we might wanna either delete one of them or keep both + # I prefer keeping block_config and deleting 'child_block_configs' from previous puzzle steps + + return [ + { + "chosen_replacements": chosen_replacements, + "total_value": total_value, + "total_costs": total_costs, + } + ] + + +def usage_example(): + num_layers = 32 + num_options_per_parent_replacement = 5 + + replacements = dict() + for num_layers_in_replacement in (1, 2, 3): + for i_option in range(num_options_per_parent_replacement): + for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement): + replacement_id = f"parent layers {parent_layer_indices} child config {i_option}" + replacement = { + "parent_layer_indices": parent_layer_indices, + "metrics": {"loss": random()}, + "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10}, + "replacement_id": replacement_id, + } + replacements[replacement_id] = replacement + + constraints = {"stats.memory_mib": num_layers * 15.0, "stats.runtime_ms": num_layers * 1.5} + (result,) = run_mip( + replacements, + objective="metrics.loss", + constraints=constraints, + bigger_is_better=False, + ) + chosen_replacements = result["chosen_replacements"] + total_value = result["total_value"] + total_costs = result["total_costs"] + + print() + print() + print(f"{total_value=}") + print(f"{total_costs=}") + print(f"{constraints=}") + print("chosen_replacements=") + print("\n".join([rep["replacement_id"] for rep in chosen_replacements])) + + +if __name__ == "__main__": + usage_example() diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py new file mode 100644 index 0000000000..761534f6df --- /dev/null +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -0,0 +1,764 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main entry point for running the puzzle optimization to find optimal layer configurations.""" + +# mypy: ignore-errors +import argparse +import dataclasses +import enum +import json +import sys +from collections.abc import Hashable, Iterable +from copy import deepcopy +from pathlib import Path +from typing import Any, Literal, TypeAlias + +import numpy as np +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +from modelopt.torch.utils import json_dump + +from ..anymodel.model_descriptor import ModelDescriptorFactory +from ..block_config import AttentionConfig, BlockConfig, FFNConfig +from ..replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + replacement_is_teacher, +) +from ..tools.checkpoint_utils import load_model_config +from ..tools.logger import mprint +from ..utils.misc import block_config_to_str, solution_to_str +from ..utils.parsing import get_nested_key, parse_json, parse_path +from .mip_with_multi_layer_replacements import run_mip as run_multi_layer_replacement_mip + +__all__ = [ + "PuzzleMetrics", + "MultiLayerPuzzleMetrics", + "run_puzzle", + "gather_multi_layer_puzzle_metrics", + "filter_subblock_stats_by_args", +] + +""" +Usage: +Must specify either --single_block_replacement_validation_dir and --subblock_stats_path (in which case the metrics will +be gathered from the validation output files) or --gathered_metrics_path (in which case the metrics will be read from +this json file). + +Constraints can be specified either as 'mip_constraints' (the actual constraints that go into the MIP, e.g. 'stats.memory_mib', 'stats.runtime_ms'), +or as "human constraints" (e.g. 'target_memory', 'target_throughput', for the full list see PuzzleConstraints._ALLOWED_HUMAN_CONSTRAINTS). + +""" + +PuzzleMetrics: TypeAlias = dict[Hashable, dict[Hashable, dict[str, float]]] +MultiLayerPuzzleMetrics: TypeAlias = dict[str, dict[str, Hashable]] + + +@dataclasses.dataclass +class PuzzleConstraints: + """A set of puzzle constraints can be expressed either directly as the mip constraints (e.g. 'stats.memory_mib') or as human constraints (e.g. 'target_throughput')""" + + class Type(enum.Enum): + MIP = "mip" + HUMAN = "human" + + _ALLOWED_HUMAN_CONSTRAINTS = { + "target_memory", + "target_throughput", + "target_latency", + "target_time_to_first_token", + "num_params", + "stats.has_attention", + } + type: Type + name: str = dataclasses.field(init=False) + constraints: dict[str, Any] + + @staticmethod + def sizeof_fmt(num, suffix=""): + for unit in ("", "K", "M", "G", "T"): + if abs(num) < 1000.0: + return f"{num:g}{unit}{suffix}" + num /= 1000.0 + return f"{num:.1f}P{suffix}" + + def _validate_human_constraints(self): + illegal_constraints = set(self.constraints.keys()) - self._ALLOWED_HUMAN_CONSTRAINTS + if illegal_constraints: + raise ValueError( + f"The following human_constraints are illegal: {','.join(illegal_constraints)}" + ) + + def format_num_params_to_float(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_float(x) for x in num_params] + if isinstance(num_params, str): + # we only deal with Billions of params scale + return float(num_params.replace("B", "")) * 1e9 + return num_params + + def format_num_params_to_str(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_str(x) for x in num_params] + if isinstance(num_params, float) or isinstance(num_params, int): + return f"{num_params / 1e9}B" + return num_params + + def __post_init__(self): + if self.type == PuzzleConstraints.Type.HUMAN: + self._validate_human_constraints() + + if "stats.active_params" in self.constraints: + self.constraints["stats.active_params"] = self.format_num_params_to_float( + self.constraints["stats.active_params"] + ) + + # Set self.name + constraints = deepcopy(self.constraints) # going to override with "human readable" versions + if "stats.active_params" in constraints: + constraints["stats.active_params"] = self.format_num_params_to_str( + constraints["stats.active_params"] + ) + + if self.type == PuzzleConstraints.Type.HUMAN: + # change values to a more human string form + if "target_memory" in constraints: + constraints["target_memory"] = str(constraints["target_memory"]) + "MiB" + if "num_params" in constraints: + constraints["num_params"] = self.sizeof_fmt(constraints["num_params"]) + + def build_constraint_name(constraint_name, constraint_value): + if isinstance(constraint_value, Iterable) and not isinstance(constraint_value, str): + return "-".join(f"{constraint_name}_{x}" for x in constraint_value) + else: + return f"{constraint_name}_{constraint_value}" + + self.name = "-".join(build_constraint_name(k, v) for k, v in constraints.items()).replace( + ".", "_" + ) + + def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: + if self.type == PuzzleConstraints.Type.MIP: + return self.constraints + + assert all(key in subblock_stats_args for key in ("batch_size", "generation_seq_len")), ( + "Can't realize human constraints without 'batch_size' and 'generation_seq_len' in subblock_stats_args." + ) + batch_size = subblock_stats_args["batch_size"] + generation_seq_len = subblock_stats_args["generation_seq_len"] + + mip_constraints = {} + + # Memory constraints + if "target_memory" in self.constraints: + mip_constraints["stats.memory_mib"] = self.constraints["target_memory"] + + # Throughput constraints + throughput_constraints = [] + if "target_throughput" in self.constraints: + if self.constraints["target_throughput"] == 0: + raise ValueError("target_throughput must not be zero") + throughput_constraints.append( + batch_size * generation_seq_len / self.constraints["target_throughput"] + ) + if "target_latency" in self.constraints: + throughput_constraints.append(self.constraints["target_latency"]) + if throughput_constraints: + mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints) + + # Prefill runtime constraint + if "target_time_to_first_token" in self.constraints: + mip_constraints["stats.prefill_runtime_ms"] = ( + 1000 * self.constraints["target_time_to_first_token"] + ) + + # Num params + if "num_params" in self.constraints: + mip_constraints["stats.num_params"] = self.constraints["num_params"] + if "stats.has_attention" in self.constraints: + mip_constraints["stats.has_attention"] = self.constraints["stats.has_attention"] + return mip_constraints + + +def parse_args() -> DictConfig: + parser = argparse.ArgumentParser() + + parser.add_argument("--puzzle_profile", type=parse_path) + + parser.add_argument("--single_block_replacement_validation_dir", type=parse_path, default=None) + parser.add_argument( + "--gathered_metrics_path", + type=parse_path, + default=None, + help="Can be given explicitly instead of --single_block_replacement_validation_dir", + ) + + parser.add_argument("--subblock_stats_path", type=parse_path) + parser.add_argument("--subblock_stats_args", type=parse_json) + + parser.add_argument("--objective", type=str) + parser.add_argument("--mip_constraints", type=parse_json) + parser.add_argument("--human_constraints", type=parse_json) + parser.add_argument("--report_additional_costs", type=str, action="append", default=[]) + + parser.add_argument( + "--output_path", + type=parse_path, + help="The main folder under which all results will be stored.", + ) + + parser.add_argument("--max_seconds_per_solution", type=float, default=60.0) + parser.add_argument("--metric_overrides", type=parse_json, default=None) + parser.add_argument( + "--bigger_is_better", + action="store_true", + help="Set this if using accuracy objective, don't set if using loss objective", + ) + + args = parser.parse_args() + return DictConfig(vars(args)) + + +def run_single_puzzle_config( + args: DictConfig, + gathered_metrics: dict, + subblock_stats: dict, + subblock_stats_args: dict, + constraints: PuzzleConstraints, + output_folder, +) -> Path: + # we override the constraints and subblock_stats_args for this run to keep reporting out the same way. + args = deepcopy(args) + + subblock_stats = filter_subblock_stats_by_args(subblock_stats, subblock_stats_args) + _add_block_stats_to_gathered_metrics(gathered_metrics, subblock_stats) + + output_folder.mkdir(parents=True, exist_ok=True) + _dump_gathered_metrics(gathered_metrics, output_folder) + + non_block_stats = {"stats": _get_block_stats(subblock_stats, "non_block")} + batch_size = subblock_stats["args"]["batch_size"] + generation_seq_len = subblock_stats["args"]["generation_seq_len"] + + mip_constraints = constraints.to_mip_constraints(subblock_stats["args"]) + orig_mip_constraints = deepcopy(mip_constraints) + mprint(f"Solving for the following MIP constraints: {mip_constraints}") + args.mip_constraints = orig_mip_constraints + args.human_constraints = ( + constraints.constraints if constraints.type == PuzzleConstraints.Type.HUMAN else None + ) + args.subblock_stats_args = subblock_stats_args + + for stat_name, max_cost in mip_constraints.items(): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + + is_min_max = isinstance(max_cost, Iterable) + min_cost = None + if is_min_max: + min_cost, max_cost = max_cost + + min_cost = min_cost - non_block_cost if (min_cost is not None) else None + max_cost = max_cost - non_block_cost if (max_cost is not None) else None + + if is_min_max: + mip_constraints[stat_name] = (min_cost, max_cost) + else: + mip_constraints[stat_name] = max_cost + + # If there's an additional cost that is not a constraint - set it to "inf" so MIP report the actual value of it. + for cost in set(args.report_additional_costs) - set(orig_mip_constraints.keys()): + mip_constraints[cost] = np.inf + + mprint(f"After non-block adjustments: {mip_constraints=}") + + solutions = run_multi_layer_replacement_mip( + replacements=gathered_metrics, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + max_seconds_per_solution=args.max_seconds_per_solution, + ) + + for solution in solutions: + for stat_name in set([*orig_mip_constraints.keys(), *args.report_additional_costs]): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + solution["total_costs"][stat_name] += non_block_cost + + # Calculate throughput from runtime_ms + if "stats.runtime_ms" in solution["total_costs"]: + total_runtime = solution["total_costs"]["stats.runtime_ms"] + solution["total_costs"]["throughput"] = ( + 1000 * batch_size * generation_seq_len / total_runtime + ) + + solution["total_value"] = {args.objective: solution["total_value"]} + solution["puzzle_args"] = ( + OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else vars(args) + ) + solution["subblock_stats"] = subblock_stats["args"] + chosen_block_configs, _ = extract_block_configs_and_locations( + solution["chosen_replacements"] + ) + solution["chosen_block_configs"] = chosen_block_configs + solution["solution_repr"] = solution_to_str(chosen_block_configs) + + if len(solutions) > 0: + solution_repr_0 = solutions[0]["solution_repr"] + mprint(f"\n{solution_repr_0}") + mprint(f"Total costs: {solutions[0]['total_costs']}") + (output_folder / "solution_repr_0.txt").write_text(solution_repr_0) + + solutions_file = output_folder / "solutions.json" + json_dump(solutions, solutions_file) + mprint(solutions_file) + return solutions_file + + +def _dump_gathered_metrics(gathered_metrics: PuzzleMetrics, output_folder: Path) -> None: + for replacement_id, replacement_info in gathered_metrics.items(): + replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"]) + gathered_metrics_for_dump = gathered_metrics + + json_dump(gathered_metrics_for_dump, output_folder / "replacement_metrics_and_stats.json") + + +def _load_all_constraints(args, puzzle_profile): + def parse_constraints(constraints, constraints_type: PuzzleConstraints.Type): + if isinstance(constraints, (list, ListConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=c) for c in constraints] + elif isinstance(constraints, (dict, DictConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=constraints)] + raise TypeError(f"Invalid constraints type: {constraints_type}") + + # Constraints can be given explicitely + if args.mip_constraints is not None: + return parse_constraints(args.mip_constraints, PuzzleConstraints.Type.MIP) + + if args.human_constraints is not None: + return parse_constraints(args.human_constraints, PuzzleConstraints.Type.HUMAN) + + # Or through the puzzle_profile + if "mip_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["mip_constraints"], PuzzleConstraints.Type.MIP) + + if "human_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["human_constraints"], PuzzleConstraints.Type.HUMAN) + + raise ValueError( + "Constraints must be given either explicitely by --mip_constraints or --human_constraints arguments, or through the puzzle_profile." + ) + + +def _load_all_subblock_stats_args(args, puzzle_profile): + # If given explicitely in args + if args.subblock_stats_args is not None: + if isinstance(args.subblock_stats_args, dict): + return [args.subblock_stats_args] + else: + return args.subblock_stats_args + + # Or can be given inside puzzle_profile + if "subblock_stats_args" in puzzle_profile: + return puzzle_profile["subblock_stats_args"] + + raise ValueError( + "subblock_stats_args must be given either explicitely by the --subblock_stats_args argument, or through the puzzle_profile." + ) + + +def _override_args_from_profile(args, puzzle_profile): + for arg_name in vars(args): + if arg_name in puzzle_profile: + if arg_name not in ("mip_constraints", "human_constraints", "subblock_stats_args"): + setattr(args, arg_name, puzzle_profile[arg_name]) + + +def _assert_valid_config(args, puzzle_profile): + required_args = ( + "subblock_stats_path", + "objective", + "output_path", + ) + missing_args = [ + arg for arg in required_args if not hasattr(args, arg) or getattr(args, arg) is None + ] + if missing_args: + mprint(f"error: The following arguments are required: {', '.join(missing_args)}") + sys.exit(1) + + # Make sure we have specified subblock_stats_args + if not hasattr(args, "subblock_stats_args") and "subblock_stats_args" not in puzzle_profile: + mprint( + "error: Must specify `subblock_stats_args` in either puzzle_profile or as a commandline arg." + ) + sys.exit(1) + + # Make sure we have specified constraints + if ( + not hasattr(args, "mip_constraints") + and not hasattr(args, "human_constraints") + and "mip_constraints" not in puzzle_profile + and "human_constraints" not in puzzle_profile + ): + mprint( + "error: Must specify either `mip_constraints` or `human_constraints` in one of puzzle_profile or as a commandline argument." + ) + sys.exit(1) + + +def _get_minimal_unique_names(dicts: list[dict]) -> list[str]: + all_keys = set(k for d in dicts for k in d.keys()) + all_values = {k: set(d[k] for d in dicts if k in d) for k in all_keys} + non_common_keys = [k for k, values in all_values.items() if len(values) > 1] + + return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] + + +def run_puzzle(args: DictConfig) -> list[str]: + # Loads config from args/puzzle_profile + if args.puzzle_profile is not None: + with open(args.puzzle_profile) as f: + puzzle_profile = yaml.safe_load(f) + _override_args_from_profile(args, puzzle_profile) + mprint(f"Loaded Puzzle profile from {args.puzzle_profile}") + else: + puzzle_profile = {} + _assert_valid_config(args, puzzle_profile) + + # Read Metrics and Stats + if args.gathered_metrics_path is not None: + gathered_metrics = json.loads(args.gathered_metrics_path.read_text()) + else: + gathered_metrics = gather_multi_layer_puzzle_metrics( + args.single_block_replacement_validation_dir + ) + + if args.metric_overrides is not None: + gathered_metrics = {**gathered_metrics, **args.metric_overrides} + + subblock_stats = json.loads(args.subblock_stats_path.read_text()) + + all_subblock_args = _load_all_subblock_stats_args(args, puzzle_profile) + all_subblock_output_folders = [ + args.output_path / unique_name + for unique_name in _get_minimal_unique_names(all_subblock_args) + ] + all_constraints = _load_all_constraints(args, puzzle_profile) + + # Running all puzzles + solution_paths = [] + for subblock_stats_args, subblock_stats_output_folder in zip( + all_subblock_args, all_subblock_output_folders + ): + for constraints in all_constraints: + output_folder = subblock_stats_output_folder / constraints.name + _solution_path = run_single_puzzle_config( + args, + gathered_metrics, + subblock_stats, + subblock_stats_args, + constraints, + output_folder, + ) + solution_paths.append(_solution_path) + return solution_paths + + +def gather_puzzle_metrics( + single_block_replacement_validation_dir: Path, +) -> PuzzleMetrics: + single_block_metrics = [ + _parse_single_block_replacement_metrics(metrics_path) + for metrics_path in single_block_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_block_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_block_replacement_validation_dir, all_metric_names + ) + + n_layer = len(teacher_metrics) + gathered_metrics = {f"block_{block_idx}": dict() for block_idx in range(n_layer)} + for variant_metrics in single_block_metrics + teacher_metrics: + block_config = variant_metrics["block_config"] + block_name = f"block_{variant_metrics['block_idx']}" + # if we explicitly measure teacher's blocks don't override them + gathered_metrics[block_name][block_config] = variant_metrics + # if not gathered_metrics[block_name].get(block_config): + # gathered_metrics[block_name][block_config] = variant_metrics + return gathered_metrics + + +def gather_multi_layer_puzzle_metrics( + single_replacement_validation_dir: Path, +) -> MultiLayerPuzzleMetrics: + single_sequence_metrics = [ + _parse_single_sequence_replacement_metrics(metrics_path) + for metrics_path in single_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_sequence_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_replacement_validation_dir, all_metric_names + ) + + gathered_metrics = { + f"replacement_{replacement_id}": replacement_metrics + for replacement_id, replacement_metrics in enumerate( + single_sequence_metrics + teacher_metrics + ) + } + + return gathered_metrics + + +def _parse_single_block_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_block_replacement = raw_metrics["puzzle_solution"]["single_block_replacement"] + variant_metrics = { + "block_config": BlockConfig(**single_block_replacement["block_config"]), + "block_idx": single_block_replacement["block_idx"], + "metrics": _extract_average_metrics(raw_metrics), + } + return variant_metrics + + +def _parse_single_sequence_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_sequence_replacement = raw_metrics["puzzle_solution"]["single_sequence_replacement"] + if len(single_sequence_replacement["child_block_configs"]) > 1: + raise NotImplementedError( + "Currently we only support many-to-1 replacements, but we can support many-to-many! " + ) + variant_metrics = { + "block_config": BlockConfig(**single_sequence_replacement["child_block_configs"][0]), + # is there cases where child_block_configs has more than one entry? + "parent_layer_indices": single_sequence_replacement["parent_layer_indices"], + "metrics": _extract_average_metrics(raw_metrics), + "layer_replacement": parse_layer_replacement(single_sequence_replacement), + "is_teacher": False, + } + return variant_metrics + + +def _parse_teacher_block_metrics( + single_block_replacement_validation_dir: Path, + all_metric_names: Iterable[str] = ("kl_div_loss",), +) -> list[dict]: + raw_metrics = json.loads((single_block_replacement_validation_dir / "teacher.json").read_text()) + teacher_checkpoint_dir = Path(raw_metrics["args"]["teacher_dir"]).resolve() + descriptor_name = raw_metrics["args"]["descriptor"] + descriptor = ModelDescriptorFactory.get(descriptor_name) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + + teacher_replacements = None + replacement_library_path = raw_metrics["args"].get("replacement_library_path") + if replacement_library_path is not None: + teacher_replacements = dict() + all_layer_replacements = json.loads(Path(replacement_library_path).read_text()) + for layer_replacement in all_layer_replacements: + layer_replacement = parse_layer_replacement(layer_replacement) + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + + teacher_metrics = [ + { + "block_config": block_config, + "block_idx": block_idx, + "parent_layer_indices": [block_idx], + "metrics": { + **dict.fromkeys(all_metric_names, 0.0), # default value 0. for teacher + **_extract_average_metrics(raw_metrics), # override with real value if exists + }, + **( + {"layer_replacement": teacher_replacements[block_idx]} + if teacher_replacements is not None + else {} + ), + "is_teacher": True, + } + for block_idx, block_config in enumerate(teacher_model_config.block_configs) + ] + return teacher_metrics + + +def _extract_average_metrics(raw_metrics: dict[str, dict]) -> dict[str, float]: + average_metrics = dict() + for metric_name in raw_metrics: + metric_dict = raw_metrics[metric_name] + if isinstance(metric_dict, dict) and ("avg" in metric_dict.keys()): + metric_value = raw_metrics[metric_name]["avg"] + average_metrics[metric_name] = metric_value + average_metrics[f"one_minus_{metric_name}"] = 1 - metric_value + return average_metrics + + +def filter_subblock_stats_by_args( + all_subblock_stats: list[dict], + subblock_stats_args: dict[str, Any], + convert_dicts_to_dataclasses: bool = True, +) -> dict[str, dict]: + matching_subblock_stats = [ + subblock_stats + for subblock_stats in all_subblock_stats + if _dict_is_subset(subblock_stats_args, subblock_stats["args"]) + ] + assert len(matching_subblock_stats) == 1, ( + "The provided subblock_stats_args should match exactly one measurement " + f"scenario, instead matched {len(matching_subblock_stats)}:\n" + f"{[m['args'] for m in matching_subblock_stats]}" + ) + subblock_stats = deepcopy(matching_subblock_stats[0]) + + if convert_dicts_to_dataclasses: + class_name_to_class = {klass.__name__: klass for klass in [AttentionConfig, FFNConfig]} + subblocks_dict = dict() + for substats in subblock_stats["subblocks"]: + subblock_config_class = class_name_to_class[substats.pop("subblock_config_class")] + subblock_config = subblock_config_class(**substats.pop("subblock_config")) + dict_key = (subblock_config, None) + if "parent_layer_index" in substats: + dict_key = (subblock_config, substats["parent_layer_index"]) + subblocks_dict[dict_key] = substats + subblock_stats["subblocks"] = subblocks_dict + return subblock_stats + + +def _dict_is_subset(dict1: dict, dict2: dict) -> bool: + return all(item in dict2.items() for item in dict1.items()) + + +def _add_block_stats_to_gathered_metrics( + gathered_metrics: PuzzleMetrics, subblock_stats: dict +) -> None: + for block_name, block_variants in gathered_metrics.items(): + parent_layer_index = None + if "parent_layer_indices" in block_variants: + parent_layer_index = block_variants["parent_layer_indices"][0] + + if "metrics" in block_variants: + # this is a sequence stats object for multi-layer puzzle + block_variants["stats"] = _get_block_stats( + subblock_stats, block_variants["block_config"], parent_layer_index + ) + else: + for block_config, variant_metrics in block_variants.items(): + variant_metrics["stats"] = _get_block_stats(subblock_stats, block_config) + + +def _get_block_stats( + subblock_stats: dict, + block_config: BlockConfig | Literal["non_block"], + parent_layer_index: int = None, +) -> dict[str, float]: + if block_config == "non_block": + return subblock_stats["non_block"] + + if block_config.parallel_blocks is None: + attention_key = (block_config.attention, parent_layer_index) + ffn_key = (block_config.ffn, parent_layer_index) + attention_stats = subblock_stats["subblocks"][attention_key] + ffn_stats = subblock_stats["subblocks"][ffn_key] + assert set(attention_stats.keys()) == set(ffn_stats.keys()) + + block_stats = dict() + for k in attention_stats.keys(): + block_stats[k] = _none_add(attention_stats[k], ffn_stats[k]) + block_stats[f"attention_{k}"] = attention_stats[k] + block_stats[f"ffn_{k}"] = ffn_stats[k] + + block_stats["has_attention"] = int( + not block_config.attention.no_op and block_config.attention.mamba is None + ) + block_stats["has_ffn"] = int(not block_config.ffn.no_op) + block_stats["has_moe"] = int(block_config.ffn.moe is not None) + block_stats["not_no_op"] = int( + not (block_config.attention.no_op and block_config.ffn.no_op) + ) + block_stats["num_kv_heads"] = ( + block_config.attention.num_key_value_heads if block_stats["has_attention"] else 0 + ) + block_stats["num_local_experts"] = ( + block_config.ffn.moe.num_local_experts if block_stats["has_moe"] else 0 + ) + + return block_stats + + # this is a parallel block + ADDITIVE_METRICS = ("memory_mib", "num_params", "kv_cache_memory_mib") + ADDITIVE_METRICS = [ + f"{prefix}{metric}" for prefix in ("", "attention_", "ffn_") for metric in ADDITIVE_METRICS + ] + block_stats = [ + _get_block_stats(subblock_stats, sub_parallel) + for sub_parallel in block_config.parallel_blocks + ] + block_stats = { + k: _none_add_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + if k in ADDITIVE_METRICS + else _none_max_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + for k in block_stats[0].keys() + } + + return block_stats + + +def _none_add(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return a + b + + +def _none_max(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return max(a, b) + + +def _none_add_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_add(r, e) + return r + + +def _none_max_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_max(r, e) + return r + + +if __name__ == "__main__": + args = parse_args() + run_puzzle(args) diff --git a/modelopt/torch/puzzletron/mip/sweep.py b/modelopt/torch/puzzletron/mip/sweep.py new file mode 100644 index 0000000000..ea4e95dc3e --- /dev/null +++ b/modelopt/torch/puzzletron/mip/sweep.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MIP sweep functionality for exploring multiple memory compression rates.""" + +import json +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, OmegaConf +from transformers import PretrainedConfig + +import modelopt.torch.utils.distributed as dist + +from ..anymodel import models # noqa: F401 — register ModelDescriptorFactory entries +from ..anymodel.model_descriptor import ModelDescriptorFactory +from ..tools.checkpoint_utils_hf import load_model_config +from ..tools.logger import mprint +from . import mip_and_realize_models +from .run_puzzle import _get_block_stats, filter_subblock_stats_by_args + +__all__ = [ + "get_teacher_memory_from_subblock_stats", + "get_teacher_num_params_from_subblock_stats", + "extract_solution_results", + "write_results_to_csv", + "run_mip_sweep", +] + + +def _load_teacher_subblock_stats(hydra_cfg: DictConfig) -> tuple[dict[str, Any], PretrainedConfig]: + """Load filtered subblock_stats and teacher ``model_config`` for the current MIP scenario.""" + puzzle_dir = Path(hydra_cfg.puzzle_dir) + teacher_dir = Path(hydra_cfg.teacher_dir) + + descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) + lm_config = descriptor.get_language_model_config(model_config) + hidden_size = lm_config.hidden_size + + mip_subblock_args = hydra_cfg.mip.subblock_stats_args[0] + subblock_stats_args = OmegaConf.to_container(mip_subblock_args, resolve=True) + # Subblock_stats.json can list multiple runs that share batch/dtypes but differ by hidden size; + # filter_subblock_stats_by_args needs n_embd so exactly one row matches the teacher. + subblock_stats_args = {**subblock_stats_args, "n_embd": hidden_size} + + batch_size = subblock_stats_args["batch_size"] + weights_dtype = str(subblock_stats_args["weights_dtype"]) + activations_dtype = str(subblock_stats_args["activations_dtype"]) + kv_cache_dtype = str(subblock_stats_args["kv_cache_dtype"]) + + subblock_stats_path = puzzle_dir / "subblock_stats.json" + if not subblock_stats_path.exists(): + raise FileNotFoundError( + f"subblock_stats.json not found at {subblock_stats_path}. " + "Please run the full pipeline first without --mip-only flag." + ) + + with open(subblock_stats_path) as f: + subblock_stats_list = json.load(f) + + try: + subblock_stats = filter_subblock_stats_by_args(subblock_stats_list, subblock_stats_args) + except AssertionError as e: + raise ValueError( + f"No unique subblock_stats entry for batch_size={batch_size}, " + f"dtypes=({weights_dtype}, {activations_dtype}, {kv_cache_dtype}), " + f"n_embd={hidden_size}" + ) from e + + return subblock_stats, model_config + + +def get_teacher_memory_from_subblock_stats(hydra_cfg: DictConfig) -> float: + """Calculate teacher model memory from subblock_stats.json. + + Sums ``non_block`` and per-layer ``_get_block_stats(subblock_stats, block_config, layer_index)`` + over ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher memory in MiB + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) + + total_memory = subblock_stats.get("non_block", {}).get("memory_mib", 0.0) + + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_memory += block_stats["memory_mib"] + + return total_memory + + +def get_teacher_num_params_from_subblock_stats(hydra_cfg: DictConfig) -> int: + """Calculate total teacher parameter count from subblock_stats.json. + + Sums ``non_block`` and per-layer ``_get_block_stats(...)["num_params"]`` over + ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher parameter count (same units as subblock_stats JSON). + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) + + total_params = subblock_stats.get("non_block", {}).get("num_params", 0) + + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_params += block_stats["num_params"] + + return int(total_params) + + +def extract_solution_results( + solution_path: Path, + target_memory_mib: float, + teacher_memory_mib: float, + compression_rate: float, +) -> dict: + """Extract results from a completed MIP solution. + + Args: + solution_path: Path to the solutions.json file (not the directory!) + target_memory_mib: Target memory constraint used for MIP + teacher_memory_mib: Teacher model memory in MiB + compression_rate: Compression rate applied + + Returns: + Dictionary containing extracted metrics + """ + result = { + "compression_rate": compression_rate, + "target_memory_mib": target_memory_mib, + "teacher_memory_mib": teacher_memory_mib, + } + + # solution_path is the path to solutions.json file, get parent directory + solution_dir = solution_path.parent + + # Load solutions.json for actual memory and parameters + solutions_file = solution_dir / "solutions.json" + with open(solutions_file) as f: + solutions_data = json.load(f) + solution = solutions_data[0] # First solution + total_costs = solution.get("total_costs", {}) + result["actual_memory_mib"] = total_costs.get("stats.memory_mib", None) + result["num_params"] = total_costs.get("stats.num_params", None) + + # Load solution_0.json for accuracy metrics + validation_dir = solution_dir / "solutions--validation" + # TODO: There could be multiple solutions, but we only need the first one. Is it the best solution? + solution_0_file = validation_dir / "solution_0.json" + + with open(solution_0_file) as f: + validation_data = json.load(f) + result["lm_loss"] = validation_data.get("lm_loss", {}).get("avg", None) + result["token_accuracy_top_1"] = validation_data.get("token_accuracy_top_1", {}).get( + "avg", None + ) + result["token_accuracy_top_5"] = validation_data.get("token_accuracy_top_5", {}).get( + "avg", None + ) + result["token_accuracy_top_10"] = validation_data.get("token_accuracy_top_10", {}).get( + "avg", None + ) + + return result + + +def write_results_to_csv(results: list, output_csv: str): + """Write sweep results to CSV file. + + Args: + results: List of result dictionaries + output_csv: Path to output CSV file + """ + import csv + + # Define CSV columns in desired order + columns = [ + "compression_rate", + "target_memory_mib", + "actual_memory_mib", + "teacher_memory_mib", + "num_params", + "lm_loss", + "token_accuracy_top_1", + "token_accuracy_top_5", + "token_accuracy_top_10", + ] + + # Write CSV + output_path = Path(output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=columns) + writer.writeheader() + writer.writerows(results) + + mprint(f"Results written to: {output_path}") + + +def run_mip_sweep(hydra_cfg): + """Run MIP for multiple memory compression rates and generate CSV with results. + + This function is called when mip.sweep.enabled is True in the config. + + Args: + hydra_cfg: Hydra configuration object with mip.sweep settings + """ + mprint("=" * 80) + mprint("MIP Sweep Mode Enabled") + mprint("=" * 80) + + # Get sweep configuration + sweep_cfg = hydra_cfg.mip.sweep + compression_rates = sweep_cfg.memory_compression_rates + output_csv = sweep_cfg.output_csv + puzzle_dir = Path(hydra_cfg.puzzle_dir) + + mprint(f"Compression rates: {compression_rates}") + mprint(f"Output CSV: {output_csv}") + mprint(f"Puzzle directory: {puzzle_dir}") + + # Calculate teacher memory from subblock_stats + teacher_memory = get_teacher_memory_from_subblock_stats(hydra_cfg) + mprint( + f"Teacher memory (from subblock_stats): {teacher_memory:.1f} MiB ({teacher_memory / 1024:.1f} GiB)" + ) + + # Collect results + all_results = [] + + # Run MIP for each compression rate + for compression_rate in compression_rates: + target_memory_mib = teacher_memory * compression_rate + mprint("\n" + "=" * 80) + mprint( + f"Running MIP for compression_rate={compression_rate:.2f} " + f"(target={target_memory_mib:.1f} MiB = {target_memory_mib / 1024:.1f} GiB)" + ) + mprint("=" * 80) + + # Modify config dynamically + hydra_cfg.mip.human_constraints.target_memory = target_memory_mib + + # Run MIP and realize models (reuse existing distributed logic!) + solution_paths = mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + + # Extract results (only on master rank) + if dist.is_master(): + for solution_path in solution_paths: + result = extract_solution_results( + solution_path=Path(solution_path), + target_memory_mib=target_memory_mib, + teacher_memory_mib=teacher_memory, + compression_rate=compression_rate, + ) + all_results.append(result) + + mem = ( + f"{result['actual_memory_mib']:.1f}" + if result["actual_memory_mib"] is not None + else "N/A" + ) + loss = f"{result['lm_loss']:.4f}" if result["lm_loss"] is not None else "N/A" + mprint(f"✓ Results: actual_memory={mem} MiB, lm_loss={loss}") + + # Write results to CSV (only on master rank) + if dist.is_master(): + mprint("\n" + "=" * 80) + mprint("MIP Sweep Complete - Writing Results") + mprint("=" * 80) + write_results_to_csv(all_results, output_csv) + mprint(f"Completed {len(all_results)} sweep runs") + mprint("=" * 80) diff --git a/modelopt/torch/puzzletron/mip/utils.py b/modelopt/torch/puzzletron/mip/utils.py new file mode 100644 index 0000000000..122c09b368 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/utils.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for MIP optimization.""" + +from typing import Any + +__all__ = ["InfeasibleError", "sort_replacements", "get_nested_key", "consecutive_ngrams"] + + +class InfeasibleError(Exception): + """Exception raised when optimization problem is infeasible.""" + + +def sort_replacements(layer_replacements: list[dict]) -> list[dict]: + """Sort layer replacements by parent layer indices. + + Args: + layer_replacements: List of replacement dictionaries containing "parent_layer_indices" + + Returns: + Sorted list of replacements + """ + return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """Access nested dictionary values using dot notation. + + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + + Args: + dictionary: Dictionary to access + nested_key: Dot-separated key path (e.g., "a.b.c") + + Returns: + Value at the nested key location + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def consecutive_ngrams(l: int, n: int) -> list[list[int]]: + """Generate all consecutive n-grams from range(l). + + Splits range(l) into all consecutive n-grams. + + Args: + l: Length of the range + n: Size of each n-gram + + Returns: + List of consecutive n-grams + + Example: + consecutive_ngrams(7, 2) = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6]] + """ + ngrams = [] + for i in range(l - n + 1): + ngrams.append(list(range(i, i + n))) + return ngrams diff --git a/modelopt/torch/puzzletron/plugins/__init__.py b/modelopt/torch/puzzletron/plugins/__init__.py new file mode 100644 index 0000000000..13a749a826 --- /dev/null +++ b/modelopt/torch/puzzletron/plugins/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Export utilities for Puzzletron models.""" + +from modelopt.torch.utils import import_plugin + +with import_plugin("puzzletron_mbridge"): + from .mbridge import * # register bridge adapters diff --git a/modelopt/torch/puzzletron/plugins/mbridge/__init__.py b/modelopt/torch/puzzletron/plugins/mbridge/__init__.py new file mode 100644 index 0000000000..a3743095a1 --- /dev/null +++ b/modelopt/torch/puzzletron/plugins/mbridge/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from .base import * +from .llama import * +from .qwen3 import * diff --git a/modelopt/torch/puzzletron/plugins/mbridge/base.py b/modelopt/torch/puzzletron/plugins/mbridge/base.py new file mode 100644 index 0000000000..a2d230540e --- /dev/null +++ b/modelopt/torch/puzzletron/plugins/mbridge/base.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass, fields + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import ( + HeterogeneousTransformerConfig, + TransformerConfig, +) +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + +# Monkey-patch: add get_config_for_layer to TransformerConfig if missing +# (needed for non-heterogeneous teacher models in this container version) +if not hasattr(TransformerConfig, "get_config_for_layer"): + TransformerConfig.get_config_for_layer = lambda self, layer_number: self + +__all__ = ["heterogeneous_layer_spec", "GenericHeterogeneousProvider", "HeterogeneousBridgeMixin"] + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + Returns empty list for per_block_parameters if not yet initialized (before finalize()). + This allows OmegaConf to serialize/deserialize configs without errors. Actual usage + should call finalize() first to set per_block_parameters as a real attribute. + """ + if name == "per_block_parameters": + # Return existing attribute if set, otherwise [] for OmegaConf compatibility + try: + return object.__getattribute__(self, name) + except AttributeError: + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider. + + This method: + 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all + model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) + 2. Converts the provider to a dict and filters to only fields accepted by + GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid + GPTModelProvider fields are preserved) + 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider + + All parameters from the parent bridge (e.g., LlamaBridge) are maintained because + GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all + the fields that the parent bridge sets. + """ + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + # If no block_configs, fall back to standard (non-heterogeneous) provider. + if not (hasattr(hf_pretrained.config, "block_configs")): + return parent_provider + + provider_kwargs = dataclasses.asdict(parent_provider) + + # Filter to only fields that GenericHeterogeneousProvider accepts. + # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all + # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, + # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits + # from GPTModelProvider, not from model-specific subclasses. + # + # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they + # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., + # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a + # model-specific heterogeneous provider that inherits from the model-specific provider. + valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} + + # Only keep kwargs that are valid fields + provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config diff --git a/modelopt/torch/puzzletron/plugins/mbridge/llama.py b/modelopt/torch/puzzletron/plugins/mbridge/llama.py new file mode 100644 index 0000000000..1894182f52 --- /dev/null +++ b/modelopt/torch/puzzletron/plugins/mbridge/llama.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from .base import HeterogeneousBridgeMixin + +__all__ = ["PuzzletronLlamaAnyModelBridge"] + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/plugins/mbridge/qwen3.py b/modelopt/torch/puzzletron/plugins/mbridge/qwen3.py new file mode 100644 index 0000000000..1c3f6c7384 --- /dev/null +++ b/modelopt/torch/puzzletron/plugins/mbridge/qwen3.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from .base import HeterogeneousBridgeMixin + +__all__ = [] + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge diff --git a/modelopt/torch/puzzletron/pruning/__init__.py b/modelopt/torch/puzzletron/pruning/__init__.py new file mode 100644 index 0000000000..872a68c9ab --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Structured pruning mixins and checkpoint utilities for Puzzletron.""" + +from .expert_removal_pruning_mixin import * +from .ffn_intermediate_pruning_mixin import * +from .kv_heads_pruning_mixin import * +from .pruning_ckpts import * +from .pruning_mixin import * +from .pruning_utils import * diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py new file mode 100644 index 0000000000..2c547a6cac --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook +from modelopt.torch.prune.importance_hooks.expert_removal_hooks import ( + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, +) + +from .pruning_mixin import LayerDescriptor, PruningMixIn +from .pruning_utils import MlpInitMode, _init_moe_module + +__all__ = [ + "ExpertRemovalLayerDescriptor", + "ExpertRemovalPruningMixIn", +] + + +@dataclass +class ExpertRemovalLayerDescriptor(LayerDescriptor): + """Descriptor for expert-removal pruning layers.""" + + # TODO: Add shared expert weights in case it's prunable. + # TODO: Consider removing the segmentation between weight and bias. + + #: Module name for hook registration; supports ``regex:`` prefix. + target_name: str + #: MoE prefix layer name with ``{layer_idx}`` placeholder, + #: e.g. ``model.layers.{layer_idx}.moe``. + moe_prefix_name: str + #: Expert prefix relative to *moe_prefix* with ``{expert_idx}`` placeholder, + #: e.g. ``experts.{expert_idx}``. + expert_prefix_name: str = "" + #: Router weight names relative to *moe_prefix*. + router_weights: List[str] = field(default_factory=list) + #: Router bias names relative to *moe_prefix*. + router_biases: List[str] = field(default_factory=list) + #: Per-expert weight names relative to *expert_prefix* (per-expert format). + expert_weights: List[str] = field(default_factory=list) + #: Per-expert bias names relative to *expert_prefix* (per-expert format). + expert_biases: List[str] = field(default_factory=list) + #: If ``True``, experts are stored as single fused tensors (shape ``[num_experts, ...]``). + is_fused_experts: bool = False + #: Fused expert weight names relative to *moe_prefix*, + #: e.g. ``["experts.gate_up_proj", "experts.down_proj"]``. + fused_expert_weights: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.target_name + + def moe_prefix(self, layer_idx: int) -> str: + return self.moe_prefix_name.format(layer_idx=layer_idx) + + def expert_prefix(self, layer_idx: int, expert_idx: int) -> str: + _expert_prefix = self.moe_prefix_name + "." + self.expert_prefix_name + return _expert_prefix.format(layer_idx=layer_idx, expert_idx=expert_idx) + + +class ExpertRemovalPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: ExpertRemovalLayerDescriptor): + assert isinstance(layer_descriptor, ExpertRemovalLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [ + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + ] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + + child_block_config = new_config.block_configs[layer_idx] + parent_block_config = original_config.block_configs[layer_idx] + + if not parent_block_config.ffn.is_moe: + return layer_out_state_dict + + new_num_experts = child_block_config.ffn.moe.num_local_experts + orig_num_experts = parent_block_config.ffn.moe.num_local_experts + + child_router_keys, new_experts_keys = self._generate_moe_keys(layer_idx, new_num_experts) + parent_router_keys, orig_experts_keys = self._generate_moe_keys(layer_idx, orig_num_experts) + + # Pop parent's router keys from copy list; child-only router keys will be initialized below + for rk in sum(parent_router_keys.values(), []): + if rk in keys: + keys.pop(rk) + for key in sum(orig_experts_keys.values(), []): + if key in keys: + keys.pop(key) + + if self.layer_descriptor.is_fused_experts: + # Fused format: unbundle single tensor [num_experts, ...] into list of per-expert tensors + orig_experts_weights = {} + for name, fused_keys in orig_experts_keys.items(): + fused_tensor = parent_state_dict[fused_keys[0]] # Single fused tensor + orig_experts_weights[name] = [fused_tensor[i] for i in range(orig_num_experts)] + + new_experts_weights = {} + for name, fused_keys in new_experts_keys.items(): + fused_tensor = new_state_dict[fused_keys[0]] # Single fused tensor + new_experts_weights[name] = [fused_tensor[i] for i in range(new_num_experts)] + else: + # Per-expert format: load each expert tensor separately + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + + orig_router_weights = { + name: [parent_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in parent_router_keys.items() + } + new_router_weights = { + name: [new_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in child_router_keys.items() + } + + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weights=orig_router_weights, + orig_experts_weights=orig_experts_weights, + new_router_weights=new_router_weights, + new_experts_weights=new_experts_weights, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + assert new_experts_keys.keys() == out_experts_weights.keys(), ( + "new_experts_keys and out_experts_weights must have the same keys" + ) + assert child_router_keys.keys() == out_router_weights.keys(), ( + "child_router_keys and out_router_weights must have the same keys" + ) + + for name in child_router_keys.keys(): + layer_out_state_dict.update(zip(child_router_keys[name], out_router_weights[name])) + + if self.layer_descriptor.is_fused_experts: + # Fused format: rebundle list of per-expert tensors into single fused tensor + for name in new_experts_keys.keys(): + fused_key = new_experts_keys[name][0] # Single key for fused tensor + fused_tensor = torch.stack(out_experts_weights[name], dim=0) # [num_experts, ...] + layer_out_state_dict[fused_key] = fused_tensor + else: + # Per-expert format: each expert has its own key + for name in new_experts_keys.keys(): + layer_out_state_dict.update(zip(new_experts_keys[name], out_experts_weights[name])) + + return layer_out_state_dict + + def _generate_moe_keys( + self, layer_idx: int, num_experts: int + ) -> Tuple[Dict[str, List[str]], dict[str, list[str]]]: + """ + Generate MoE weight keys for router and experts. + TODO simplify or better define the data structure of the moe keys returned. + + :return: tuple of router_keys and expert_keys, all are absolute names relative to the model root: + * router_keys structure: + {"weight: [], bias: []"} + * expert_keys structure (per-expert format): + {": []} + i.e: + { + "down_proj.weight": ["model...experts.0.down_proj.weight", ..., "model...experts.N.down_proj.weight"], + ... + } + * expert_keys structure (fused format): + {": []} + i.e: + { + "experts.gate_up_proj": ["model...experts.gate_up_proj"], + "experts.down_proj": ["model...experts.down_proj"], + } + """ + self.layer_descriptor: ExpertRemovalLayerDescriptor + moe_prefix = self.layer_descriptor.moe_prefix(layer_idx) + + router_keys = { + "weight": [ + f"{moe_prefix}.{_weight}" for _weight in self.layer_descriptor.router_weights + ], + "bias": [f"{moe_prefix}.{_bias}" for _bias in self.layer_descriptor.router_biases], + } + + if self.layer_descriptor.is_fused_experts: + # Fused format: single tensor per weight type with shape [num_experts, ...] + experts_module_names = {} + for fused_weight in self.layer_descriptor.fused_expert_weights: + experts_module_names[fused_weight] = [f"{moe_prefix}.{fused_weight}"] + else: + # Per-expert format: separate tensor for each expert + expert_key_names = ( + self.layer_descriptor.expert_weights + self.layer_descriptor.expert_biases + ) + experts_module_names = {} + for key_name in expert_key_names: + experts_module_names[key_name] = [ + f"{self.layer_descriptor.expert_prefix(layer_idx, expert_idx)}.{key_name}" + for expert_idx in range(num_experts) + ] + + return router_keys, experts_module_names diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py new file mode 100644 index 0000000000..f9e7b16248 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ( + ForwardHook, + IndependentChannelContributionHook, + IterativeChannelContributionHook, +) + +from .pruning_mixin import LayerDescriptor, PruningMixIn +from .pruning_utils import MlpInitMode, _init_mlp_module + +__all__ = [ + "FFNIntermediateLayerDescriptor", + "FFNIntermediatePruningMixIn", +] + + +@dataclass +class FFNIntermediateLayerDescriptor(LayerDescriptor): + down_proj_name: str + ffn_prefix_name: str + linear_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.down_proj_name + + def ffn_prefix(self, layer_idx: int) -> str: + return self.ffn_prefix_name.format(layer_idx=layer_idx) + + +class FFNIntermediatePruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: FFNIntermediateLayerDescriptor): + assert isinstance(layer_descriptor, FFNIntermediateLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + keys_to_remove: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + # Hardcoded strings + mlp_prefix = self.layer_descriptor.ffn_prefix(layer_idx) + mlp_key_names = [ + f"{mlp_prefix}.{name}.weight" for name in self.layer_descriptor.linear_weight_names + ] + mlp_keys = [keys.get(module_name) for module_name in mlp_key_names] + mlp_keys = [k for k in mlp_keys if k is not None] + + for key in mlp_keys: + keys_to_remove[f"{mlp_prefix}.{key.split('.')[-2]}.weight"] = key + + pruned_filters = None + projection_matrix = None + + for mlp_key in mlp_keys: + expanded_dim = 1 if self.layer_descriptor.down_proj_name in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + parent_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + layer_out_state_dict[mlp_key] = mlp_module_weight + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py new file mode 100644 index 0000000000..740d1fada3 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from dataclasses import dataclass, field +from typing import Any, List, Optional, Type + +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ( + ForwardHook, + IndependentKvHeadContributionHook, +) + +from .pruning_mixin import LayerDescriptor, PruningMixIn +from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights + +__all__ = [ + "KVHeadsLayerDescriptor", + "KVHeadsPruningMixIn", +] + + +@dataclass +class KVHeadsLayerDescriptor(LayerDescriptor): + o_proj_name: str + attn_prefix_name: str + qkvo_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.o_proj_name + + def attn_prefix(self, layer_idx: int) -> str: + return self.attn_prefix_name.format(layer_idx=layer_idx) + + +class KVHeadsPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: KVHeadsLayerDescriptor): + assert isinstance(layer_descriptor, KVHeadsLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentKvHeadContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_config: Optional[dict[str, Any]], + is_original_mha: bool, + keys: dict, + keys_to_remove: dict, + **kwargs, + ): + layer_out_state_dict = {} + + attn_prefix = self.layer_descriptor.attn_prefix(layer_idx) + q_name, k_name, v_name, o_name = [ + f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names + ] + + head_size = new_config.head_dim + for part in ["weight", "bias"]: + attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] + q_key, k_key, v_key, o_key = attn_keys + + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py new file mode 100644 index 0000000000..245a115b3d --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for creating pruned model checkpoints. + +This module provides functions to generate pruned checkpoints by modifying model architectures +(FFN intermediate sizes, attention head groups, hidden dimensions) and initializing child pruned models +from parent checkpoints. +""" + +# mypy: ignore-errors +import json +import os +import time +from typing import Optional + +from omegaconf import DictConfig + +from ..anymodel.model_descriptor import ModelDescriptorFactory +from ..tools.bypassed_training import init_child_from_parent +from ..tools.checkpoint_utils import load_model_config +from ..tools.logger import mprint +from .expert_removal_pruning_mixin import ExpertRemovalPruningMixIn +from .ffn_intermediate_pruning_mixin import FFNIntermediatePruningMixIn +from .kv_heads_pruning_mixin import KVHeadsPruningMixIn +from .pruning_utils import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + resolve_pruning_mixin, +) + +__all__ = [ + "launch_ffn_intermediates_prune_ckpt", + "launch_attn_groups_prune_ckpt", + "launch_hidden_dim_prune_ckpt", + "launch_experts_prune_ckpt", + "launch_moe_ffn_intermediates_prune_ckpt", + "launch_prune_ckpt", +] + + +def launch_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_attn_groups_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + descriptor = cfg.descriptor + parent_model_config = load_model_config( + cfg.teacher_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) + num_attention_heads = parent_model_config.num_attention_heads + + for n_heads_in_group in cfg.pruning.n_heads_in_group_list: + dirname = f"n_heads_in_group{n_heads_in_group}" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process n_heads_in_group {n_heads_in_group} has already been pruned & saved") + continue + + mprint("Process n_heads_in_group {}".format(n_heads_in_group)) + mprint(f"=== STARTING ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + + num_key_value_heads = num_attention_heads // n_heads_in_group + model_config_overrides_json = {"attention": [{"num_key_value_heads": num_key_value_heads}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_hidden_dim_prune_ckpt(cfg: DictConfig): + """Launch hidden dimension pruning using channel importance ranking.""" + # Get channel importance results from the activations log directory + activations_log_dir = cfg.pruning.activations_log_dir + channel_importance_path = os.path.join(activations_log_dir, "channel_importance_results.json") + + if not os.path.exists(channel_importance_path): + raise FileNotFoundError( + f"Channel importance results not found at {channel_importance_path}. " + f"Make sure to run the activation collection step first." + ) + + # Load parent model config to get FFN configuration + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + parent_model_config = load_model_config( + cfg.pruning.model_name_or_path, trust_remote_code=trust_remote_code + ) + parent_hidden_size = parent_model_config.hidden_size + + # Get teacher's FFN configuration + intermediate_sizes = [] + for block_config in parent_model_config.block_configs: + if block_config.ffn.intermediate_size is not None: + intermediate_sizes.append(block_config.ffn.intermediate_size) + else: + intermediate_sizes.append(None) + + mprint(f"Teacher config:") + mprint(f" - hidden_size: {parent_hidden_size}") + mprint(f" - intermediate_sizes: {intermediate_sizes}") + os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) + + for hidden_size in cfg.pruning.hidden_size_list: + mprint(f"\n######################################################################") + mprint(f"Hidden Size = {hidden_size}") + mprint(f"######################################################################\n") + + mprint(f"Child config:") + mprint(f" - hidden_size: {hidden_size}") + + # Create model config overrides with proper FFN configuration + model_config_overrides_json = json.dumps( + { + "hidden_size": hidden_size, + "ffn": [ + { + "intermediate_size": intermediate_size, + } + for intermediate_size in intermediate_sizes + ], + } + ) + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + dirname = f"hidden_size_{hidden_size}" + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + mprint(f"Creating checkpoint with hidden_size={hidden_size}") + mprint(f"Model config overrides: {model_config_overrides_json}") + + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.pruning.model_name_or_path, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode(cfg.pruning.linear_init_mode), + hidden_size_init_mode=HiddenSizeInitMode(cfg.pruning.hidden_size_init_mode), + channel_importance_path=channel_importance_path, + ) + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + mprint(f"Created pruned checkpoint at: {output_dir}") + + +def launch_experts_prune_ckpt( + cfg: DictConfig, + max_save_workers: Optional[int] = None, + max_layer_workers: Optional[int] = None, + symlink_suffix: Optional[str] = None, +): + for num_experts in cfg.pruning.num_experts_to_keep_list: + dirname = f"num_experts_{num_experts}" + # Create symlink name with optional suffix + symlink_name = f"{dirname}_{symlink_suffix}" if symlink_suffix else dirname + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", symlink_name)): + mprint( + f"Process num_experts {num_experts} (symlink: {symlink_name}) has already been pruned & saved" + ) + continue + mprint(f"Process num_experts {num_experts}") + mprint(f"=== STARTING EXPERT PRUNING FOR num_experts={num_experts} ===") + model_config_overrides_json = {"ffn": [{"moe": {"num_local_experts": num_experts}}]} + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, symlink_name)) + + mprint(f"=== COMPLETED EXPERT PRUNING FOR num_experts={num_experts} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_moe_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"moe_ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = { + "attention": [{"no_op": True, "llama4": None}], + "ffn": [{"moe": {"expert_intermediate_dim": intermediate_size}}], + } + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + os.symlink(output_dir, os.path.join(cfg.puzzle_dir, "ckpts", dirname)) + + mprint(f"=== COMPLETED MOE FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_prune_ckpt(cfg: DictConfig): + cfg.descriptor = ModelDescriptorFactory.get(cfg.descriptor) + # Resolve pruning_mixin from config (could be string, enum, or PruningMixIn) + cfg.pruning.pruning_mixin = resolve_pruning_mixin(cfg.pruning.pruning_mixin, cfg.descriptor) + pruning_mixin = cfg.pruning.pruning_mixin + + # I/O optimization settings - same as FFN pruning + max_save_workers = None # Will auto-calculate as min(CPU count, num files) + if "PRUNING_SAVE_WORKERS" in os.environ: + max_save_workers = int(os.environ["PRUNING_SAVE_WORKERS"]) + + # Layer workers now auto-calculate but can still be overridden + max_layer_workers = None # Will auto-calculate as min(CPU count, num layers) + if "PRUNING_LAYER_WORKERS" in os.environ: + max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) + + if isinstance(pruning_mixin, FFNIntermediatePruningMixIn): + launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, KVHeadsPruningMixIn): + launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, ExpertRemovalPruningMixIn): + launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers) + # elif target_layer == "layernorm": + # launch_hidden_dim_prune_ckpt(cfg) + else: + raise NotImplementedError( + f"checkpoint pruning is not currently supported for pruning mixin: {pruning_mixin.__class__.__name__}" + ) diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py new file mode 100644 index 0000000000..a9a4264f5f --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + +__all__ = [ + "LayerDescriptor", + "PruningMixIn", +] + + +class LayerDescriptor: + def module_name_regex(self) -> str: + return "" + + def block_idx_from_module_name(self, module_name: str) -> Optional[int]: + block_idx_match = re.search(r"\.(\d+)\.", module_name) + if block_idx_match: + return int(block_idx_match.group(1)) + return None + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_layer = self.module_name_regex() + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + match_predicate = lambda module_name: pattern.search(module_name) + else: + match_predicate = lambda module_name: module_name.endswith(target_layer) + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if match_predicate(module_name): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +class PruningMixIn(ABC): + def __init__(self, layer_descriptor: LayerDescriptor): + self.layer_descriptor = layer_descriptor + + def get_module_names_to_hook(self, model) -> List[Tuple[int, str]]: + return self.layer_descriptor.get_modules_names_to_hook(model) + + @abstractmethod + def supported_hooks(self) -> List[Type[ForwardHook]]: + raise NotImplementedError + + # @abstractmethod + # def prune_single_layer( + # self, + # layer_idx: int, + # parent_state_dict: dict, + # new_state_dict: dict, + # original_config: PretrainedConfig, + # new_config: PretrainedConfig, + # **kwargs + # ): + # raise NotImplementedError diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py new file mode 100644 index 0000000000..c600e119cf --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -0,0 +1,663 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import json +import math +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers import PretrainedConfig + +from ..anymodel.model_descriptor import ModelDescriptor +from .pruning_mixin import PruningMixIn + +__all__ = [ + "GQAInitMode", + "MlpInitMode", + "LinearInitMode", + "HiddenSizeInitMode", + "resolve_pruning_mixin", +] + + +class GQAInitMode(Enum): + RandomKV = "RandomKV" + AverageKV = "AverageKV" + FirstKV = "FirstKV" + RandomBlock = "RandomBlock" + CopyAsIs = "CopyAsIs" + Degrouping = "Degrouping" + PruneKVHeads = "PruneKVHeads" + + +class MlpInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + CopyAsIs = "CopyAsIs" + PruneByActivationsLog = "PruneByActivationsLog" + ExpertRemoval = "ExpertRemoval" + ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + + +class LinearInitMode(Enum): + Random = "Random" + FromTeacher = "FromTeacher" + + +class HiddenSizeInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + PruneByChannelRanking = "PruneByChannelRanking" + CopyAsIs = "CopyAsIs" + + +def resolve_pruning_mixin( + pruning_mixin, descriptor: Type[ModelDescriptor] +) -> PruningMixIn | List[PruningMixIn]: + """ + Convert pruning_mixin argument to PruningMixIn instance(s). + + Args: + pruning_mixin: Can be a string identifier, PruningMixIn instance, + or a list of any of those types. + descriptor: ModelDescriptor class that provides the pruning_mixins() mapping. + + Returns: + PruningMixIn or List[PruningMixIn] depending on input type. + """ + # Handle list of values recursively + if isinstance(pruning_mixin, list): + return [resolve_pruning_mixin(item, descriptor) for item in pruning_mixin] + + # Handle single value + # If it's already a PruningMixIn, return as is + if isinstance(pruning_mixin, PruningMixIn): + return pruning_mixin + + # Get the pruning mixins mapping from the descriptor + mixins_dict = descriptor.pruning_mixins() + + if isinstance(pruning_mixin, str): + if pruning_mixin not in mixins_dict: + available_methods = list(mixins_dict.keys()) + raise ValueError( + f"Pruning method '{pruning_mixin}' is not supported by {descriptor.__name__}. " + f"Available methods: {available_methods}" + ) + return mixins_dict[pruning_mixin] + + raise ValueError(f"Unsupported pruning_mixin type: {type(pruning_mixin)}") + + +def _init_mlp_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_prefix: str, + expanded_dim: int, + layer_idx: int, + new_item: torch.Tensor, + new_config: PretrainedConfig, + orig_item: torch.Tensor, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + pruned_filters: Optional[torch.Tensor] = None, + projection_matrix: Optional[dict[str, torch.Tensor]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + assert orig_item.ndim == 2, f"{orig_item.ndim=}" + assert new_item.ndim == 2, f"{new_item.ndim=}" + + assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( + f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" + ) + + new_intermediate_size = new_config.block_configs[layer_idx].ffn.intermediate_size + original_intermediate_size = original_config.block_configs[layer_idx].ffn.intermediate_size + + if mlp_init_mode == MlpInitMode.CopyAsIs: + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + elif mlp_init_mode == MlpInitMode.Random: + mlp_module_weight = new_item + + elif new_intermediate_size == original_intermediate_size: + mlp_module_weight = orig_item + + elif mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + assert original_intermediate_size >= new_intermediate_size, ( + f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated." + ) + orig_ffn_size = orig_item.shape[expanded_dim] + new_ffn_size = new_item.shape[expanded_dim] + + if mlp_init_mode == MlpInitMode.Truncate: + truncated_weight = torch.narrow( + orig_item, dim=expanded_dim, start=0, length=new_ffn_size + ) + mlp_module_weight = truncated_weight + + elif mlp_init_mode == MlpInitMode.PruneByActivationsLog: + if pruned_filters is None: + filter_importance = _load_activations_log( + mlp_init_config, module_name=f"{mlp_prefix}.down_proj" + ) + filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) + pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) + + pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) + if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: + pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) + mlp_module_weight = pruned_weight + + elif ( + mlp_init_mode == MlpInitMode.ExpertRemoval + ): # the case of mlp layers of maverick. for now we only support copy as is + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + return mlp_module_weight, pruned_filters, projection_matrix + + +def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: + _cache_activations_log(mlp_init_config) + module_log = ACTIVATIONS_LOG[module_name] + filter_importance = module_log["score"] + return filter_importance + + +ACTIVATIONS_LOG = dict() + + +def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: + if len(ACTIVATIONS_LOG) == 0: + assert "activations_log_dir" in mlp_init_config + activations_log_dir = mlp_init_config["activations_log_dir"] + print(f"Loading activations_log from {activations_log_dir}") + # Only load rank_*.pth files to avoid loading hook_states_*.pth checkpoint files + ACTIVATIONS_LOG.update( + { + module_name: module_log + for p in Path(activations_log_dir).glob("rank_*.pth") + for module_name, module_log in torch.load(p).items() + } + ) + + +def _init_attention_weights( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + + # new_w* are typically randomly initialized + new_wq = new_state_dict[q_key] + new_wk = new_state_dict[k_key] + new_wv = new_state_dict[v_key] + new_wo = new_state_dict[o_key] + + # w* are from the parent model + wq = original_state_dict[q_key] + wk = original_state_dict[k_key] + wv = original_state_dict[v_key] + wo = original_state_dict[o_key] + + if "bias" in k_key: + for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: + assert tensor.ndim == 1 + wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo = [ + t.unsqueeze(1) for t in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo] + ] + dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases + + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): + wk, wv = new_wk, new_wv + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): + assert orig_num_kv_heads % num_kv_heads == 0, ( + f"({orig_num_kv_heads=}) % ({num_kv_heads=}) != 0" + ) + n_heads_to_aggregate = orig_num_kv_heads // num_kv_heads + + wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) + wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + wk = wk.mean(dim=1) + wv = wv.mean(dim=1) + else: + wk = wk[:, 0] + wv = wv[:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" + assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" + assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" + assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" + + elif gqa_init_mode == GQAInitMode.Degrouping: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = num_kv_heads + orig_n_groups = orig_num_kv_heads + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + wk = degroup_w(wk) + wv = degroup_w(wv) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + wk = wk.view(orig_num_kv_heads, head_size, dim1) + wv = wv.view(orig_num_kv_heads, head_size, dim1) + wq = wq.view(orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size, dim1) + wo = wo.view(dim1, orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size) + + o_proj_module_name = o_key.replace(".weight", "") + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + wk = wk[kv_heads_to_keep] + wv = wv[kv_heads_to_keep] + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + wq = wq[kv_heads_to_keep] + wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) + + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + wo = wo[:, kv_heads_to_keep] + wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) + wo = wo / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + wq = wq[kv_head_ordering] + + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + wo = wo[:, kv_head_ordering] + wo[:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + wk = wk.reshape(-1, dim1) + wv = wv.reshape(-1, dim1) + wq = wq.reshape(-1, dim1) + wo = wo.reshape(dim1, -1) + return wq, wk, wv, wo + + +def _init_attention_biases( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = num_q_heads // num_kv_heads + orig_n_heads_in_group = num_q_heads // orig_num_kv_heads + + o_proj_bias = new_config.o_proj_bias + attention_bias = new_config.attention_bias + + # If no biases + if not (o_proj_bias or attention_bias): + return {} + + new_bias_sd = {} + bias_sd = {} + # new_w* are typically randomly initialized + if o_proj_bias: + new_bias_sd["o"] = new_state_dict[o_key] + bias_sd["o"] = original_state_dict[o_key] + if attention_bias: + for bias_key, key in zip("qkv", [q_key, k_key, v_key]): + new_bias_sd[bias_key] = new_state_dict[key] + bias_sd[bias_key] = original_state_dict[key] + + # maybe unsqueeze all tensors (non-in-place to avoid mutating shared state dict entries) + for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): + assert tensor.ndim == 1 + new_bias_sd = {k: v.unsqueeze(1) for k, v in new_bias_sd.items()} + bias_sd = {k: v.unsqueeze(1) for k, v in bias_sd.items()} + + dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: + bias_sd["k"] = torch.zeros( + new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device + ) + bias_sd["v"] = torch.zeros( + new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device + ) + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + bias_sd["k"] = bias_sd["k"].mean(dim=1) + bias_sd["v"] = bias_sd["v"].mean(dim=1) + else: + bias_sd["k"] = bias_sd["k"][:, 0] + bias_sd["v"] = bias_sd["v"][:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + for key in bias_sd.keys(): + assert new_bias_sd[key].shape == bias_sd[key].shape, ( + f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" + ) + + elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + bias_sd["k"] = degroup_w(bias_sd["k"]) + bias_sd["v"] = degroup_w(bias_sd["v"]) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + if o_proj_bias: + o_proj_module_name = o_key.rsplit(".", 1)[0] + else: + # Here we assume that the o_proj layer is called "o_proj" + o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" + + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + # view as KV groups + if attention_bias: + bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["q"] = bias_sd["q"].view( + orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 + ) + # Keep important KV heads and prune the others + bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] + bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].view( + dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size + ) + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + if attention_bias: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] + bias_sd["q"] = torch.repeat_interleave( + bias_sd["q"], repeats=reduction_factor, dim=0 + ) + + if o_proj_bias: + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] + bias_sd["o"] = torch.repeat_interleave( + bias_sd["o"], repeats=reduction_factor, dim=1 + ) + bias_sd["o"] = bias_sd["o"] / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + if attention_bias: + bias_sd["q"] = bias_sd["q"][kv_head_ordering] + + if o_proj_bias: + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] + bias_sd["o"][:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + if attention_bias: + for bias_key in "qkv": + bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].reshape(-1) + return bias_sd + + +def _init_moe_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_init_config: Optional[Dict[str, Any]], + layer_idx: int, + orig_router_weights: Dict[str, List[torch.Tensor]], + orig_experts_weights: Dict[str, List[torch.Tensor]], + new_router_weights: Dict[str, List[torch.Tensor]], + new_experts_weights: Dict[str, List[torch.Tensor]], + orig_num_experts: int, + new_num_experts: int, +) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + + if mlp_init_mode != MlpInitMode.ExpertRemoval: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + selected_experts = _select_expert_indices( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + + # Router: prefer parent tensors when available; if child has bias only, slice from child + result_router_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_router_weights.items(): + result_router_weights[name] = [ + tensor_to_slice[selected_experts] for tensor_to_slice in orig_router_weights[name] + ] + + # Experts: for each name present in the child, take from parent if available, else from child + result_experts_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_experts_weights.items(): + if name in orig_experts_weights: + src_list = orig_experts_weights[name] + else: + src_list = new_list + result_experts_weights[name] = [src_list[i] for i in selected_experts] + + # Validate shapes + assert result_router_weights.keys() == new_router_weights.keys(), ( + "result_router_weights and new_router_weights must have the same keys" + ) + for name in new_router_weights.keys(): + assert len(new_router_weights[name]) == len(result_router_weights[name]) + for new_router_weight, result_router_weight in zip( + new_router_weights[name], result_router_weights[name] + ): + assert new_router_weight.shape == result_router_weight.shape + + assert result_experts_weights.keys() == new_experts_weights.keys(), ( + "result_experts_weights and new_experts_weights must have the same keys" + ) + for name in result_experts_weights.keys(): + assert len(new_experts_weights[name]) == len(result_experts_weights[name]) + for new_expert_weight, result_expert_weight in zip( + new_experts_weights[name], result_experts_weights[name] + ): + assert new_expert_weight.shape == result_expert_weight.shape + + return result_router_weights, result_experts_weights + + +def _select_expert_indices( + *, mlp_init_config: dict[str, Any], layer_idx: int, orig_num_experts: int, new_num_experts: int +) -> list[int]: + expert_scores = _load_expert_scores(mlp_init_config, layer_idx) + assert len(expert_scores) == orig_num_experts + higher_is_better = mlp_init_config.get("higher_is_better", True) + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: ( + expert_scores[i] + if not math.isnan(expert_scores[i]) + else (float("-inf") if higher_is_better else float("inf")) + ), + reverse=higher_is_better, + )[:new_num_experts] + return selected_experts + + +def _load_expert_scores( + mlp_init_config: Optional[dict[str, Any]], layer_idx: int +) -> list[list[int | float]]: + assert mlp_init_config is not None + if "expert_scores_file" in mlp_init_config: + expert_scores_file = mlp_init_config["expert_scores_file"] + with open(expert_scores_file, "r") as f: + expert_scores = json.load(f) + elif "activations_log_dir" in mlp_init_config: + _cache_activations_log(mlp_init_config) + # Use layer_prefix_template from pruning config, or fall back to legacy nemotron_h format + # TODO - get from descriptors + layer_prefix_template = mlp_init_config.get( + "layer_prefix_template", "backbone.layers.{layer_idx}." + ) + layer_prefix = layer_prefix_template.format(layer_idx=layer_idx) + candidate_layer_keys = [ + key for key in ACTIVATIONS_LOG.keys() if key.startswith(layer_prefix) + ] + if len(candidate_layer_keys) == 0: + raise ValueError(f"No layer keys found for {layer_prefix=}. {ACTIVATIONS_LOG.keys()=}") + elif len(candidate_layer_keys) > 1: + if "layer_suffix" not in mlp_init_config: + raise ValueError( + f"Multiple candidate layer keys found for {layer_prefix=}, you must specify a layer_suffix in the mlp_init_config. {candidate_layer_keys=}" + ) + layer_suffix = mlp_init_config["layer_suffix"] + layer_key = f"{layer_prefix}{layer_suffix}" + else: + layer_key = candidate_layer_keys[0] + layer_log = ACTIVATIONS_LOG[layer_key] + + expert_scores_key = mlp_init_config.get("expert_scores_key", "expert_ranks") + if expert_scores_key not in layer_log: + raise ValueError( + f"Expert scores key {expert_scores_key=} not found in {layer_log.keys()=}" + ) + expert_scores = layer_log[expert_scores_key] + else: + raise ValueError(f"Unsupported {mlp_init_config=}") + return expert_scores diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py new file mode 100644 index 0000000000..253674f97a --- /dev/null +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). + +It is used by mtn.convert() to convert a model from HF format to Puzzletron heterogeneous format + do pruning scoring +and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. +""" + +from pathlib import Path + +import hydra +from torch import nn + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.nas.conversion import NASModeRegistry +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + MetadataDict, + ModeDescriptor, + RestoreEntrypoint, +) +from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict + +from .activation_scoring import launch_score_activations +from .anymodel.converter import ConverterFactory +from .anymodel.model_descriptor import ModelDescriptorFactory +from .build_library_and_stats import launch_build_library_and_stats +from .mip import launch_mip_and_realize_model +from .pruning import launch_prune_ckpt +from .scoring import launch_scoring +from .tools.hydra_utils import initialize_hydra_config_for_dir +from .tools.logger import mprint + +__all__ = [ + "PuzzletronModel", + "PuzzletronConfig", + "PuzzletronDescriptor", + "PuzzletronSearcher", + "convert_puzzletron_model", + "restore_puzzletron_model", +] + + +class PuzzletronModel(nn.Module): + pass # No model implementation is needed for the puzzletron mode + + +class PuzzletronConfig(ModeloptBaseConfig): + """Configuration for Puzzletron NAS algorithm.""" + + # Input model path to compress in the HF format + input_model_path: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config directory containing the search space definition + hydra_config_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config name containing the search space definition + hydra_config_name: str = ModeloptField( + default="", + title="", + description="", + ) + + # Directory to save the compressed model and intermediate results + puzzle_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Dataset path to use for scoring in prunining and NAS search + dataset_path: str = ModeloptField( + default="", + title="", + description="", + ) + + +def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: + """1. Convert the model from HF format to AnyModel format. + 2. Score the pruning activations. + 3. Prune the model and save pruned checkpoints + + The output of this step will be used by mnt.search() to perform the NAS search. + """ + # Required for mtn.search() to read NAS configuration + model.hydra_config_dir = config.hydra_config_dir + model.hydra_config_name = config.hydra_config_name + model.puzzle_dir = config.puzzle_dir + model.dataset_path = config.dataset_path + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=config.hydra_config_dir, + config_name=config.hydra_config_name, + overrides=[ + f"puzzle_dir={config.puzzle_dir}", + f"dataset_path={config.dataset_path}", + ], + ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) + if dist.is_master(): + mprint( + "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" + ) + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + converter.convert( + descriptor=descriptor, + input_dir=Path(config.input_model_path), + output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, + ) + dist.barrier() + + # Score_pruning_activations (distributed processing) + mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") + launch_score_activations(hydra_cfg) + + # Prune the model and save pruned checkpoints + if dist.is_master(): + mprint( + "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + ) + launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return model, {} + + +def restore_puzzletron_model( + model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict +) -> nn.Module: + """Restore is not needed for the puzzletron mode as we are not saving any model state""" + return model + + +@NASModeRegistry.register_mode +class PuzzletronDescriptor(ModeDescriptor): + """Descriptor for the Puzzletron mode.""" + + @property + def name(self) -> str: + """String identifier for this mode.""" + return "puzzletron" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Configuration class for this mode.""" + return PuzzletronConfig + + @property + def search_algorithm(self) -> type[BaseSearcher]: + """Return the associated searcher implementation.""" + + return PuzzletronSearcher + + @property + def convert(self) -> ConvertEntrypoint: + """Entrypoint to convert a model.""" + return convert_puzzletron_model + + @property + def restore(self) -> RestoreEntrypoint: + """Entrypoint to restore a model.""" + return restore_puzzletron_model + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode. + For now, this will be a no-op as there is no modelopt's concept of search space defined + for the puzzletron algorithm. + """ + return "export_nas" + + +class PuzzletronSearcher(BaseSearcher): + """Runs NAS search for the Puzzletron mode.""" + + @property + def default_state_dict(self) -> SearchStateDict: + """Not needed for the puzzletron mode as we are not saving any model state""" + return {} + + def run_search(self) -> None: + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=self.model.hydra_config_dir, + config_name=self.model.hydra_config_name, + overrides=[ + f"puzzle_dir={self.model.puzzle_dir}", + f"dataset_path={self.model.dataset_path}", + ], + ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Build_library_and_stats (single process) + if dist.is_master(): + mprint( + "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" + ) + launch_build_library_and_stats(hydra_cfg) + dist.barrier() + + # Calc_one_block_scores (distributed processing) + mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + launch_scoring(hydra_cfg) + + # mip_and_realize_models (distributed processing) + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/__init__.py b/modelopt/torch/puzzletron/replacement_library/__init__.py new file mode 100644 index 0000000000..c3d67d190c --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replacement library for Puzzletron layer substitution.""" + +from .library import * +from .replacement_utils import * diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py new file mode 100644 index 0000000000..b5d0c754f1 --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -0,0 +1,621 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module constructs the replacement library JSON files from a puzzle directory containing +multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock +configurations, builds a library of available replacements, and generates solutions for layer +replacement in compressed models. The resulting replacement library can then be used by +ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch.puzzletron.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any, Type + +import pandas as pd +from omegaconf import DictConfig + +from modelopt.torch.utils import json_dump + +from ..anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ..block_config import AttentionConfig, BlockConfig, FFNConfig +from ..mip.utils import sort_replacements +from ..tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + is_valid_decilm_checkpoint, + load_model_config, +) +from ..tools.logger import mprint +from ..utils.misc import block_config_to_str, subblock_config_to_str +from ..utils.parsing import format_global_config +from .replacement_utils import is_replacement_identical_to_teacher, replacement_is_teacher + +__all__ = [ + "UNIQUE_SUBBLOCK_IDENTIFIER", + "CHECKPOINTS_DIR_NAME", + "build_replacement_library", + "launch_build_replacement_library", + "infer_teacher_dir", +] + +UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"] +CHECKPOINTS_DIR_NAME = "ckpts" + + +def build_replacement_library( + master_puzzle_dir: Path | str, + descriptor: ModelDescriptor, + teacher_checkpoint_dir: Path | str | None = None, + add_ffn_no_ops: bool = True, + add_attention_no_ops: bool = True, +) -> None: + """ + For normal puzzle runs, use default values. + For advanced use cases, see the Usage section. + """ + master_puzzle_dir = Path(master_puzzle_dir) + (master_puzzle_dir / "ckpts").mkdir(exist_ok=True) + teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + subblocks_df = _build_subblocks_df( + master_puzzle_dir, + teacher_checkpoint_dir, + add_ffn_no_ops, + add_attention_no_ops, + trust_remote_code=trust_remote_code, + ) + block_library_df = _build_block_library_from_subblocks(subblocks_df, master_puzzle_dir) + + layer_replacements = _build_layer_replacements( + block_library_df, master_puzzle_dir, teacher_checkpoint_dir, trust_remote_code + ) + + single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions( + layer_replacements, teacher_checkpoint_dir, trust_remote_code + ) + + json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), master_puzzle_dir / "subblock_library.json") + json_dump(layer_replacements, master_puzzle_dir / "replacement_library.json") + json_dump( + single_sequence_replacement_solutions, + master_puzzle_dir / "single_sequence_replacement_solutions.json", + ) + mprint("done") + + +def launch_build_replacement_library(cfg: DictConfig) -> None: + """ + Launch the build replacement library function with Hydra configuration. + """ + mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}" + ) + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + build_replacement_library( + master_puzzle_dir=cfg.puzzle_dir, + teacher_checkpoint_dir=cfg.teacher_dir, + add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops, + add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops, + descriptor=descriptor, + ) + + +def infer_teacher_dir( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, +) -> Path: + if teacher_checkpoint_dir is None: + teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" + if not teacher_checkpoint_dir.exists(): + raise ValueError( + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." + ) + teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() + return teacher_checkpoint_dir + + +def _build_block_library_from_subblocks( + subblocks_df: pd.DataFrame, output_dir: Path +) -> pd.DataFrame: + joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy() + constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df) + + is_constructed_block_has_joint_variant = pd.Series( + map(tuple, constructed_blocks_df[["block_config", "block_idx"]].values) + ).isin(pd.Series(map(tuple, joint_blocks_df[["block_config", "block_idx"]].values))) + constructed_blocks_df = constructed_blocks_df[~is_constructed_block_has_joint_variant] + + block_library_df = pd.concat([joint_blocks_df, constructed_blocks_df]) + block_library_df["block_repr"] = block_library_df["block_config"].apply(block_config_to_str) + + dups = block_library_df.loc[ + block_library_df[["block_config", "block_idx"]].duplicated() + ].sort_values(by=["block_config", "block_idx"]) + if len(dups) > 0: + mprint(f"Found {len(dups)} duplicate blocks in the block library. Here are some examples:") + dup_block_idx = dups["block_idx"].iloc[0] + dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx] + for _, row in dups_with_same_block_idx.head(10).iterrows(): + mprint(row.to_dict()) + json_dump( + block_library_df.to_dict(orient="records"), output_dir / "ERROR_block_library.json" + ) + json_dump( + subblocks_df.to_dict(orient="records"), output_dir / "ERROR_subblock_library.json" + ) + raise ValueError( + f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details." + ) + + return block_library_df + + +def _construct_blocks_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + columns = subblocks_df.columns + decomp_blocks_df = subblocks_df[subblocks_df["block_config"].isna()].drop( + columns=columns[columns.str.contains("block_config|joint|block_repr")] + ) + + attention_df = decomp_blocks_df.dropna(subset="attention_config").drop( + columns=columns[columns.str.contains("ffn")] + ) + ffn_df = decomp_blocks_df.dropna(subset="ffn_config").drop( + columns=columns[columns.str.contains("attention")] + ) + constructed_blocks_df = pd.merge(attention_df, ffn_df, on="block_idx") + + constructed_blocks_df["block_config"] = constructed_blocks_df.apply( + lambda row: BlockConfig(ffn=row["ffn_config"], attention=row["attention_config"]), axis=1 + ) + + return constructed_blocks_df + + +def _build_subblocks_df( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str, + add_ffn_no_ops: bool, + add_attention_no_ops: bool, + trust_remote_code: bool = False, +) -> pd.DataFrame: + teacher_checkpoint_dir = Path(teacher_checkpoint_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + checkpoints_to_split = [teacher_checkpoint_dir] + + subblock_rows = [] + for checkpoint_dir in checkpoint_dirs: + subblocks_to_extract = _infer_subblocks_to_extract(checkpoint_dir, checkpoints_to_split) + if len(subblocks_to_extract) > 0: + subblock_rows_from_current_checkpoint = ( + _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir, subblocks_to_extract, trust_remote_code=trust_remote_code + ) + ) + subblock_rows.extend(subblock_rows_from_current_checkpoint) + + subblocks_df = pd.DataFrame(subblock_rows) + + subblocks_df = _drop_duplicates_of_decomp_no_op(subblocks_df) + assert subblocks_df.duplicated().sum() == 0 + + if add_ffn_no_ops or add_attention_no_ops: + subblocks_df = _add_no_op_subblock_rows(subblocks_df, add_ffn_no_ops, add_attention_no_ops) + + subblocks_df = _drop_duplicates_of_teacher(subblocks_df, teacher_checkpoint_dir) + + subblocks_that_have_multiple_sources = list( + subblocks_df[subblocks_df.duplicated(UNIQUE_SUBBLOCK_IDENTIFIER, keep=False)].groupby( + UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False + ) + ) + if len(subblocks_that_have_multiple_sources) > 0: + mprint( + f"Found {len(subblocks_that_have_multiple_sources)} subblock types with multiple sources. Dropping duplicates..." + ) + for subblock_identifier, duplicates_df in subblocks_that_have_multiple_sources: + mprint("\n================================") + mprint(dict(zip(UNIQUE_SUBBLOCK_IDENTIFIER, subblock_identifier))) + for _, row in duplicates_df.iterrows(): + mprint(row.to_dict()) + + # Drop duplicates, keeping the first occurrence (which should be from teacher) + mprint(f"Dropping duplicates. Original count: {len(subblocks_df)}") + subblocks_df = subblocks_df.drop_duplicates(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + mprint(f"After dropping duplicates: {len(subblocks_df)}") + + subblocks_df["ffn_repr"] = subblocks_df["ffn_config"].apply(subblock_config_to_str) + subblocks_df["attention_repr"] = subblocks_df["attention_config"].apply(subblock_config_to_str) + subblocks_df["block_repr"] = subblocks_df["block_config"].apply(block_config_to_str) + + return subblocks_df + + +def _drop_duplicates_of_teacher( + subblocks_df: pd.DataFrame, + teacher_checkpoint_dir: Path | str, +) -> pd.DataFrame: + orig_subblocks_df = subblocks_df.copy() + + attention_is_teacher = subblocks_df["attention_checkpoint_dir"] == str(teacher_checkpoint_dir) + ffn_is_teacher = subblocks_df["ffn_checkpoint_dir"] == str(teacher_checkpoint_dir) + is_joint_teacher = attention_is_teacher & ffn_is_teacher + + is_decomp_attention = subblocks_df["ffn_config"].isna() + is_decomp_ffn = subblocks_df["attention_config"].isna() + is_joint_block = ~is_decomp_attention & ~is_decomp_ffn + + student_indices_that_have_teacher_dups = [] + + for current_subset, is_teacher in [ + (is_decomp_attention, attention_is_teacher), + (is_decomp_ffn, ffn_is_teacher), + (is_joint_block, is_joint_teacher), + ]: + subblocks_df = orig_subblocks_df.copy().loc[current_subset] + + subblocks_df["is_student"] = ~is_teacher.loc[current_subset] + + def get_student_indices_that_have_teacher_dups(grouped_is_student: pd.Series) -> list: + if grouped_is_student.all(): + return [] + return grouped_is_student.index[grouped_is_student].tolist() + + current_student_indices_that_have_teacher_dups = [ + dup_index + for dup_list in subblocks_df.groupby(UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False)[ + "is_student" + ].apply(get_student_indices_that_have_teacher_dups) + for dup_index in dup_list + ] + student_indices_that_have_teacher_dups.extend( + current_student_indices_that_have_teacher_dups + ) + + dedup_subblocks_df = orig_subblocks_df.drop(index=student_indices_that_have_teacher_dups) + return dedup_subblocks_df + + +def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame: + is_decomp = subblocks_df["block_config"].isna() + is_ffn_no_op = subblocks_df["ffn_config"].apply(lambda conf: conf is not None and conf.no_op) + is_attention_no_op = subblocks_df["attention_config"].apply( + lambda conf: conf is not None and conf.no_op + ) + is_duplicated = subblocks_df.duplicated(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + is_dup_of_decomp_no_op = is_duplicated & is_decomp & (is_ffn_no_op | is_attention_no_op) + subblocks_df = subblocks_df[~is_dup_of_decomp_no_op] + return subblocks_df + + +def _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir: Path, subblocks_to_extract: list[str], trust_remote_code: bool = False +) -> list[dict[str, Any]]: + subblock_rows_from_current_checkpoint = [] + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) + for block_idx, block_config in enumerate(model_config.block_configs): + for subblock_to_extract in subblocks_to_extract: + subblock_row = _init_empty_subblock_row(block_idx) + + if subblock_to_extract == "block": + subblock_row["block_config"] = block_config + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "ffn": + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "attention": + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + else: + raise ValueError() + + subblock_rows_from_current_checkpoint.append(subblock_row) + return subblock_rows_from_current_checkpoint + + +def _add_no_op_subblock_rows( + subblocks_df: pd.DataFrame, + add_ffn_no_op: bool, + add_attention_no_op: bool, +) -> pd.DataFrame: + n_layer = subblocks_df["block_idx"].max() + 1 + + no_op_subblocks = [] + if add_ffn_no_op: + no_op_subblocks.append("ffn") + if add_attention_no_op: + no_op_subblocks.append("attention") + + additional_no_op_rows = [] + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, subblock_cls = _get_rows_with_no_op_subblock( + subblocks_df, no_op_subblock + ) + existing_no_op_indices = rows_with_no_op_subblock["block_idx"].values + missing_no_op_indices = list(set(range(n_layer)) - set(existing_no_op_indices)) + for block_idx in missing_no_op_indices: + no_op_subblock_row = { + **_init_empty_subblock_row(block_idx), + f"{no_op_subblock}_config": subblock_cls(no_op=True), + } + additional_no_op_rows.append(no_op_subblock_row) + + subblocks_df = pd.concat([subblocks_df, pd.DataFrame(additional_no_op_rows)]) + + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, _ = _get_rows_with_no_op_subblock(subblocks_df, no_op_subblock) + assert len(rows_with_no_op_subblock) == n_layer, ( + f"Got {len(rows_with_no_op_subblock)} rows with {no_op_subblock}=no_op, but we have {n_layer} layers" + ) + return subblocks_df + + +def _get_rows_with_no_op_subblock( + subblocks_df: pd.DataFrame, no_op_subblock: str +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: + other_subblock = "ffn" if no_op_subblock == "attention" else "attention" + subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig + no_op_subblock_config = subblock_cls(no_op=True) + rows_with_no_op_subblock = subblocks_df[ + (subblocks_df[f"{no_op_subblock}_config"] == no_op_subblock_config) + & subblocks_df[f"{other_subblock}_config"].isna() + ] + return rows_with_no_op_subblock, subblock_cls + + +def _get_last_checkpoint_from_each_experiment( + master_puzzle_dir: Path | str, trust_remote_code: bool = False +) -> set[Path]: + master_puzzle_dir = Path(master_puzzle_dir) + master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME + subdirs_of_master_checkpoints_dir = [ + p.resolve() for p in master_checkpoints_dir.iterdir() if p.is_dir() + ] + checkpoint_dirs = [ + p.parent + for subdir in subdirs_of_master_checkpoints_dir + for p in subdir.rglob("config.json") + ] + + for checkpoint_dir in checkpoint_dirs: + if checkpoint_dir == master_checkpoints_dir: + raise ValueError( + f"We need at least 1 hierarchy level under the '{CHECKPOINTS_DIR_NAME}' dir. " + "Name your checkpoints, preferably with meaningful names. " + "If you are Ido Galil, tell Tomer that you got this exception ;) " + ) + + # Filter out checkpoints without block_configs (e.g. unconverted raw HF layouts) + valid_checkpoint_dirs = [ + cp + for cp in checkpoint_dirs + if is_valid_decilm_checkpoint(cp, trust_remote_code=trust_remote_code) + ] + + experiment_dirs = [ + p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs + ] + + deduped_checkpoint_dirs = set( + pd.DataFrame({"checkpoint_dir": valid_checkpoint_dirs, "experiment_dir": experiment_dirs}) + .sort_values("checkpoint_dir") + .drop_duplicates(subset="experiment_dir", keep="last")["checkpoint_dir"] + .tolist() + ) + return deduped_checkpoint_dirs + + +def _infer_subblocks_to_extract( + checkpoint_dir: Path, + checkpoints_to_split: list[Path], +) -> list[str]: + if (checkpoint_dir / "replacement_library.json").exists(): + return [] + bypass_config_path = checkpoint_dir / "bypass_config.json" + if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()): + subblocks_to_extract = ["block", "attention", "ffn"] + else: + bypass_config = json.loads(bypass_config_path.read_text()) + keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + if keys_to_learn == "entire_block": + subblocks_to_extract = ["block"] + elif "mlp" in keys_to_learn and "attn" not in keys_to_learn: + subblocks_to_extract = ["ffn"] + elif "attn" in keys_to_learn and "mlp" not in keys_to_learn: + subblocks_to_extract = ["attention"] + else: + raise ValueError(f"Unrecognized {keys_to_learn=}") + return subblocks_to_extract + + +def _init_empty_subblock_row(block_idx: int) -> dict[str, Any]: + return { + "attention_checkpoint_dir": None, + "ffn_checkpoint_dir": None, + "block_config": None, + "attention_config": None, + "ffn_config": None, + "block_idx": block_idx, + "block_repr": None, + "attention_repr": None, + "ffn_repr": None, + } + + +def _build_layer_replacements( + block_library_df: pd.DataFrame, + master_puzzle_dir: Path, + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df) + layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints + layer_replacements = _filter_duplicate_teacher_replacements( + layer_replacements, teacher_checkpoint_dir, trust_remote_code + ) + return layer_replacements + + +def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) -> list[dict]: + layer_replacements = [] + for _, row in block_library_df.iterrows(): + block_idx = row["block_idx"] + block_config = row["block_config"] + weight_paths = [] + for subblock_name in ["attention", "ffn"]: + checkpoint_dir = row[f"{subblock_name}_checkpoint_dir"] + if checkpoint_dir is not None: + subblock_path = ( + Path(checkpoint_dir) + / SAFETENSORS_SUBBLOCKS_DIR_NAME + / f"block_{block_idx}_{subblock_name}.safetensors" + ) + weight_paths.append(subblock_path) + weight_paths = sorted(set(weight_paths)) + layer_replacement = { + "parent_layer_indices": [block_idx], + "child_block_configs": [block_config], + "weight_paths": weight_paths, + } + layer_replacements.append(layer_replacement) + return layer_replacements + + +def _gather_layer_replacements_from_checkpoints( + master_puzzle_dir: str | Path, trust_remote_code: bool = False +) -> list[dict]: + gathered_layer_replacements = [] + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + for checkpoint_dir in checkpoint_dirs: + if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists(): + layer_replacements = json.loads(layer_replacements_path.read_text()) + for layer_replacement in layer_replacements: + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config_dict) + for block_config_dict in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = sorted( + set(Path(p) for p in layer_replacement["weight_paths"]) + ) + gathered_layer_replacements.extend(layer_replacements) + return gathered_layer_replacements + + +def _filter_duplicate_teacher_replacements( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + filtered_layer_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ) or not is_replacement_identical_to_teacher(layer_replacement, teacher_model_config): + filtered_layer_replacements.append(layer_replacement) + return filtered_layer_replacements + + +def _build_single_sequence_replacement_solutions( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + n_layer = teacher_model_config.num_hidden_layers + + teacher_replacements = dict() + student_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher(layer_replacement, teacher_model_config, teacher_checkpoint_dir): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + else: + student_replacements.append(layer_replacement) + + teacher_indices_represented_in_replacements = sorted(teacher_replacements.keys()) + assert teacher_indices_represented_in_replacements == list(range(n_layer)), ( + f"{n_layer=}, {teacher_indices_represented_in_replacements=}" + ) + + student_replacements = sort_replacements(student_replacements) + + solutions = [] + for layer_replacement in student_replacements: + block_indices_not_represented_in_replacement = sorted( + set(range(n_layer)) - set(layer_replacement["parent_layer_indices"]) + ) + chosen_replacements = sort_replacements( + [layer_replacement] + + [ + teacher_replacements[block_idx] + for block_idx in block_indices_not_represented_in_replacement + ] + ) + + block_configs = [ + block_config + for replacement in chosen_replacements + for block_config in replacement["child_block_configs"] + ] + + solutions.append( + { + "single_sequence_replacement": layer_replacement, + "chosen_replacements": chosen_replacements, + "block_configs": block_configs, + } + ) + + return solutions diff --git a/modelopt/torch/puzzletron/replacement_library/library.py b/modelopt/torch/puzzletron/replacement_library/library.py new file mode 100644 index 0000000000..d6012f596a --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/library.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Replacement library for loading models with layer replacements (AnyModel / sharded HF checkpoints). +""" +# mypy: ignore-errors + +import copy +import json +import tempfile +from pathlib import Path +from typing import List, Optional + +from immutabledict import immutabledict +from safetensors import safe_open +from transformers import PretrainedConfig, PreTrainedModel + +from ..anymodel.converter import Converter +from ..tools.checkpoint_utils import SAFETENSORS_SUBBLOCKS_DIR_NAME, load_model_config +from ..tools.checkpoint_utils_hf import save_model_config +from ..tools.sharded_checkpoint_utils import load_and_shard_model +from .replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + weights_path_to_checkpoint_dir, +) + +__all__ = [ + "ReplacementLibrary", +] + + +class ReplacementLibrary: + def __init__( + self, + replacement_library_path: str | Path, + descriptor, + model_config_overrides: Optional[dict] = None, + ): + self.descriptor = descriptor + self.replacement_library = self._load_replacement_library(replacement_library_path) + self._ensure_all_checkpoints_are_split() + self.model_config_overrides = ( + immutabledict(model_config_overrides) if (model_config_overrides is not None) else None + ) + + self._model_config = None + self._arbitrary_checkpoint_dir = None + + @staticmethod + def _load_replacement_library(replacement_library_path: str | Path) -> list[dict]: + replacement_library = json.loads(Path(replacement_library_path).read_text()) + replacement_library = [ + parse_layer_replacement(layer_replacement) for layer_replacement in replacement_library + ] + return replacement_library + + def _ensure_all_checkpoints_are_split(self) -> None: + checkpoint_dirs = self._get_all_checkpoint_dirs() + unsplit_checkpoints = [] + for checkpoint_dir in checkpoint_dirs: + if not (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + unsplit_checkpoints.append(checkpoint_dir) + assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" + + @property + def model_config(self) -> PretrainedConfig: + if self._model_config is None: + trust_remote_code = self.descriptor.requires_trust_remote_code() + self._model_config = load_model_config( + self.get_arbitrary_checkpoint_dir(), + self.model_config_overrides, + ignore_unexpected_config_keys=True, + trust_remote_code=trust_remote_code, + ) + return self._model_config + + def create_model_config(self, layer_replacements: list[dict]): + block_configs, _ = extract_block_configs_and_locations(layer_replacements) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = block_configs + model_config.num_hidden_layers = len(block_configs) + return model_config + + def _get_arbitrary_non_block_checkpoint_paths(self): + checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] + return non_block_paths + + def create_index_file_from_weights(self, weight_paths: List[str]): + weight_map = {} + for weight_path in weight_paths: + weight_path = Path(weight_path) + with safe_open(str(weight_path), framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}" + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + return index + + def prepare_tmp_checkpoint_dir( + self, + tmpdir: Path, + model_config: PretrainedConfig, + layer_replacements: List[dict], + ): + arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + + weight_paths = self._get_arbitrary_non_block_checkpoint_paths() + for layer_replacement in layer_replacements: + weight_paths += layer_replacement["weight_paths"] + + weights_index = self.create_index_file_from_weights(weight_paths) + index_path = tmpdir / "model.safetensors.index.json" + with index_path.open("w", encoding="utf-8") as out: + json.dump(weights_index, out, indent=2, sort_keys=True) + + Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir) + save_model_config(model_config, tmpdir) + + # create symlinks inside tmpdir + subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_dir.mkdir(exist_ok=True) + for weight_path in weight_paths: + link_path = subblocks_dir / weight_path.name + link_path.symlink_to(weight_path) + + def load_model( + self, + layer_replacements: list[dict], + ) -> PreTrainedModel: + """Load model using AnyModel approach with temporary checkpoint directory.""" + model_config = self.create_model_config(layer_replacements) + with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir: + tmpdir = Path(tmpdir) + self.prepare_tmp_checkpoint_dir( + tmpdir, model_config=model_config, layer_replacements=layer_replacements + ) + model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) + return model + + def get_arbitrary_checkpoint_dir(self) -> Path: + if self._arbitrary_checkpoint_dir is None: + self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() + return self._arbitrary_checkpoint_dir + + def _get_arbitrary_checkpoint_dir(self) -> Path: + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + if len(weight_paths) > 0: + return weights_path_to_checkpoint_dir(weight_paths[0]) + + def _get_all_checkpoint_dirs(self) -> list[Path]: + checkpoint_dirs = set() + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + for weights_path in weight_paths: + checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) + checkpoint_dirs.add(checkpoint_dir) + return list(checkpoint_dirs) diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py new file mode 100644 index 0000000000..066dde0de1 --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides helper functions for parsing, sorting, and analyzing layer replacement +configurations used in the replacement library for model compression. +""" + +# mypy: ignore-errors +import json +from copy import deepcopy +from pathlib import Path + +from transformers import PretrainedConfig + +from ..block_config import BlockConfig + +__all__ = [ + "parse_layer_replacement", + "extract_block_configs_and_locations", + "weights_path_to_checkpoint_dir", + "replacement_is_teacher", + "is_replacement_identical_to_teacher", + "split_replacements_to_teacher_and_student", +] + + +def parse_layer_replacement(layer_replacement: dict | str) -> dict: + if isinstance(layer_replacement, str): + layer_replacement = json.loads(layer_replacement) + else: + layer_replacement = deepcopy(layer_replacement) + + if "layer_replacement" in layer_replacement: # happens in puzzle solutions + layer_replacement = layer_replacement["layer_replacement"] + + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config) if isinstance(block_config, dict) else block_config + for block_config in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = [Path(p) for p in layer_replacement["weight_paths"]] + return layer_replacement + + +def extract_block_configs_and_locations( + layer_replacements: list[dict], +) -> tuple[list[BlockConfig], list[tuple[dict, int]]]: + from ..mip.utils import sort_replacements # local import to avoid circular dependency + + layer_replacements = sort_replacements(layer_replacements) + block_configs = [] + block_locations = [] + for layer_replacement in layer_replacements: + child_block_configs = layer_replacement["child_block_configs"] + if not isinstance(child_block_configs, list | tuple): + child_block_configs = [child_block_configs] + for block_idx_in_replacement, block_config in enumerate(child_block_configs): + block_configs.append(block_config) + block_locations.append((layer_replacement, block_idx_in_replacement)) + return block_configs, block_locations + + +def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: + checkpoint_dir: Path = weights_path + while checkpoint_dir != Path("/"): + if (checkpoint_dir / "config.json").exists(): + return checkpoint_dir + checkpoint_dir = checkpoint_dir.parent + raise FileNotFoundError(f"Couldn't find checkpoint dir for weights path {weights_path}") + + +def replacement_is_teacher( + layer_replacement: dict, + teacher_model_config: PretrainedConfig, + teacher_checkpoint_dir: Path, +) -> bool: + paths_all_teacher = all( + p.is_relative_to(teacher_checkpoint_dir) for p in layer_replacement["weight_paths"] + ) + return paths_all_teacher and is_replacement_identical_to_teacher( + layer_replacement, teacher_model_config + ) + + +def is_replacement_identical_to_teacher( + layer_replacement: dict, + teacher_model_config: PretrainedConfig, +) -> bool: + if len(layer_replacement["parent_layer_indices"]) == 1: + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_block_config = teacher_model_config.block_configs[block_idx] + if len(child_block_configs := layer_replacement["child_block_configs"]) == 1: + replacement_block_config: BlockConfig = child_block_configs[0] + if replacement_block_config == teacher_block_config: + return True + else: + parallel_blocks = getattr(replacement_block_config, "parallel_blocks", None) + if ( + parallel_blocks is not None + and len(parallel_blocks) == 1 + and parallel_blocks[0].attention == teacher_block_config.attention + and parallel_blocks[0].ffn == teacher_block_config.ffn + ): + return True + return False + + +def split_replacements_to_teacher_and_student( + replacements: list[dict], + teacher_model_config: PretrainedConfig, + teacher_checkpoint_dir: Path, +) -> tuple[list[dict], list[dict]]: + teacher_replacements, student_replacements = [], [] + for replacement in replacements: + if replacement_is_teacher(replacement, teacher_model_config, teacher_checkpoint_dir): + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + return teacher_replacements, student_replacements diff --git a/modelopt/torch/puzzletron/scoring.py b/modelopt/torch/puzzletron/scoring.py new file mode 100644 index 0000000000..5482c8913b --- /dev/null +++ b/modelopt/torch/puzzletron/scoring.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates and scores model compression solutions by evaluating puzzle solution candidates.""" + +# mypy: ignore-errors +import os +import re +from glob import glob + +import hydra +import numpy as np +import pandas as pd +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + +from .tools.hydra_utils import register_hydra_resolvers +from .tools.logger import mprint +from .tools.validate_puzzle_with_multi_replacements import validate_puzzle_solutions + +__all__ = ["launch_scoring"] + + +def extract_solution_id(filename): + pattern = r"solution_(\d+)\.json" + match = re.search(pattern, filename) + + if match: + solution_id = match.group(1) + return int(solution_id) + else: + mprint(f"Couldn't extract solutions_id from file {filename}") + + +def find_missing_solutions(solutions_df, validation_dir): + all_solutions = np.arange(solutions_df.shape[0]) + + benchmarked_solutions = list(glob(f"{validation_dir}/solution*.json")) + benchmarked_solutions = [ + extract_solution_id(os.path.basename(s)) for s in benchmarked_solutions + ] + benchmarked_solutions = [s for s in benchmarked_solutions if s is not None] + + unbenchmarked_solutions = np.setdiff1d(all_solutions, benchmarked_solutions) + return unbenchmarked_solutions.tolist() + + +def get_solutions_to_validate(cfg: DictConfig): + _solutions_to_validate = cfg.scoring.solutions_to_validate + if _solutions_to_validate is None: + single_block_replacement_solutions = pd.read_json(cfg.scoring.solutions_path) + if cfg.scoring.skip_existing_solutions: + _solutions_to_validate = find_missing_solutions( + single_block_replacement_solutions, cfg.scoring.output_dir + ) + else: + _solutions_to_validate = np.arange(single_block_replacement_solutions.shape[0]).tolist() + return _solutions_to_validate + + +def launch_scoring(cfg: DictConfig): + cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg) + mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}") + validate_puzzle_solutions(args=cfg.scoring) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(cfg) + dist.setup(timeout=cfg.nccl_timeout_minutes) + try: + launch_scoring(cfg) + finally: + dist.cleanup() + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/puzzletron/sewing_kit/__init__.py b/modelopt/torch/puzzletron/sewing_kit/__init__.py new file mode 100644 index 0000000000..963828afc1 --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from .core import * +from .passage import * +from .utils import * diff --git a/modelopt/torch/puzzletron/sewing_kit/core.py b/modelopt/torch/puzzletron/sewing_kit/core.py new file mode 100644 index 0000000000..9ffbc5e80d --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/core.py @@ -0,0 +1,897 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from __future__ import annotations + +from abc import ABC +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union + +from typing_extensions import override + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +import torch +import torch.distributed +import torch.nn as nn + +from .passage import ( + InputArgs, + OutputValue, + Passage, + PassageInputAdapter, + PassageInputOverrides, + PassageOutputAdapter, + PassageOutputOverrides, + Predicate, + always_false_predicate, +) +from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip + +__all__ = [ + "ExternalTarget", + "FunctionTarget", + "ModuleTarget", + "RemoteTarget", + "Needle", + "StitchedModule", + "CantResolveNodeDependenciesException", + "OutputsLoopFoundException", + "KnotException", + "MultipleExternalNodesException", + "OnlyInternalNodesException", + "InputReducer", +] + + +InputAdapter = Callable[[InputArgs], InputArgs] +OutputAdapter = Callable[..., OutputValue] + + +def default_input_adapter_fn(input_values: InputArgs) -> InputArgs: + return input_values + + +def default_output_adapter_fn(v: OutputValue) -> OutputValue: + return v + + +@dataclass +class IOReducer: + pass + + +def default_input_reducer_fn(acc: InputArgs, input_override: InputArgs, *args): + return acc + input_override + + +@dataclass +class InputReducer(IOReducer): + reducer_fn: Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs] = ( + default_input_reducer_fn + ) + + def __call__( + self, + acc: InputArgs, + input_override: InputArgs, + original_input: InputArgs, + index: int, + all_input_overrides: list[InputArgs], + ) -> InputArgs: + result = self.reducer_fn(acc, input_override, original_input, index, all_input_overrides) + return result + + @classmethod + def default(cls) -> InputReducer: + return InputReducer() + + +def default_output_reducer_fn(acc: OutputValue, input_override: OutputValue, *args): + return input_override + + +@dataclass +class OutputReducer(IOReducer): + reducer_fn: Callable[ + [OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue + ] = default_output_reducer_fn + requires_original_output: bool = False + + def __call__( + self, + acc: OutputValue, + output_override: OutputValue, + original_output: Optional[OutputValue], + index: int, + all_output_overrides: list[OutputValue], + ) -> InputArgs: + result = self.reducer_fn(acc, output_override, original_output, index, all_output_overrides) + return result + + @classmethod + def default(cls) -> OutputReducer: + return OutputReducer() + + +class Singleton(type): + _instances = {} + + @override + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class Target: + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class TargetWithInput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name="", input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedInputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + name: str, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name=name, input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithOutput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name="", output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedOutputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton): + """External target for stitched modules.""" + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ConstantTarget(TargetWithOutput): + name: str + value: Any + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class FunctionTarget(TargetWithInput, TargetWithOutput): + name: str + function: Callable[..., Any] + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ModuleTarget(TargetWithNamedInputs, TargetWithNamedOutputs): + name: str + module: nn.Module + + @override + def __str__(self) -> str: + return f"ModuleTarget({self.name})" + + @override + def __repr__(self) -> str: + return str(self) + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class RemoteTarget(Target): + peer_rank: Union[int, Sequence[int]] + process_group: Optional[torch.distributed.ProcessGroup] = None + blocking: bool = True + + @override + def __hash__(self) -> int: + return super().__hash__() + + def value( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass(frozen=True, eq=True) +class RemoteDataDescriptor(ABC): + key: str + + +@dataclass(frozen=True, eq=True) +class RemoteTensorDataDescriptor(RemoteDataDescriptor): + device: Literal["cuda", "cpu"] + dtype: torch.dtype + shape: torch.Size + + +@dataclass(frozen=True, eq=True) +class RemotePythonDataDescriptor(RemoteDataDescriptor): + value: Any + + +@dataclass +class Node: + target: Target + stitches_to: list[StitchDescriptor] = field(default_factory=list) + stitches_from: list[StitchDescriptor] = field(default_factory=list) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class InputDescriptor: + target: Target + input_name: str = "" + input_adapter: InputAdapter = field(default=default_input_adapter_fn) + reducer: InputReducer = field(default_factory=InputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class OutputDescriptor: + target: Target + output_name: str = "" + output_adapter: OutputAdapter = field(default=default_output_adapter_fn) + reducer: OutputReducer = field(default_factory=OutputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +IODescriptor = Union[InputDescriptor, OutputDescriptor] + + +@dataclass +class StitchDescriptor: + source_descriptor: IODescriptor + destination_descriptor: IODescriptor + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class StitchedModuleOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + + +class StitchedModuleException(Exception): + pass + + +class CantResolveNodeDependenciesException(StitchedModuleException): + pass + + +class StitchedModule(nn.Module): + def __init__( + self, + nodes: dict[Target, Node], + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> None: + super().__init__() + self.nodes = nodes + self.ignore_extra_overrides = ignore_extra_overrides + external_nodes = [n for n in nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in nodes.values() if isinstance(n.target, RemoteTarget)] + assert len(external_nodes) <= 1 + assert len(remote_nodes) + len(external_nodes) > 0 + self.external_node = external_nodes[0] if len(external_nodes) > 0 else None + self.internal_nodes = [ + n for n in nodes.values() if not isinstance(n.target, ExternalTarget) + ] + self.values_from_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + self.values_to_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + + self.node_passages: dict[Node, Passage] = { + node: Passage.create( + module=node.target.module, + inputs_to_capture=set( + s.source_descriptor.input_name + for s in node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ), + outputs_to_capture=set( + s.source_descriptor.output_name + for s in node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ), + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=getattr(node.target, "name", None), + ) + for node in self.internal_nodes + if isinstance(node.target, ModuleTarget) + } + + self.passage_modules = nn.ModuleDict( + { + f"node_{node_index}": self.node_passages[node] + for node_index, node in enumerate(nodes.values()) + if node in self.node_passages + } + ) + self.adapter_modules = nn.ModuleDict( + { + f"node_{node_index}__stitch_{stitch_index}__{descriptor_name}": adapter + for node_index, node in enumerate(nodes.values()) + for stitch_index, stitch in enumerate(node.stitches_from + node.stitches_to) + for descriptor_name, descriptor in ( + ("source", stitch.source_descriptor), + ("destination", stitch.destination_descriptor), + ) + for adapter in [ + descriptor.input_adapter + if isinstance(descriptor, InputDescriptor) + else descriptor.output_adapter + ] + if isinstance(adapter, nn.Module) + } + ) + + def create_input_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageInputOverrides: + input_descriptors_by_group = defaultdict[str, list[InputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, InputDescriptor): + input_descriptors_by_group[io_descriptor.input_name].append(io_descriptor) + + input_overrides = PassageInputOverrides() + for group, input_descriptors in input_descriptors_by_group.items(): + reducers = [d.reducer for d in input_descriptors] + + def create_reducer(input_descriptors=input_descriptors, reducers=reducers): + inputs = [values_to_node[d] for d in input_descriptors] + + def reducer_fn( + original_input: InputArgs, + module_name: Optional[str], + module: Optional[nn.Module], + ) -> InputArgs: + acc = InputArgs() + for i, (input_, reducer) in enumerate(zip(inputs, reducers)): + acc = reducer(acc, input_, original_input, i, inputs) + return acc + + return reducer_fn + + input_override = PassageInputAdapter(create_reducer()) + input_overrides[group] = input_override + + return input_overrides + + def create_output_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageOutputOverrides: + output_descriptors_by_group = defaultdict[str, list[OutputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, OutputDescriptor): + output_descriptors_by_group[io_descriptor.output_name].append(io_descriptor) + + output_overrides = PassageOutputOverrides() + for group, output_descriptors in output_descriptors_by_group.items(): + reducers = [d.reducer for d in output_descriptors] + requires_original_output = any(r.requires_original_output for r in reducers) + + def create_reducer(reducers=reducers): + outputs = [values_to_node[d] for d in output_descriptors] + + def reducer_fn( + original_output: Optional[OutputValue], + module_name: Optional[str], + module: Optional[nn.Module], + ) -> OutputValue: + acc = None + for i, (output, reducer) in enumerate(zip(outputs, reducers)): + acc = reducer(acc, output, original_output, i, outputs) + return acc + + return reducer_fn + + reducer_fn = create_reducer() + if requires_original_output: + output_override = PassageOutputAdapter(reducer_fn) + else: + output_override = reducer_fn(None, None, None) + + output_overrides[group] = output_override + + return output_overrides + + @override + def __call__( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + return super().__call__(input_overrides, output_overrides, *args, **kwargs) + + @override + @dynamo_skip + def forward( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + input_overrides = {k: InputArgs.from_value(v) for k, v in input_overrides.items()} + + self.values_from_node.clear() + self.values_to_node.clear() + + unresolved_count: int = 0 + nodes_stack: list[Node] = ( + [] if self.external_node is None else [self.external_node] + ) + self.internal_nodes + while len(nodes_stack) > 0: + node = nodes_stack.pop(0) + values_from_node = self.values_from_node[node] + values_to_node = self.values_to_node[node] + + if isinstance(node.target, ExternalTarget): + assert self.external_node is not None + + if not self.ignore_extra_overrides: + input_override_names = set(input_overrides.keys()) + external_node_input_names = set( + s.source_descriptor.input_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ) + assert input_override_names == external_node_input_names + output_override_names = set(output_overrides.keys()) + external_node_output_names = set( + s.source_descriptor.output_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ) + assert output_override_names == external_node_output_names + + for stitch in self.external_node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + orig_input_override = input_overrides[stitch.source_descriptor.input_name] + input_override = stitch.source_descriptor.input_adapter(orig_input_override) + values_from_node[stitch.source_descriptor] = input_override + elif isinstance(stitch.source_descriptor, OutputDescriptor): + orig_output_override = output_overrides[ + stitch.source_descriptor.output_name + ] + output_override = stitch.source_descriptor.output_adapter( + orig_output_override + ) + values_from_node[stitch.source_descriptor] = output_override + else: + raise RuntimeError("Shouldn't happen") + + else: + if len(values_to_node) < len(node.stitches_to): + nodes_stack.append(node) + unresolved_count += 1 + if unresolved_count >= len(nodes_stack): + raise CantResolveNodeDependenciesException( + "Can't resolve nodes dependencies" + ) + continue + + if isinstance(node.target, ConstantTarget): + assert len(values_to_node) == 0 + + output_value = node.target.value + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(output_value) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, FunctionTarget): + assert all( + isinstance(v, InputDescriptor) and v.input_name == "" + for v in values_to_node + ) + + function_input_overrides = self.create_input_overrides(values_to_node)[""] + + if isinstance(function_input_overrides, InputArgs): + input_args = function_input_overrides + else: + input_args = function_input_overrides(InputArgs(), None, None) + + function_output = node.target.function(*input_args.args, **input_args.kwargs) + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(function_output) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, ModuleTarget): + passage = self.node_passages[node] + passage.input_overrides = self.create_input_overrides(values_to_node) + passage.output_overrides = self.create_output_overrides(values_to_node) + passage_output = passage(*args, **kwargs) + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + captured_input = passage_output.captured_inputs[ + stitch.source_descriptor.input_name + ] + value = stitch.source_descriptor.input_adapter(captured_input) + values_from_node[stitch.source_descriptor] = value + elif isinstance(stitch.source_descriptor, OutputDescriptor): + captured_output = passage_output.captured_outputs[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(captured_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + + elif isinstance(node.target, RemoteTarget): + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_from_node + ) + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_to_node + ) + + process_group = node.target.process_group + peers = node.target.peer_rank + if not isinstance(peers, Sequence): + peers = [peers] + + if len(values_to_node) > 0: + items_to_send = list(self.create_output_overrides(values_to_node).items()) + + data_descriptors: list[RemoteDataDescriptor] = [] + tensors_to_send: list[torch.Tensor] = [] + + for key, value in items_to_send: + if isinstance(value, torch.Tensor): + if value.is_cuda: + tensor_device = "cuda" + elif value.is_cpu: + tensor_device = "cpu" + else: + raise RuntimeError( + f"Invalid tensor device to send to remote target: {value.device}" + ) + + data_descriptor = RemoteTensorDataDescriptor( + key=key, + device=tensor_device, + dtype=value.dtype, + shape=value.shape, + ) + tensors_to_send.append(value) + + else: + data_descriptor = RemotePythonDataDescriptor( + key=key, + value=value, + ) + + data_descriptors.append(data_descriptor) + + works: list[Optional[torch.distributed.Work]] = [] + for peer in peers: + peer_works = distributed_isend_obj( + data_descriptors, dst=peer, group=process_group + ) + works.extend(peer_works) + + for tensor in tensors_to_send: + work = torch.distributed.isend( + tensor, dst=peer, group=process_group + ) + works.append(work) + + if node.target.blocking: + for work in works: + if work is not None: + work.wait() + + if len(node.stitches_from) > 0: + assert len(peers) == 1, ( + f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})" + ) + (peer,) = peers + + data_descriptors = distributed_recv_obj(src=peer, group=process_group) + assert isinstance(data_descriptors, list) + + tensors_to_recv: list[torch.Tensor] = [] + received_values: dict[str, Any] = {} + for data_descriptor in data_descriptors: + if isinstance(data_descriptor, RemoteTensorDataDescriptor): + tensor = torch.empty( + data_descriptor.shape, + dtype=data_descriptor.dtype, + device=data_descriptor.device, + ) + tensors_to_recv.append(tensor) + received_values[data_descriptor.key] = tensor + elif isinstance(data_descriptor, RemotePythonDataDescriptor): + received_values[data_descriptor.key] = data_descriptor.value + else: + raise RuntimeError( + f"Received invalid data descriptor from remote peer: {data_descriptor}" + ) + + works: list[Optional[torch.distributed.Work]] = [] + for tensor in tensors_to_recv: + work = torch.distributed.irecv(tensor, src=peer, group=process_group) + works.append(work) + + for work in works: + if work is not None: + work.wait() + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, OutputDescriptor): + remote_output = received_values[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(remote_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + else: + raise RuntimeError("Shouldn't happen") + + for stitch in node.stitches_from: + dst_node = self.nodes[stitch.destination_descriptor.target] + value = values_from_node[stitch.source_descriptor] + + if isinstance(stitch.destination_descriptor, InputDescriptor): + value = stitch.destination_descriptor.input_adapter(value) + elif isinstance(stitch.destination_descriptor, OutputDescriptor): + value = stitch.destination_descriptor.output_adapter(value) + else: + raise RuntimeError("Shouldn't happen") + + self.values_to_node[dst_node][stitch.destination_descriptor] = value + + unresolved_count = 0 + + values_to_external_node = ( + {} if self.external_node is None else self.values_to_node[self.external_node] + ) + output = StitchedModuleOutput( + captured_inputs={ + k.input_name: v + for k, v in values_to_external_node.items() + if isinstance(k, InputDescriptor) + }, + captured_outputs={ + k.output_name: v + for k, v in values_to_external_node.items() + if isinstance(k, OutputDescriptor) + }, + ) + + self.values_from_node.clear() + self.values_to_node.clear() + + return output + + +class KnotException(Exception): + pass + + +class LoopFoundException(KnotException): + pass + + +class InputsLoopFoundException(LoopFoundException): + pass + + +class OutputsLoopFoundException(LoopFoundException): + pass + + +class MultipleExternalNodesException(KnotException): + pass + + +class OnlyInternalNodesException(KnotException): + pass + + +class Needle: + def __init__(self) -> None: + self.nodes = dict[Target, Node]() + + def get_node_for_target(self, target: Target) -> Node: + if target not in self.nodes: + node = Node(target=target) + self.nodes[target] = node + else: + node = self.nodes[target] + + return node + + def stitch(self, src: IODescriptor, dst: IODescriptor) -> Self: + descriptor = StitchDescriptor(source_descriptor=src, destination_descriptor=dst) + + src_node = self.get_node_for_target(descriptor.source_descriptor.target) + dst_node = self.get_node_for_target(descriptor.destination_descriptor.target) + + if descriptor not in src_node.stitches_from: + src_node.stitches_from.append(descriptor) + + if descriptor not in dst_node.stitches_to: + dst_node.stitches_to.append(descriptor) + + return self + + def _search_loops( + self, + node: Node, + expand_fn: Callable[[Node], Iterable[IODescriptor]], + traversed_nodes: Optional[set[Node]] = None, + ) -> bool: + if isinstance(node.target, ExternalTarget): + return False + + if traversed_nodes is None: + traversed_nodes = set() + + if node in traversed_nodes: + found_loop = True + else: + traversed_nodes = traversed_nodes | {node} + found_loop = False + descriptors = expand_fn(node) + for descriptor in descriptors: + stitch_node = self.get_node_for_target(descriptor.target) + found_loop |= self._search_loops(stitch_node, expand_fn, traversed_nodes) + + return found_loop + + def _validate_nodes(self): + # internal_nodes = [n for n in self.nodes.values() if not isinstance(n.target, (ExternalTarget, RemoteTarget))] + external_nodes = [n for n in self.nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in self.nodes.values() if isinstance(n.target, RemoteTarget)] + + if len(external_nodes) + len(remote_nodes) == 0: + raise OnlyInternalNodesException(f"Has only internal nodes") + + if len(external_nodes) > 1: + raise MultipleExternalNodesException( + f"Expected no more than 1 external node, found {len(external_nodes)}" + ) + + for i, node in enumerate(self.nodes.values()): + found_inputs_loop = self._search_loops( + node, lambda n: [s.source_descriptor for s in n.stitches_to] + ) + if found_inputs_loop: + raise InputsLoopFoundException(f"Found a loop in inputs of node {i}: {node}") + + found_outputs_loop = self._search_loops( + node, lambda n: [s.destination_descriptor for s in n.stitches_from] + ) + if found_outputs_loop: + raise OutputsLoopFoundException(f"Found a loop in outputs of node {i}: {node}") + + def knot( + self, + capture_cache_outputs_predicate=always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> StitchedModule: + self._validate_nodes() + + module = StitchedModule( + nodes=self.nodes, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + ignore_extra_overrides=ignore_extra_overrides, + ) + + return module diff --git a/modelopt/torch/puzzletron/sewing_kit/passage.py b/modelopt/torch/puzzletron/sewing_kit/passage.py new file mode 100644 index 0000000000..d8fa1f51cf --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/passage.py @@ -0,0 +1,474 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any, ContextManager, Iterable, Mapping, Optional, Union + +import torch.nn as nn +from typing_extensions import override + +from .utils import ( + ActivityContext, + dynamo_skip, + fake_tensors, + has_fake_tensor, + is_submodule_of, + is_submodule_or_same, + real_tensors, +) + +__all__ = [ + "InputArgs", + "OutputValue", + "PassageInputAdapter", + "PassageOutputAdapter", + "PassageInputOverrides", + "PassageOutputOverrides", + "NoActivePassageContextError", + "RequiredPassageOutputsCapturedSignal", + "PassageOutput", + "Predicate", + "always_false_predicate", + "Passage", + "patch_module", +] + + +@dataclass +class InputArgs: + """Container for input arguments to modules.""" + + args: list[Any] + kwargs: dict[str, Any] + + def __init__(self, *args, **kwargs): + self.args = list(args) + self.kwargs = dict(kwargs) + + def __add__(self, other: Any) -> InputArgs: + assert isinstance(other, InputArgs) + result = InputArgs(*self.args, *other.args, **{**self.kwargs, **other.kwargs}) + return result + + def drop_args(self, index: int | slice | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if index is None: + new_args.args.clear() + else: + del new_args.args[index] + + return new_args + + def drop_kwargs(self, keys: Sequence[str] | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if keys is None: + new_args.kwargs.clear() + else: + for key in keys: + new_args.kwargs.pop(key, None) + + return new_args + + @classmethod + def from_value(cls, v): + if isinstance(v, cls): + return v + elif isinstance(v, InputArgs): + return cls(*v.args, **v.kwargs) + elif isinstance(v, Sequence): + return cls(*v) + else: + return cls(v) + + +OutputValue = Any + + +@dataclass +class PassageInputAdapter: + adapter_fn: Callable[[InputArgs, Optional[str], Optional[nn.Module]], InputArgs] + + def __call__( + self, original_input: InputArgs, module_name: Optional[str], module: Optional[nn.Module] + ) -> InputArgs: + result = self.adapter_fn(original_input, module_name, module) + return result + + +@dataclass +class PassageOutputAdapter: + adapter_fn: Callable[[Any, Optional[str], Optional[nn.Module]], Any] + + def __call__( + self, original_output: Any, module_name: Optional[str], module: Optional[nn.Module] + ) -> Any: + result = self.adapter_fn(original_output, module_name, module) + return result + + +class PassageInputOverrides(dict[str, Union[PassageInputAdapter, InputArgs]]): + def __init__(self, input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = None): + if input_overrides is None: + input_overrides = {} + for k, v in input_overrides.items(): + self[k] = v + + # def __setitem__(self, key: str, value: InputAdapter | InputArgs) -> None: + # if isinstance(key, InputArgs): + # def adapter_fn(original_input: InputArgs) -> InputArgs: + # assert isinstance(value, InputArgs) + # return value + # self[key] = InputAdapter(adapter_fn) + # else: + # self[key] = value + + +class PassageOutputOverrides(dict[str, Union[PassageOutputAdapter, Any]]): + def __init__(self, output_overrides: Mapping[str, PassageOutputAdapter | Any] = None): + if output_overrides is None: + output_overrides = {} + for k, v in output_overrides.items(): + self[k] = v + + +class NoActivePassageContextError(RuntimeError): + pass + + +class RequiredPassageOutputsCapturedSignal(Exception): + pass + + +@dataclass +class PassageOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + captured_fake_outputs: dict[str, Any] + module_output: Any + + +Predicate = Callable[[str, nn.Module], bool] + + +def always_false_predicate(module_name: str, module: nn.Module) -> bool: + return False + + +def always_true_predicate(module_name: str, module: nn.Module) -> bool: + return True + + +class Passage(nn.Module): + create_fn_context = ActivityContext[None](max_depth=1) + active_passages_context = ActivityContext["Passage"](no_duplicates=True, reversed=True) + + def __init__( + self, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ): + super().__init__() + + if not self.create_fn_context.is_active(): + raise RuntimeError("Please use Passage.create(...) in order to create a new Passage") + + self.active_context_manager: Optional[ContextManager] = None + + self.name = name + self.module = module + self.module_to_name_mapping = {id(v): k for k, v in module.named_modules()} + self.inputs_to_capture = set(inputs_to_capture) + self.outputs_to_capture = set(outputs_to_capture) + self.input_overrides = input_overrides + self.output_overrides = output_overrides + self.outputs_cache = outputs_cache + self.capture_fake_outputs_predicate = capture_fake_outputs_predicate + self.capture_cache_outputs_predicate = capture_cache_outputs_predicate + self.early_exit = early_exit + + self.reset() + + @property + def input_overrides(self) -> PassageInputOverrides: + return self._input_overrides + + @input_overrides.setter + def input_overrides(self, value: Mapping[str, PassageInputAdapter | InputArgs]): + self._input_overrides = PassageInputOverrides(value) + + @property + def output_overrides(self) -> PassageOutputOverrides: + return self._output_overrides + + @output_overrides.setter + def output_overrides(self, value: Mapping[str, PassageOutputAdapter | Any]): + self._output_overrides = PassageOutputOverrides(value) + + def reset(self): + self.required_capture_count = ( + (len(self.inputs_to_capture) + len(self.outputs_to_capture)) + if self.early_exit + else None + ) + self.captured_outputs: dict[str, Any] = {} + self.captured_inputs: dict[str, InputArgs] = {} + self.captured_fake_outputs: dict[str, Any] = {} + + @classmethod + def module_name_relative_to_active_passage(cls, module: PatchedModule) -> str: + root_passage = Passage.active_passages_context.get_active() + assert root_passage is not None + module_name = root_passage.module_to_name_mapping[id(module)] + return module_name + + @classmethod + def create( + cls, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ) -> Passage: + with cls.create_fn_context(None): + passage = cls( + module=module, + inputs_to_capture=inputs_to_capture, + outputs_to_capture=outputs_to_capture, + input_overrides=input_overrides, + output_overrides=output_overrides, + outputs_cache=outputs_cache, + capture_fake_outputs_predicate=capture_fake_outputs_predicate, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=name, + ) + + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + patch_module(submodule_name, submodule) + + # register_passage_hooks(module, descriptor) + + return passage + + def is_active(self) -> bool: + result = self.active_context_manager is not None + return result + + def __enter__(self): + assert self.active_context_manager is None + self.active_context_manager = Passage.active_passages_context(self) + self.active_context_manager.__enter__() + self.module_to_name_mapping = {id(v): k for k, v in self.named_modules()} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.active_context_manager is not None + self.active_context_manager.__exit__(exc_type, exc_val, exc_tb) + self.active_context_manager = None + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def unfreeze(self): + self.train() + self.requires_grad_(True) + + def run(self, *args, **kwargs) -> PassageOutput: + return self(*args, **kwargs) + + @override + def __call__(self, *args, **kwargs) -> PassageOutput: + return super().__call__(*args, **kwargs) + + @dynamo_skip + @override + def forward(self, *args, **kwargs) -> PassageOutput: + self.reset() + + with Passage.active_passages_context(self): + try: + module_output = self.module(*args, **kwargs) + except RequiredPassageOutputsCapturedSignal: + module_output = None + + output = PassageOutput( + captured_inputs=self.captured_inputs, + captured_outputs=self.captured_outputs, + captured_fake_outputs=self.captured_fake_outputs, + module_output=module_output, + ) + + self.reset() + + return output + + +class PatchedModule: ... + + +def patch_module(module_name_: str, module: nn.Module): + # orig_forward = module.forward + + if isinstance(module, PatchedModule): + # if module_name != Passage.module_name_relative_to_active_passage(module): + # logger.warn(f'Module "{module_name}" already patched for module "{Passage.module_name_relative_to_active_passage(module)}". Could lead to bugs.') + return + + orig_class = module.__class__ + + class PassageModuleWrapper(orig_class, PatchedModule): + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def can_be_skipped(_self: PassageModuleWrapper, depth: int) -> bool: + passages_beyond_depth = Passage.active_passages_context[depth:] + module_name = Passage.module_name_relative_to_active_passage(_self) + + results = [ + ( + module_name in passage.outputs_cache + and not any( + is_submodule_or_same(k, module_name) for k in passage.outputs_to_capture + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.input_overrides.items() + if v is not None + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.output_overrides.items() + if v is not None + ) + ) + for passage in passages_beyond_depth + ] + + result = all(results) + + return result + + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def run_passage(_self: PassageModuleWrapper, depth: int, args, kwargs): + if depth + 1 > len(Passage.active_passages_context): + output = super(PassageModuleWrapper, _self).__call__(*args, **kwargs) + return output + + module_name = Passage.module_name_relative_to_active_passage(_self) + passage = Passage.active_passages_context[depth] + + has_output_override = module_name in passage.output_overrides + output_override = passage.output_overrides.get(module_name) + + if has_output_override and not isinstance(output_override, PassageOutputAdapter): + output = output_override + else: + input_override = passage.input_overrides.get(module_name) + if input_override is not None: + original_input_args = InputArgs(*args, **kwargs) + + if isinstance(input_override, PassageInputAdapter): + new_input_args = input_override(original_input_args, module_name, module) + else: + new_input_args = input_override + + args, kwargs = new_input_args.args, new_input_args.kwargs + + if ( + output_override is None + and PassageModuleWrapper.can_be_skipped(_self, depth) + and (has_fake_tensor(args) or has_fake_tensor(kwargs)) + ): + cached_output = passage.outputs_cache[module_name] + return cached_output + + output = PassageModuleWrapper.run_passage( + _self=_self, + depth=depth + 1, + args=args, + kwargs=kwargs, + ) + + if isinstance(output_override, PassageOutputAdapter): + output = output_override(output, module_name, module) + + if passage.capture_fake_outputs_predicate(module_name, module): + fake_output = fake_tensors(output) + passage.captured_fake_outputs[module_name] = fake_output + + if not module_name in passage.outputs_cache and passage.capture_cache_outputs_predicate( + module_name, module + ): + fake_output = fake_tensors(output) + passage.outputs_cache[module_name] = fake_output + + if module_name in passage.inputs_to_capture: + real_args, real_kwargs = real_tensors(args), real_tensors(kwargs) + passage.captured_inputs[module_name] = InputArgs(*real_args, **real_kwargs) + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if module_name in passage.outputs_to_capture: + real_output = real_tensors(output) + output_value = real_output + passage.captured_outputs[module_name] = output_value + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if passage.required_capture_count == 0: + raise RequiredPassageOutputsCapturedSignal() + + return output + + @dynamo_skip + @override + def __call__(self, *args, **kwargs): + output = self.run_passage( + _self=self, + depth=0, + args=args, + kwargs=kwargs, + ) + return output + + # module.forward = forward + PassageModuleWrapper.__name__ = f"ModuleWrapper({module.__class__.__name__})" + module.__class__ = PassageModuleWrapper diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py new file mode 100644 index 0000000000..3db63f6001 --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from __future__ import annotations + +import inspect +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Generic, + Optional, + Protocol, + TypeVar, + cast, + overload, +) + +import torch +import torch._C +import torch._dynamo +import torch.distributed +import torch.nn as nn +import torch.utils._pytree as pytree +from torch import Tensor +from torch._subclasses import FakeTensor, FakeTensorMode +from typing_extensions import override + +if TYPE_CHECKING: + from collections.abc import Sequence + +__all__ = [ + "ActivityContext", + "ActivityContextDuplicateException", + "dynamo_skip", + "dynamo_disable", + "is_submodule_of", + "is_submodule_or_same", + "fake_mode", + "fake_tensor", + "fake_tensor_like", + "fake_tensors", + "real_tensors", + "has_fake_tensor", + "distributed_isend_obj", + "distributed_send_obj", + "distributed_recv_obj", +] + +Fn = TypeVar("Fn", bound=Callable) + + +class DynamoSkip(Protocol): + @overload + def __call__(self, fn: None = None) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn) -> Fn: ... + + +class DynamoDisable(Protocol): + @overload + def __call__(self, fn: None = None, disable: bool = False) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn, disable: bool = False) -> Fn: ... + + +try: + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.decorators).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.decorators).disable +except: + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.eval_frame).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.eval_frame).disable + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class ModuleRef(Generic[TModule]): + def __init__(self, module: TModule): + self.module = module + + +class ActivityContextMaxDepthException(Exception): + pass + + +class ActivityContextDuplicateException(Exception): + pass + + +T = TypeVar("T") + + +class ActivityContext(Generic[T]): + def __init__(self, max_depth: Optional[int] = None, no_duplicates=False, reversed=False): + self.activity_stack: list[T] = [] + self.max_depth = max_depth + self.no_duplicates = no_duplicates + self.reversed = reversed + + def __contains__(self, value: T) -> bool: + result = value in self.activity_stack + return result + + def __call__(self, value: T) -> ContextManager: + @contextmanager + def fn(): + inserted = False + try: + if self.no_duplicates and value in self.activity_stack: + raise ActivityContextDuplicateException( + f"Activity stack cannot have a duplicate of item {value}" + ) + + if self.reversed: + self.activity_stack.insert(0, value) + else: + self.activity_stack.append(value) + inserted = True + + if self.max_depth is not None and len(self) > self.max_depth: + raise ActivityContextMaxDepthException( + f"Activity stack exceeds max depth of {self.max_depth}" + ) + + yield + finally: + if inserted: + assert self.is_active() + self.activity_stack.pop(0 if self.reversed else -1) + + return fn() + + def __len__(self) -> int: + result = len(self.activity_stack) + return result + + @overload + def __getitem__(self, key: int) -> T: ... + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + result = self.activity_stack[key] + return result + + def is_active(self) -> bool: + result = len(self) > 0 + return result + + def get_active(self) -> Optional[T]: + if self.is_active(): + return self.activity_stack[-1] + return None + + +def is_submodule_of(module_name: str, other_module_name: str) -> bool: + result = module_name.startswith(f"{other_module_name}.") or ( + module_name != "" and other_module_name == "" + ) + return result + + +def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: + result = module_name == other_module_name or is_submodule_of(module_name, other_module_name) + return result + + +fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, + # allow_fallback_kernels=False, +) + + +@overload +def fake_tensor(t: Tensor, *, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +@overload +def fake_tensor( + size: Sequence[int] | torch.Size, *, dtype: Optional[torch.dtype] = None, use_meta=False +) -> Tensor: ... + + +@overload +def fake_tensor(*args: int, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +class MyFakeTensor(Tensor): + @dynamo_disable + def __init__(self, *args, **kwargs): + super().__init__() + self._t: FakeTensor + + @override + @dynamo_disable + def __repr__(self, *, tensor_contents=None): + return f"MyFakeTensor(shape={list(self._t.shape)}, dtype={self._t.dtype}, device={self._t.device})" + + @classmethod + @override + @dynamo_disable + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args, kwargs = pytree.tree_map_only(MyFakeTensor, lambda t: t._t, (args, kwargs)) + + types = pytree.tree_map_only(type(MyFakeTensor), lambda t: FakeTensor, types) + + out = func(*args, **kwargs) + + out = pytree.tree_map_only(Tensor, lambda t: MyFakeTensor.create(t), out) + + return out + + __torch_function__ = torch._C._disabled_torch_function_impl + + # @dynamo_disable + # def __getattribute__(self, attr: str): + # if attr in {'_t', 'device', '__repr__', '__torch_function__', '__class__'}: + # return object.__getattribute__(self, attr) + + # result = getattr(self._t, attr) + + # result = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), result + # ) + # print('__getattribute__', 'attr', attr, 'ret', result) + + # return result + + @property + @dynamo_disable + def device(self): + return self._t.device + + # @property + # @dynamo_disable + # def shape(self): + # return self._t.shape + + # @dynamo_disable + # def size(self): + # return self._t.size() + + # @classmethod + # @dynamo_disable + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # if kwargs is None: + # kwargs = {} + + # args, kwargs = pytree.tree_map_only( + # MyFakeTensor, lambda t: t._t, (args, kwargs) + # ) + + # ret = func(*args, **kwargs) + + # ret = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), ret + # ) + # print('__torch_function__', 'func', func, 'ret', ret) + + # return ret + + @staticmethod + @dynamo_disable + def __new__(cls, elem, device) -> MyFakeTensor: + self = torch.Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + return cast("MyFakeTensor", self) + + @classmethod + @dynamo_disable + def create(cls, data: Tensor) -> MyFakeTensor: + if isinstance(data, MyFakeTensor): + return data + + if isinstance(data, FakeTensor): + t = data + else: + t = FakeTensor.from_tensor(data, fake_mode=fake_mode) + + # my_fake_tensor = MyFakeTensor(torch.empty(t.shape, dtype=t.dtype, device='meta')) + my_fake_tensor = MyFakeTensor( + torch.empty(t.shape, dtype=t.dtype, device="meta"), + t.device, + ) + my_fake_tensor._t = t + + return my_fake_tensor + + +@dynamo_disable +def fake_tensor(*args, **kwargs) -> Tensor: + dtype: Optional[torch.dtype] = kwargs.get("dtype") + use_meta = kwargs.get("use_meta", False) + device = kwargs.get("device", "meta") + + if len(args) == 1 and isinstance(args[0], Tensor): + if use_meta: + fake_tensor = torch.empty(args[0].size(), dtype=dtype or args[0].dtype, device="meta") + else: + fake_tensor = MyFakeTensor.create(args[0]) + else: + fake_tensor = torch.empty(*args, dtype=dtype, device=device) + if not use_meta: + fake_tensor = MyFakeTensor.create(fake_tensor) + + return fake_tensor + + +@dynamo_skip +def fake_tensor_like(t: Tensor, use_meta=False) -> Tensor: + return fake_tensor(t, use_meta=use_meta) + + +T = TypeVar("T") + + +@dynamo_skip +def fake_tensors(value: T, use_meta=False) -> T: + result = pytree.tree_map_only(Tensor, lambda t: fake_tensor_like(t, use_meta), value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: fake_tensors(v, use_meta) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([fake_tensors(v, use_meta) for v in value]) + # if isinstance(value, Tensor): + # return fake_tensor_like(value, use_meta) + # return value + + +@dynamo_skip +def real_tensors(value: Any) -> Any: + result = pytree.tree_map_only(Tensor, lambda t: None if is_fake_tensor(t) else t, value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: real_tensors(v) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([real_tensors(v) for v in value]) + # if is_fake_tensor(value): + # return None + # return value + + +@dynamo_skip +def is_fake_tensor(t: Any) -> bool: + return isinstance(t, (MyFakeTensor, FakeTensor)) or (isinstance(t, Tensor) and t.is_meta) + + +@dynamo_skip +def has_fake_tensor(v: Any) -> bool: + result = pytree.tree_any(is_fake_tensor, v) + return result + + +def _get_device_for_distributed( + group: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.device: + """ + Determine the appropriate device for distributed communication based on the backend. + NCCL backend requires CUDA tensors, while Gloo supports both CPU and CUDA. + """ + if not torch.distributed.is_initialized(): + return torch.device("cpu") + + backend = torch.distributed.get_backend(group) + if backend == "nccl": + # NCCL requires CUDA tensors + return torch.device("cuda", torch.cuda.current_device()) + else: + # Gloo and other backends support CPU tensors + return torch.device("cpu") + + +def distributed_isend_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> list[Optional[torch.distributed.Work]]: + device = _get_device_for_distributed(group) + obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor( + obj, device=device, **_get_group_kwarg_if_necessary() + ) + works: list[Optional[torch.distributed.Work]] = [ + torch.distributed.isend(obj_size_tensor, dst, group), + torch.distributed.isend(obj_tensor, dst, group), + ] + # p2p_ops = [ + # torch.distributed.P2POp(torch.distributed.isend, obj_size_tensor, dst, group), + # torch.distributed.P2POp(torch.distributed.isend, obj_tensor, dst, group), + # ] + + # works = torch.distributed.batch_isend_irecv(p2p_ops) + + return works + + +def distributed_send_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +): + works = distributed_isend_obj(obj=obj, dst=dst, group=group) + for work in works: + if work is not None: + work.wait() + + +def distributed_recv_obj( + src: Optional[int] = None, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> Any: + device = _get_device_for_distributed(group) + obj_size_tensor = torch.LongTensor(1).to(device) + torch.distributed.recv(obj_size_tensor, src=src, group=group) + obj_size = int(obj_size_tensor.item()) + + obj_tensor = torch.ByteTensor(obj_size).to(device) + torch.distributed.recv(obj_tensor, src=src, group=group) + + obj = torch.distributed.distributed_c10d._tensor_to_object( + obj_tensor, obj_size, **_get_group_kwarg_if_necessary() + ) + + return obj + + +def _get_group_kwarg_if_necessary() -> dict: + """For newer versions of torch""" + arg_names = inspect.signature( + torch.distributed.distributed_c10d._object_to_tensor + ).parameters.keys() + return dict(group=None) if "group" in arg_names else dict() diff --git a/modelopt/torch/puzzletron/subblock_stats/__init__.py b/modelopt/torch/puzzletron/subblock_stats/__init__.py new file mode 100644 index 0000000000..fbbeb3ff70 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Subblock statistics collection for Puzzletron.""" + +from .calc_subblock_params_and_memory import * +from .calc_subblock_stats import * diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py new file mode 100644 index 0000000000..d893eb55bb --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calculate memory usage and parameter counts for neural network subblocks. + +This module provides utilities to compute memory footprints and parameter counts +for different subblock types (FFN, Attention, Mamba, MoE) in large language models, +considering various data types, batch sizes, and sequence lengths. +""" + +import copy +import json +import math +from pathlib import Path +from typing import Type + +import numpy as np +import torch +from transformers import PretrainedConfig + +from ..anymodel.model_descriptor import ModelDescriptor +from ..block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + maybe_cast_block_configs, +) +from ..tools.checkpoint_utils_hf import init_model_from_config +from ..utils.misc import ( + EmptyInitOnDevice, + calculate_kv_dim, + raise_unknown_subblock_config_error, + sizeof_dtype, +) + +__all__ = [ + "calculate_subblock_memory", + "calculate_subblock_params", + "calc_subblock_active_params", + "load_moe_stats", + "estimate_num_active_experts", + "calculate_mamba_memory", + "calculate_mamba_state_size", + "calculate_ffn_memory", + "calculate_non_block_memory", + "calculate_non_block_params", +] + + +def calculate_subblock_memory( + subblock_config: FFNConfig | AttentionConfig, + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], +) -> float | dict[str, float]: + """``model_config`` / ``descriptor`` are required (puzzletron-style); FFN uses them for meta init.""" + if subblock_config.no_op: + return 0 + if isinstance(subblock_config, FFNConfig): + return calculate_ffn_memory( + subblock_config, + model_config, + descriptor, + weights_dtype, + ) + if isinstance(subblock_config, AttentionConfig): + if subblock_config.is_mamba: + return calculate_mamba_memory( + subblock_config, + model_config, + descriptor, + batch_size, + weights_dtype, + kv_cache_dtype, + ) + else: + return calculate_attention_memory( + subblock_config, + model_config, + descriptor, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + ) + raise_unknown_subblock_config_error(subblock_config) + + +def calculate_subblock_params( + config: PretrainedConfig, + layer_config: BlockConfig | FFNConfig | AttentionConfig, + descriptor: Type[ModelDescriptor], +) -> int: + """Count parameters on one meta decoder layer. + + The caller is responsible for adjusting per-layer config fields (e.g. + ``hybrid_override_pattern``) before passing ``config``; see + ``ModelDescriptor.truncate_pattern_for_subblock``. + """ + if isinstance(layer_config, FFNConfig): + block_config = layer_config.to_blockconfig() + elif isinstance(layer_config, AttentionConfig): + block_config = layer_config.to_blockconfig() + else: + block_config = layer_config + + ffn = block_config.ffn + attn = block_config.attention + ffn_no_op = ffn is None or ffn.no_op + attn_no_op = attn is None or attn.no_op + if not (ffn_no_op or attn_no_op): + raise AssertionError( + "One of ffn or attention must be no-op for sublayer param calculation " + "(single subblock at a time)." + ) + if ffn_no_op and attn_no_op: + return 0 + + _config = copy.deepcopy(config) + lm_config = descriptor.get_language_model_config(_config) + lm_config.num_hidden_layers = 1 + + block_configs = maybe_cast_block_configs([block_config]) + _config.block_configs = block_configs + if lm_config is not _config: + lm_config.block_configs = block_configs + + # Replaced earlier pattern: + # with EmptyInitOnDevice("meta"), deci_x_patcher(..., block_configs=block_configs): + # model = init_model_from_config(_config, ...) + # + # That fails on GPT-OSS with recent Transformers: ``deci_x_patcher`` runs + # ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` inside ``DecoderLayer.__init__``, so norms + # / attn / mlp are swapped for placeholders before ``GptOssModel.__init__`` finishes. At the end + # of ``GptOssModel.__init__`` the stack calls ``self.post_init()`` — inherited from + # ``PreTrainedModel`` — which then raises + # ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no longer matches + # the tree). Below we merge per-layer fields manually, init without the patcher, then call the + # same descriptor no-op hooks on the built layer (equivalent param count for + # ``num_hidden_layers == 1``). + + # ``block_config_to_layer_overrides`` may include keys with value ``None``; we omit those so + # ``lm_config.update`` does not overwrite existing fields with ``None`` (same rule as + # ``override_config_with_block_configs`` inside ``deci_x_patcher``). + layer_overrides = descriptor.block_config_to_layer_overrides(block_configs[0]) + lm_config.update({k: v for k, v in layer_overrides.items() if v is not None}) + + with EmptyInitOnDevice("meta"): + model = init_model_from_config( + _config, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + decoder_layer = model.get_submodule(descriptor.layer_block_name(index=0)) + if attn_no_op: + descriptor.attn_no_op_post_init(decoder_layer) + if ffn_no_op: + descriptor.mlp_no_op_post_init(decoder_layer) + return sum(p.numel() for p in decoder_layer.parameters()) + + +def calc_subblock_active_params( + sublayer_config: FFNConfig | AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + n_embd: int, + moe_stats_file: str, + batch_size: int, + block_idx: int, +) -> int: + if not (isinstance(sublayer_config, FFNConfig) and sublayer_config.is_moe): + return calculate_subblock_params(model_config, sublayer_config, descriptor) + return estimate_moe_active_params( + sublayer_config, n_embd, moe_stats_file, batch_size, block_idx + ) + + +def load_moe_stats(stats_file: str) -> dict: + with open(stats_file) as f: + stats = json.load(f) + return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats] + + +def estimate_num_active_experts( + dist_over_experts: np.ndarray, batch_size: int, num_experts: int +) -> int: + # cut the tail and renormalize + dist_over_experts = np.sort(dist_over_experts)[::-1][:num_experts] + dist_over_experts = dist_over_experts / (dist_over_experts.sum()) + # calculate the probability of at least one expert being active + # (expectation on indicators is the expected number of active experts) + return (1 - (1 - dist_over_experts) ** batch_size).sum() + + +def estimate_moe_active_params( + subblock_config: FFNConfig, + n_embd: int, + moe_stats_file: Path | str, + batch_size: int, + block_idx: int, +) -> int: + assert Path(moe_stats_file).exists() + # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution? + # return calculate_subblock_params(subblock_config, n_embd, n_head=None) + moe_stats = load_moe_stats(moe_stats_file) + dist_over_experts = moe_stats[block_idx] + num_experts = subblock_config.moe.num_local_experts + + expected_num_active_experts = estimate_num_active_experts( + dist_over_experts, batch_size, num_experts + ) + expert_dim = subblock_config.moe.expert_intermediate_dim + shared_expert_dim = subblock_config.moe.shared_expert_intermediate_dim + num_linear_layers = 3 # all moe experts have 3 linear layers + + router_num_params = n_embd * num_experts + expected_num_active_experts_params = ( + num_linear_layers * expert_dim * n_embd * expected_num_active_experts + ) + shared_expert_num_params = num_linear_layers * shared_expert_dim * n_embd + + expected_total_params = ( + router_num_params + expected_num_active_experts_params + shared_expert_num_params + ) + return expected_total_params + + +def calculate_attention_memory( + attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, +) -> dict[str, float]: + """allocate_prefill_query: infery-llm style. + Infery used a unified Wqkv matrix, so before extracting the kv-cache, + the query also had to be kept in-memory, once per layer. + """ + seq_len = prefill_seq_len + generation_seq_len + if ( + attention_config.is_llama4 + and (attention_chunk_size := attention_config.llama4.attention_chunk_size) is not None + ): + seq_len = min(seq_len, attention_chunk_size) + + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) + total_num_tokens = seq_len * (batch_size + prefill_queue_size) + kv_cache_size = total_num_tokens * kv_dim + query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 + num_params = calculate_subblock_params(model_config, attention_config, descriptor) + total_memory = ( + kv_cache_size * sizeof_dtype(kv_cache_dtype) + + query_prefill_size * sizeof_dtype(weights_dtype) + + num_params * sizeof_dtype(weights_dtype) + ) / 2**20 + kv_cache_memory = kv_cache_size * sizeof_dtype(kv_cache_dtype) / 2**20 + return {"memory_mib": total_memory, "kv_cache_memory_mib": kv_cache_memory} + + +def calculate_mamba_memory( + attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + batch_size: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, +) -> int: + assert attention_config.mamba is not None + mamba_config = attention_config.mamba + num_params = calculate_subblock_params(model_config, attention_config, descriptor) + return ( + num_params * sizeof_dtype(weights_dtype) + + calculate_mamba_state_size(mamba_config, batch_size) * sizeof_dtype(kv_cache_dtype) + ) / 2**20 + + +def calculate_mamba_state_size( + mamba_config: MambaConfig, + batch_size: int, +) -> int: + d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) + conv_state_size = math.prod((batch_size, conv_dim, kernel_size)) + ssm_state_size = math.prod( + (batch_size, mamba_config.num_heads, mamba_config.head_dim, mamba_config.state_dim) + ) + return conv_state_size + ssm_state_size + + +def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...]: + d_inner = mamba_config.num_heads * mamba_config.head_dim + in_proj_dim = ( + d_inner * 2 + 2 * mamba_config.num_groups * mamba_config.state_dim + mamba_config.num_heads + ) + conv_dim = d_inner + 2 * mamba_config.num_groups * mamba_config.state_dim + kernel_size = 4 + return d_inner, in_proj_dim, conv_dim, kernel_size + + +def calculate_ffn_memory( + ffn_config: FFNConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + weights_dtype: torch.dtype | str, + experts_dtype: torch.dtype | str | None = None, +) -> float: + # TODO: How to separate between expert weights and the rest for any model (same as puzzletron). + num_params = calculate_subblock_params(model_config, ffn_config, descriptor) + return num_params * sizeof_dtype(weights_dtype) / 2**20 + + +def calculate_non_block_memory( + n_embd: int, + vocab_size: int, + weight_dtype: torch.dtype, +) -> float: + return calculate_non_block_params(n_embd, vocab_size) * sizeof_dtype(weight_dtype) / 2**20 + + +def calculate_non_block_params( + n_embd: int, + vocab_size: int, +) -> int: + return vocab_size * n_embd * 2 + n_embd diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py new file mode 100644 index 0000000000..dc89a1f645 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calc subblock stats to compute memory and runtime statistics for subblocks.""" + +import copy +import dataclasses +import json +import os +import warnings +from functools import partial +from itertools import product +from pathlib import Path +from typing import Iterable, Optional, Type, TypeVar + +import pandas as pd +import torch +from immutabledict import immutabledict +from omegaconf import DictConfig, ListConfig, OmegaConf +from tqdm import tqdm +from transformers import PretrainedConfig + +from modelopt.torch.utils import json_dump + +from ..anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ..block_config import AttentionConfig, BlockConfig, FFNConfig, SubblockConfig +from ..replacement_library.replacement_utils import parse_layer_replacement +from ..tools.checkpoint_utils import load_model_config +from ..tools.logger import mprint +from ..utils.parsing import format_global_config +from .calc_subblock_params_and_memory import ( + calc_subblock_active_params, + calculate_non_block_memory, + calculate_non_block_params, + calculate_subblock_memory, + calculate_subblock_params, +) + +__all__ = [ + "calculate_subblock_stats", + "launch_calc_subblock_stats", + "add_int8_runtime_estimates", +] + +# Type variable for dataclasses +T_DataClass = TypeVar("T_DataClass") + +""" +Usage: +python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] + +--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. + +""" + + +def calculate_subblock_stats( + calc_subblock_stats_config: DictConfig, + teacher_dir: Path, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + master_puzzle_dir: Path, + subblock_configs: list[immutabledict[str, AttentionConfig | FFNConfig]], + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + vocab_size: int, + benchmark_iterations: Optional[int], + use_cuda_graph: bool, + weights_dtype: torch.dtype, + activations_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, + moe_stats_file: str | Path | None = None, +) -> dict: + is_calc_runtime = benchmark_iterations is not None + if is_calc_runtime: + raise NotImplementedError("Runtime stats calculation is not implemented yet") + + gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() + subblock_stats = { + "args": dict( + is_calc_runtime=is_calc_runtime, + gpu=gpu, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=n_embd, + n_head=n_head, + vocab_size=vocab_size, + benchmark_iterations=benchmark_iterations, + use_cuda_graph=use_cuda_graph, + weights_dtype=str(weights_dtype), + activations_dtype=str(activations_dtype), + kv_cache_dtype=str(kv_cache_dtype), + ), + "non_block": dict(), + "subblocks": list(), + } + # Compute runtime stats for unique subblocks only + if is_calc_runtime: + raise NotImplementedError("Runtime stats calculation is not implemented yet") + subblock_configs_nolayerindex = set( + [subblock_config["subblock_config"] for subblock_config in subblock_configs] + ) + + # dict[SubblockConfig, float], float + # TODO: Manage default values for calc_subblock_stats_config in one place, e.g. within a dataclass for hydra config. + synth_dataset_num_requests = calc_subblock_stats_config.get("runtime_stats", {}).get( + "synth_dataset_num_requests", 200 + ) + backend = calc_subblock_stats_config.get("runtime_stats", {}).get("backend", "trt_torch") + runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_ms_for_subblocks( + subblock_configs_nolayerindex, + vocab_size, + n_embd, + n_head, + master_puzzle_dir, + teacher_dir, + synth_dataset_num_requests, + backend, + ) + + sorted_subblock_config = sorted( + subblock_configs, key=lambda subblock_config: subblock_config["subblock_config"] + ) + it = ( + tqdm(sorted_subblock_config, desc="Measuring subblock runtimes") + if is_calc_runtime + else sorted_subblock_config + ) + for subblock_config_indexed in it: + subblock_config = subblock_config_indexed["subblock_config"] + parent_layer_indices = subblock_config_indexed["parent_layer_indices"] + + layer_model_config = copy.deepcopy(model_config) + ModelDescriptor.truncate_pattern_for_subblock( + descriptor.get_language_model_config(layer_model_config), parent_layer_indices[0] + ) + + if is_calc_runtime: + total_runtime_ms = runtime_by_subblock_dict[subblock_config] + prefill_runtime_ms = None + decode_runtime_ms = None + else: + total_runtime_ms, prefill_runtime_ms, decode_runtime_ms = None, None, None + + subblock_memory = calculate_subblock_memory( + subblock_config, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + model_config=layer_model_config, + descriptor=descriptor, + ) + if not isinstance(subblock_memory, dict): + subblock_memory = {"memory_mib": subblock_memory, "kv_cache_memory_mib": 0.0} + + subblock_params = calculate_subblock_params(layer_model_config, subblock_config, descriptor) + if moe_stats_file is not None: + subblock_active_params = calc_subblock_active_params( + subblock_config, + layer_model_config, + descriptor, + n_embd, + moe_stats_file, + batch_size, + parent_layer_indices[0], + ) + else: + subblock_active_params = subblock_params + subblock_stats["subblocks"].append( + { + "subblock_config": subblock_config, + "subblock_config_class": type(subblock_config).__name__, + "runtime_ms": total_runtime_ms, + "prefill_runtime_ms": prefill_runtime_ms, + "decode_runtime_ms": decode_runtime_ms, + "num_params": subblock_params, + "active_params": subblock_active_params, + "parent_layer_index": parent_layer_indices[0], + **subblock_memory, + } + ) + + if is_calc_runtime: + # TODO: fix + # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms + # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ + # measure_non_block_runtime_ms(batch_size, prefill_seq_len, generation_seq_len, n_embd, vocab_size, + # benchmark_iterations, use_cuda_graph) + embedding_runtime_ms, lm_head_runtime_ms = None, None + else: + non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = None, None, None + non_block_memory = calculate_non_block_memory(n_embd, vocab_size, weights_dtype) + non_block_params = calculate_non_block_params(n_embd, vocab_size) + + # TODO + # the semantics here is wrong why do we refer, prefill_runtime_ms as embedding_runtime_ms and lm_head_runtime_ms as decode_runtime_ms ? + # Prefill is the first the user prompt inference, and Decode refer to the next generation process. both processes use all the model layers. + subblock_stats["non_block"] = { + "runtime_ms": non_block_runtime_ms, + "prefill_runtime_ms": embedding_runtime_ms, + "decode_runtime_ms": lm_head_runtime_ms, + "memory_mib": non_block_memory, + "num_params": non_block_params, + } + return subblock_stats + + +def launch_calc_subblock_stats(cfg: DictConfig) -> None: + """ + Launch the calc subblock stats function with Hydra configuration. + """ + mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}" + ) + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + calculate_subblock_stats_for_puzzle_dir( + cfg.calc_subblock_stats, + master_puzzle_dir=cfg.puzzle_dir, + teacher_dir=cfg.teacher_dir, + descriptor=descriptor, + model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])), + ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])), + batch_sizes=cfg.calc_subblock_stats.batch_sizes, + prefill_seq_len=cfg.calc_subblock_stats.prefill_seq_len, + generation_seq_len=cfg.calc_subblock_stats.generation_seq_len, + num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), + prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, + allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False), + benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), + merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, + subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, + moe_stats_filename=cfg.calc_subblock_stats.moe_stats_filename, + ) + + +def calculate_subblock_stats_for_puzzle_dir( + calc_subblock_stats_config: DictConfig, + master_puzzle_dir: Path | str, + teacher_dir: Path | str, + descriptor: Type[ModelDescriptor], + model_hidden_sizes: ListConfig, + ffn_hidden_sizes: ListConfig, + batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256), + prefill_seq_len: int = 2048, + generation_seq_len: int = 2048, + num_active_tokens_override: int | None = None, + prefill_queue_size: int = 0, # it's an infery-llm thing + allocate_prefill_query: bool = False, + benchmark_iterations: ( + int | None + ) = None, # If set then compute runtime performance statistics. TODO: recommend default value, is 1000 good? + merge_with_existing_stats: bool = False, + subblock_stats_filename: str = "subblock_stats.json", + moe_stats_filename: str = "moe_stats.json", +) -> None: + # ==== START === Setup for attach-helper ==== + # import sys + # import os + # sys.path.insert(0, os.environ["ATTACH_HELPER_INSTALLATION_PATH"]) + # from attach_helper import debugging_setup + # debugging_setup() # You can optionally pass a name to identify the job (e.g. `debugging_setup(name="my_script")`) + # ==== END === Setup for attach-helper ==== + if isinstance(batch_sizes, str): + batch_sizes = [ + int(batch_size) for batch_size in batch_sizes.strip("[]").replace(" ", "").split(",") + ] + + master_puzzle_dir = Path(master_puzzle_dir) + teacher_dir = ( + Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" + ) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) + # Get language model config for LM-specific attributes (VL models have nested config) + lm_config = descriptor.get_language_model_config(model_config) + subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes) + + subblock_stats_file = master_puzzle_dir / subblock_stats_filename + if subblock_stats_file.exists() and not merge_with_existing_stats: + raise ValueError( + f"Subblock stats file {subblock_stats_file} already exists and `merge_with_existing_stats` was set to False." + ) + + if subblock_stats_file.exists(): + with open(subblock_stats_file) as f: + subblock_stats = json.load(f) + else: + subblock_stats = [] + + moe_stats_file = master_puzzle_dir / moe_stats_filename + if not moe_stats_file.exists(): + warnings.warn( + f"MOE stats file {moe_stats_file} does not exist, can't calculate num active params" + ) + moe_stats_file = None + + subblock_stats_args = {immutabledict(x["args"]) for x in subblock_stats} + + data_types = [ + ("nvfp4", "nvfp4", "nvfp4"), + (torch.int8, torch.int8, torch.int8), + (torch.int8, torch.int8, torch.bfloat16), + (torch.bfloat16, torch.bfloat16, torch.bfloat16), + ] + + model_hidden_sizes = model_hidden_sizes + [ + lm_config.hidden_size + ] # add a teacher model hidden size + for batch_size, ( + weights_dtype, + activations_dtype, + kv_cache_dtype, + ), model_hidden_size in product(batch_sizes, data_types, model_hidden_sizes): + if num_active_tokens_override is not None: + prefill_seq_len = generation_seq_len = int(num_active_tokens_override / batch_size / 2) + + curr_benchmark_iterations = ( + benchmark_iterations if weights_dtype == torch.bfloat16 else None + ) + + curr_subblock_stats = calculate_subblock_stats( + calc_subblock_stats_config, + teacher_dir=teacher_dir, + model_config=model_config, + descriptor=descriptor, + master_puzzle_dir=master_puzzle_dir, + subblock_configs=subblock_configs, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=model_hidden_size, + n_head=lm_config.num_attention_heads, + vocab_size=lm_config.vocab_size, + benchmark_iterations=curr_benchmark_iterations, + use_cuda_graph=True, + weights_dtype=weights_dtype, + activations_dtype=activations_dtype, + kv_cache_dtype=kv_cache_dtype, + allocate_prefill_query=allocate_prefill_query, + moe_stats_file=moe_stats_file, + ) + + if immutabledict(curr_subblock_stats["args"]) in subblock_stats_args: + raise ValueError( + f"Failed merging subblock_stats. The following arguments already existed in the file: {curr_subblock_stats['args']}" + ) + + subblock_stats.append(curr_subblock_stats) + + # TODO fix: add_int8_runtime_estimates(subblock_stats) + + json_dump(subblock_stats, subblock_stats_file) + + mprint(subblock_stats_file) + + +def _load_subblock_configs( + master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig +) -> list[SubblockConfig]: + try: + subblock_configs = _load_subblock_configs_from_replacement_library(master_puzzle_dir) + except FileNotFoundError: + subblock_configs = _load_subblock_configs_from_subblock_library(master_puzzle_dir) + + # Extend subblock stats calculation space with ffn_hidden_sizes defined in the calc_subblock_stats section of the model config yaml file. + extra_ffn_subblock_configs = [] + for ffn_hidden_size in ffn_hidden_sizes: + # Use FFNConfig defaults (hidden_act will use its default value) + ffn_config = FFNConfig(intermediate_size=ffn_hidden_size) + extra_ffn_subblock_configs.append( + immutabledict({"subblock_config": ffn_config, "parent_layer_indices": tuple([-1])}) + ) # -1 to indicate that this sublock has no parent layer + subblock_configs.extend(extra_ffn_subblock_configs) + + return subblock_configs + + +def _load_subblock_configs_from_subblock_library(master_puzzle_dir: Path) -> list[SubblockConfig]: + subblocks_df = pd.read_json(master_puzzle_dir / "subblock_library.json") + subblocks_df["attention_config"] = subblocks_df["attention_config"].apply( + partial(_dataclass_from_dict, cls=AttentionConfig) + ) + subblocks_df["ffn_config"] = subblocks_df["ffn_config"].apply( + partial(_dataclass_from_dict, cls=FFNConfig) + ) + attention_configs = subblocks_df["attention_config"].dropna().drop_duplicates().tolist() + ffn_configs = subblocks_df["ffn_config"].dropna().drop_duplicates().tolist() + # Wrap in the same dict format expected by calculate_subblock_stats() callers. + # Use parent_layer_indices=(-1,) to indicate no specific parent layer. + subblock_configs = [ + immutabledict({"subblock_config": cfg, "parent_layer_indices": (-1,)}) + for cfg in attention_configs + ffn_configs + ] + return subblock_configs + + +def _load_subblock_configs_from_replacement_library( + master_puzzle_dir: Path, +) -> list[SubblockConfig]: + """Load unique subblocks from replacement_library.json, e.g., + 256 = 32*8 unique sublocks will be returned for a model with 32 layers and the search space of + 4 intermediate_size + teacher_intermediate_size + ffn_noop + att_op (teacher) + att_noop. + + Args: + master_puzzle_dir (Path): Directory with "replacement_library.json" file + + Returns: + list[SubblockConfig]: + """ + replacement_library = json.loads((master_puzzle_dir / "replacement_library.json").read_text()) + subblock_configs = set() + for layer_replacement in replacement_library: + layer_replacement = parse_layer_replacement(layer_replacement) + + for block_config in layer_replacement["child_block_configs"]: + block_config: BlockConfig + attention_frozen_dict = immutabledict( + { + "subblock_config": block_config.attention, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": block_config.ffn, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + if block_config.parallel_blocks is not None: + for block_idx, internal_block_config in enumerate(block_config.parallel_blocks): + attention_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.attention, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.ffn, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + subblock_configs = list(subblock_configs) + return subblock_configs + + +T_DataClass: TypeVar = Type[dataclasses.dataclass] + + +def _dataclass_from_dict( + d: dict | T_DataClass | None, + cls: T_DataClass, +) -> T_DataClass | None: + if isinstance(d, cls): + return d + if isinstance(d, dict): + return cls(**d) + if pd.isna(d): + return None + raise ValueError(f"_dataclass_from_dict: unrecognized {type(d)=} {d=}") + + +def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: + for curr_subblock_stats in subblock_stats: + args = curr_subblock_stats["args"] + if args["weights_dtype"] == "torch.int8": + assert args["activations_dtype"] == "torch.int8" + ffn_factor = 0.5 + attention_factor = 0.5 if args["kv_cache_dtype"] == "torch.int8" else 0.8 + + bf16_stats = _find_corresponding_bf16_stats(args, subblock_stats) + if bf16_stats is not None: + curr_subblocks = curr_subblock_stats["subblocks"] + [ + curr_subblock_stats["non_block"] + ] + bf16_subblocks = bf16_stats["subblocks"] + [bf16_stats["non_block"]] + for curr_subblock, bf16_subblock in zip(curr_subblocks, bf16_subblocks): + assert curr_subblock.get("subblock_config", None) == bf16_subblock.get( + "subblock_config", None + ) + is_attention = False + if (subblock_config := curr_subblock.get("subblock_config")) is not None: + if hasattr(subblock_config, "__dataclass_fields__"): + subblock_config = dataclasses.asdict(subblock_config) + is_attention = subblock_config.get("num_key_value_heads", None) is not None + runtime_factor = attention_factor if is_attention else ffn_factor + for stat_name, stat_value in bf16_subblock.items(): + if "runtime" in stat_name: + curr_subblock[stat_name] = stat_value * runtime_factor + + +def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> dict | None: + scenario_keys = [ + "batch_size", + "prefill_seq_len", + "generation_seq_len", + "prefill_queue_size", + "gpu", + "n_embd", + "n_head", + "vocab_size", + ] + corresponding_bf16_args = { + **{k: v for k, v in args.items() if k in scenario_keys}, + "is_calc_runtime": True, + "weights_dtype": "torch.bfloat16", + "activations_dtype": "torch.bfloat16", + "kv_cache_dtype": "torch.bfloat16", + } + matching_bf16_stats = [ + stats + for stats in subblock_stats + if all( + [ + stats["args"][key] == corresponding_bf16_args[key] + for key in corresponding_bf16_args.keys() + ] + ) + ] + if len(matching_bf16_stats) == 0: + return None + if len(matching_bf16_stats) == 1: + return matching_bf16_stats[0] + raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") diff --git a/modelopt/torch/puzzletron/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py new file mode 100644 index 0000000000..9de42e1553 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared tools: logging, hydra config, checkpoint utilities, and validation helpers.""" + +from .common import * +from .hydra_utils import * +from .logger import * diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/__init__.py b/modelopt/torch/puzzletron/tools/bypassed_training/__init__.py new file mode 100644 index 0000000000..5e11250245 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/bypassed_training/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for initializing child models from parent models via bypassed training.""" + +from .child_init import * +from .init_child_from_parent import * diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py new file mode 100644 index 0000000000..b242c7d48a --- /dev/null +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -0,0 +1,1172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Core logic for creating pruned child model state dicts from parent models. Used by init_child_from_parent.""" + +import concurrent.futures +import dataclasses +import json +import os +import re +import time +from copy import deepcopy +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from transformers import PretrainedConfig +from typeguard import check_type + +from ...block_config import SUBBLOCK_CLS_DICT, BlockConfig, _get_dataclass_type, _is_dataclass_type +from ...pruning.pruning_utils import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + _init_attention_biases, + _init_attention_weights, + _init_mlp_module, + _init_moe_module, + _load_expert_scores, +) +from ..logger import aprint, mprint + +__all__ = ["create_child_state_dict", "update_model_config"] + +IgnoreFn = Callable[[str], bool] + +default_ignore_fn: IgnoreFn = lambda _: False + + +class Printer: + @staticmethod + def print(s: str) -> None: + print(s) + + +def _process_single_layer( + layer_idx: int, + pruning_mixin, + descriptor, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + linear_init_mode: LinearInitMode, + ignored_keys: set, + keys: dict, + is_original_mha: bool, + head_size: int, + hidden_size: int, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: + """ + Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). + Thread-safe function for parallel layer processing. + """ + keys_to_remove = {} + layer_out_state_dict = {} + + # Delegate to pruning_mixin if available + if pruning_mixin is not None: + _layer_out = pruning_mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + return layer_out_state_dict, keys_to_remove + + # Legacy inline processing (fallback when no pruning_mixin) + + parent_block_config = original_config.block_configs[layer_idx] + child_block_config = new_config.block_configs[layer_idx] + + # Attention processing + for part in ["weight", "bias"]: + attn_prefix = f"model.layers.{layer_idx}.self_attn" + q_key = f"{attn_prefix}.q_proj.{part}" + k_key = f"{attn_prefix}.k_proj.{part}" + v_key = f"{attn_prefix}.v_proj.{part}" + o_key = f"{attn_prefix}.o_proj.{part}" + attn_keys = [q_key, k_key, v_key, o_key] + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + if all(key not in ignored_keys for key in attn_keys): + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + else: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() + if is_student_attn_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + layer_out_state_dict[linear_attn_key] = _init_linear_attn( + parent_state_dict, original_config, layer_idx, v_key, o_key + ) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + # student attn random init + for new_key in new_state_dict.keys(): + if attn_prefix in new_key: + layer_out_state_dict[new_key] = new_state_dict[new_key] + + # MLP/MoE processing + is_parent_moe = parent_block_config.ffn.is_moe + if not is_parent_moe: # not MoE, init the MLP + mlp_prefix = f"model.layers.{layer_idx}.mlp" + linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" + + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() + if is_student_mlp_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + teacher_mlp_state_dict = { + k.split(mlp_prefix + ".")[1]: v + for k, v in parent_state_dict.items() + if mlp_prefix in k + } + layer_out_state_dict[linear_mlp_key] = _init_linear_mlp(teacher_mlp_state_dict) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + ) + ) + else: + is_child_moe = child_block_config.ffn.is_moe + if is_child_moe: + parent_moe_config = original_config.block_configs[layer_idx].ffn.moe + child_moe_config = new_config.block_configs[layer_idx].ffn.moe + if parent_moe_config == child_moe_config: + pass # copy the MoE as is + elif mlp_init_mode == MlpInitMode.MoEChannelPruning: + for expert_idx in range(parent_moe_config.num_local_experts): + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + expert_idx=expert_idx, + ) + ) + + elif mlp_init_mode == MlpInitMode.ExpertRemoval: # remove some of the routed experts + router_key, new_experts_keys = _generate_moe_keys( + layer_idx, child_block_config.ffn.moe.num_local_experts + ) + _, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_block_config.ffn.moe.num_local_experts + ) + keys_to_remove[router_key] = keys.get(router_key) + for key in sum(orig_experts_keys.values(), []): + keys_to_remove[key] = keys.get(key) + + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weight=parent_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_router_weight=new_state_dict[router_key], + new_experts_weights=new_experts_weights, + ) + layer_out_state_dict[router_key] = out_router_weights + for name in new_experts_keys.keys(): + layer_out_state_dict.update( + zip(new_experts_keys[name], out_experts_weights[name]) + ) + elif child_block_config.ffn.no_op: # no-op, drop this layer + parent_mlp_prefix = f"model.layers.{layer_idx}.mlp" + for key in list(keys.keys()): + if key.startswith(parent_mlp_prefix): + keys_to_remove[key] = keys[key] + else: + assert mlp_init_mode == MlpInitMode.ConcatExpertsIntoDenseFFN, ( + "The parent layer is MoE and the child layer is a normal FFN. The only supported mode is ConcatExpertsAsMLP." + ) + + child_ffn_state_dict = _concatenate_experts_into_dense_ffn( + parent_state_dict, + mlp_init_config, + hidden_size, + layer_idx, + child_block_config, + parent_block_config, + ) + layer_out_state_dict.update(child_ffn_state_dict) + + for key in list(keys.keys()): + if key.startswith(f"model.layers.{layer_idx}.mlp"): + keys_to_remove[key] = keys[key] + + # Handle missing keys + for key_possibly_missing_in_student in [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", + "input_layernorm", + "post_attention_layernorm", + ]: + key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" + is_key_missing_from_student = ( + len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 + ) + if is_key_missing_from_student: + for k in list(keys.keys()): + if key_possibly_missing_in_student in k: + keys_to_remove[k] = keys[k] + + return layer_out_state_dict, keys_to_remove + + +@torch.no_grad() +def create_child_state_dict( + pruning_mixin, + descriptor, + original_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + ignore_fn: IgnoreFn = default_ignore_fn, + mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, + mlp_init_config: Optional[dict[str, Any]] = None, + owned_block_indexes: Optional[set[int]] = None, + linear_init_mode: LinearInitMode = LinearInitMode.Random, + hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, + channel_importance_path: Optional[str] = None, + max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None +): + mprint("=== Starting create_child_state_dict with optimizations ===") + total_start_time = time.time() + + # Phase 1: Initial setup and validation + setup_start_time = time.time() + if owned_block_indexes is None: + owned_block_indexes = set(range(new_config.num_hidden_layers)) + + # Auto-calculate optimal layer workers: min(cpu_count, num_layers) + if max_layer_workers is None: + cpu_count = os.cpu_count() or 1 + num_layers = len(owned_block_indexes) + max_layer_workers = min(cpu_count, num_layers) + mprint( + f"Auto-calculated layer workers: min({cpu_count} CPUs, {num_layers} layers) = {max_layer_workers}" + ) + else: + mprint(f"Using specified layer workers: {max_layer_workers}") + + # Memory optimization: Pre-allocate output state dict with known shapes + expected_keys_and_shapes = {k: v.shape for k, v in new_state_dict.items()} + out_state_dict = {} + + # Pre-allocate tensors where possible to reduce memory fragmentation + for key, shape in expected_keys_and_shapes.items(): + if key in new_state_dict: + tensor = new_state_dict[key] + # Only make contiguous if necessary (memory optimization) + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + + # Get language model config for LM-specific attributes (VL models have nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + # Check if original model is MHA (all layers have num_key_value_heads == num_attention_heads) + original_num_kv_heads_per_layer = [ + b.attention.num_key_value_heads for b in original_config.block_configs + ] + num_attention_heads = original_lm_config.num_attention_heads + is_original_mha = all(kv == num_attention_heads for kv in original_num_kv_heads_per_layer) + is_same_hidden_size = original_lm_config.hidden_size == new_lm_config.hidden_size + head_size = _get_head_dim(new_lm_config) + orig_head_size = _get_head_dim(original_lm_config) + assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}" + + # Allow different hidden sizes for pruning + if not is_same_hidden_size: + assert new_lm_config.hidden_size <= original_lm_config.hidden_size, ( + f"New hidden size ({new_lm_config.hidden_size}) must be <= original ({original_lm_config.hidden_size})" + ) + assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, ( + "Cannot copy as is when hidden sizes differ" + ) + + hidden_size = original_lm_config.hidden_size + + ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) + for key in ignored_keys: + aprint(f"Ignoring key {key} and taking its init from new_state_dict") + out_state_dict[key] = new_state_dict[key] + + keys = { + match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key + for key in original_state_dict.keys() + } + setup_time = time.time() - setup_start_time + mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") + + # Phase 2: Parallel layer processing + layer_processing_start_time = time.time() + + # Prepare arguments for parallel processing + process_layer_partial = partial( + _process_single_layer, + pruning_mixin=pruning_mixin, + descriptor=descriptor, + parent_state_dict=original_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + ) + + # Process layers in parallel with optimal worker count + mprint( + f"Processing {len(owned_block_indexes)} layers in parallel with {max_layer_workers} workers..." + ) + layer_results = [] + all_keys_to_remove = {} + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_layer_workers) as executor: + future_to_layer = { + executor.submit(process_layer_partial, layer_idx): layer_idx + for layer_idx in owned_block_indexes + } + + completed = 0 + for future in concurrent.futures.as_completed(future_to_layer): + layer_idx = future_to_layer[future] + try: + layer_state_dict, keys_to_remove = future.result() + layer_results.append((layer_idx, layer_state_dict)) + all_keys_to_remove.update(keys_to_remove) + + completed += 1 + if completed % 20 == 0 or completed == len( + owned_block_indexes + ): # More frequent progress updates + mprint(f"Completed {completed}/{len(owned_block_indexes)} layers") + except Exception as exc: + mprint(f"Layer {layer_idx} generated an exception: {exc}") + raise exc + + # Merge layer results into main state dict (memory efficient) + for layer_idx, layer_state_dict in layer_results: + out_state_dict.update(layer_state_dict) + + # Remove processed keys from the keys dict + for key_to_remove in all_keys_to_remove: + keys.pop(key_to_remove, None) + + layer_processing_time = time.time() - layer_processing_start_time + mprint( + f"Phase 2 - Parallel layer processing: {layer_processing_time:.2f}s ({max_layer_workers} workers)" + ) + + # Phase 3: Copy remaining keys from original model + copy_start_time = time.time() + keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys + for key in keys_to_copy_from_orig_model: + # Memory optimization: avoid unnecessary copies + tensor = original_state_dict[key] + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + copy_time = time.time() - copy_start_time + mprint( + f"Phase 3 - Copy remaining keys: {copy_time:.2f}s ({len(keys_to_copy_from_orig_model)} keys)" + ) + + # Handle hidden size pruning for remaining keys + if not is_same_hidden_size: + out_state_dict = _apply_hidden_size_pruning( + out_state_dict, + original_state_dict, + new_config, + original_config, + descriptor, + hidden_size_init_mode, + channel_importance_path, + owned_block_indexes, + ) + + # Phase 4: Verification + verify_start_time = time.time() + _verify_state_dicts_match(out_state_dict, expected_keys_and_shapes) + verify_time = time.time() - verify_start_time + mprint(f"Phase 4 - Verification: {verify_time:.2f}s") + + total_time = time.time() - total_start_time + mprint(f"=== create_child_state_dict completed in {total_time:.2f}s ===") + mprint( + f"Breakdown: Setup {setup_time:.1f}s + ParallelProcessing {layer_processing_time:.1f}s + Copy {copy_time:.1f}s + Verify {verify_time:.1f}s" + ) + mprint( + f"Speedup: Used {max_layer_workers} workers for {len(owned_block_indexes)} layers (CPU utilization: {max_layer_workers}/{os.cpu_count() or 1})" + ) + + return out_state_dict + + +def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, list[str]]]: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + router_key = f"{mlp_prefix}.router.weight" + names = ["gate_proj", "up_proj", "down_proj"] + experts_module_names = { + name: f"{mlp_prefix}.experts.{{expert_idx}}.{name}.weight" for name in names + } + return router_key, { + name: [module_name.format(expert_idx=expert_idx) for expert_idx in range(num_experts)] + for name, module_name in experts_module_names.items() + } + + +def _concatenate_experts_into_dense_ffn( + original_state_dict: dict[str, torch.Tensor], + mlp_init_config: Optional[dict], + hidden_size: int, + layer_idx: int, + child_block_config: BlockConfig, + parent_block_config: BlockConfig, +) -> dict[str, torch.Tensor]: + # Llama4 experts use SwiGLU (gated + silu); FFNConfig does not track these fields directly. + + # verify sizes + child_intermediate_size = child_block_config.ffn.intermediate_size + parent_moe_config = parent_block_config.ffn.moe + shared_expert_intermediate_dim = parent_moe_config.shared_expert_intermediate_dim + routed_expert_intermediate_dim = parent_moe_config.expert_intermediate_dim + total_concatenated_routed_experts_size = ( + child_intermediate_size - shared_expert_intermediate_dim + ) + assert total_concatenated_routed_experts_size % routed_expert_intermediate_dim == 0, ( + f"{child_intermediate_size=} " + f"{shared_expert_intermediate_dim=} " + f"{routed_expert_intermediate_dim=} " + f"{total_concatenated_routed_experts_size=} " + f"{total_concatenated_routed_experts_size % routed_expert_intermediate_dim=} != 0" + ) + num_concatenated_routed_experts = ( + total_concatenated_routed_experts_size // routed_expert_intermediate_dim + ) + + # if needed, concatenate some of the routed experts + if num_concatenated_routed_experts == 0: + print( + f"Removing all routed experts from layer {layer_idx}, turning the shared expert into a dense FFN." + ) + concat_routed_state_dict = dict() + else: + print( + f"Concatenating {num_concatenated_routed_experts} routed experts to the shared expert in layer {layer_idx}" + ) + router_key, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_moe_config.num_local_experts + ) + orig_experts_weights = { + name: [original_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + _, experts_weights = _prune_experts_by_score( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_router_weight=original_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_num_experts=num_concatenated_routed_experts, + ) + concat_dims = {"gate_proj": 0, "up_proj": 0, "down_proj": 1} + assert list(concat_dims) == list(experts_weights), ( + "concat_dims and experts_weights must have the same keys" + ) + concat_routed_state_dict = { + name: torch.cat(experts_weights[name], dim=concat_dims[name]) + for name in concat_dims.keys() + } + + # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. + parent_shared_expert_prefix = f"model.layers.{layer_idx}.mlp.shared_expert" + child_ffn_prefix = f"model.layers.{layer_idx}.mlp" + child_ffn_state_dict = dict() + + for module_name in [ + "gate_proj", + "up_proj", + "down_proj", + ]: + shared_expert_key = f"{parent_shared_expert_prefix}.{module_name}.weight" + child_ffn_key = f"{child_ffn_prefix}.{module_name}.weight" + shared_expert_weight = original_state_dict[shared_expert_key] + concat_routed_weight = concat_routed_state_dict.get(module_name) + + if concat_routed_weight is None: + child_weight = shared_expert_weight + else: + child_weight = torch.cat( + [shared_expert_weight, concat_routed_weight], + dim=1 if module_name == "down_proj" else 0, + ) + child_ffn_state_dict[child_ffn_key] = child_weight + + return child_ffn_state_dict + + +def _verify_state_dicts_match( + state_dict: dict[str, torch.Tensor], + expected_keys_and_shapes: dict[str, torch.Size], +) -> None: + # Verify keys match + expected_keys = expected_keys_and_shapes.keys() + missing_keys = set(expected_keys) - set(state_dict.keys()) + unexpected_keys = set(state_dict.keys()) - set(expected_keys) + assert len(missing_keys) == 0 and len(unexpected_keys) == 0, ( + f"Missing keys: {missing_keys}\nUnexpected keys: {unexpected_keys}" + ) + + # Verify shapes match + shape_mismatches = [] + for key in expected_keys: + expected_shape = expected_keys_and_shapes[key] + actual_shape = state_dict[key].shape + if expected_shape != actual_shape: + shape_mismatches.append(f"{key}: expected {expected_shape}, got {actual_shape}") + + assert len(shape_mismatches) == 0, "Shape mismatches found:\n" + "\n".join(shape_mismatches) + print(""" +############################ +create_child_state_dict: all keys and shapes matched successfully. +############################ +""") + + +def _init_mlp( + *, + mlp_init_mode: Union[MlpInitMode, str], + layer_idx: int, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + original_state_dict: dict, + new_state_dict: dict, + new_config: PretrainedConfig, + keys: dict[str, str], + ignored_keys: set[str], + expert_idx: Optional[int] = None, +) -> dict[str, torch.Tensor]: + out_state_dict = {} + + if mlp_init_mode == MlpInitMode.MoEChannelPruning: + if expert_idx is None: + return {} + mlp_prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}" + else: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + + key = f"{mlp_prefix}.down_proj.weight" + if key not in keys: + return {} + + mlp_c_proj_key = keys[key] + if mlp_c_proj_key not in ignored_keys: + mlp_keys = [ + keys.pop(f"{mlp_prefix}.{module_name}.weight") + for module_name in ["down_proj", "gate_proj", "up_proj"] + ] + pruned_filters = None + projection_matrix = None + for mlp_key in mlp_keys: + expanded_dim = 1 if "down_proj" in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + original_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + out_state_dict[mlp_key] = mlp_module_weight + else: + mprint(f"mlp_key {mlp_key} not in new_state_dict") + return out_state_dict + + +def _prune_experts_by_score( + *, + mlp_init_config: dict[str, Any], + layer_idx: int, + orig_router_weight: torch.Tensor, + orig_experts_weights: dict[str, list[torch.Tensor]], + new_num_experts: int, +) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]: + orig_num_experts = orig_router_weight.shape[0] + assert all( + len(orig_experts_module_weights) == orig_num_experts + for orig_experts_module_weights in orig_experts_weights.values() + ) + expert_scores = _load_expert_scores(mlp_init_config)[layer_idx] + assert len(expert_scores) == orig_num_experts + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: expert_scores[i], + reverse=mlp_init_config.get("higher_is_better", True), + )[:new_num_experts] + result_router_weight = orig_router_weight[selected_experts] + result_experts_weights = { + name: [orig_experts_module_weights[i] for i in selected_experts] + for name, orig_experts_module_weights in orig_experts_weights.items() + } + return result_router_weight, result_experts_weights + + +def _init_linear_attn( + parent_state_dict: dict[str, torch.Tensor], + parent_config: PretrainedConfig, + layer_idx: int, + v_key: str, + o_key: str, +) -> torch.Tensor: + """ + Init a linear layer that operates like an attention layer that assigns score 1 to the current token + and score 0 to all others: out = (Wo @ Wv) @ x + """ + n_embd = parent_config.hidden_size + head_size = _get_head_dim(parent_config) + # Get num_kv_heads from config, compute n_heads_in_group + n_kv_heads = parent_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = parent_config.num_attention_heads // n_kv_heads + + wv = parent_state_dict[v_key] + wv = wv.view(n_kv_heads, head_size, n_embd) + wv_expanded = torch.repeat_interleave(wv, n_heads_in_group, dim=0).reshape(n_embd, n_embd) + + wo = parent_state_dict[o_key] + + w_linear = wo @ wv_expanded + return w_linear + + +def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: + """ + A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. + """ + if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer + return teacher_mlp_state_dict["linear_mlp.weight"] + + w_up = teacher_mlp_state_dict["up_proj.weight"] + w_down = teacher_mlp_state_dict["down_proj.weight"] + w_linear = w_down @ w_up + return w_linear + + +def update_model_config( + model_config: PretrainedConfig, + model_config_overrides: None | list[dict[str, Any]] | str | dict | Path = None, +) -> PretrainedConfig: + new_model_config = deepcopy(model_config) + if model_config_overrides is None: + return new_model_config + + model_config_overrides = _parse_model_config_overrides( + model_config_overrides, model_config.num_hidden_layers + ) + + def override(item, item_overrides): + if item_overrides is None: + return item_overrides + if dataclasses.is_dataclass(item): + assert isinstance(item_overrides, dict) + return dataclass_override(item, item_overrides) + if isinstance(item, list): + assert isinstance(item_overrides, list) + return list_override(item, item_overrides) + return item_overrides + + def list_override(ls, ls_overrides: list): + assert len(ls) == len(ls_overrides) + return [override(item, item_overrides) for item, item_overrides in zip(ls, ls_overrides)] + + def dataclass_override(dc, dc_overrides: dict): + if not set(dc_overrides.keys()).issubset(dataclasses.asdict(dc).keys()): + raise ValueError( + f"Uknown overrides for dataclass {type(dc)}: {', '.join(set(dc_overrides.keys()) - dataclasses.asdict(dc).keys())}" + ) + field_types = {field.name: field.type for field in dataclasses.fields(dc)} + dc_changes = {} + for key, item_overrides in dc_overrides.items(): + previous_value, item_type = getattr(dc, key), field_types[key] + # if original block was no_op, we should not override it + if getattr(dc, "no_op", False): + return dc + + if previous_value is None and _is_dataclass_type(item_type): + new_value = _get_dataclass_type(item_type)(**item_overrides) + else: + new_value = override(previous_value, item_overrides) + check_type(new_value, item_type) + dc_changes[key] = new_value + return dataclasses.replace(dc, **dc_changes) + + new_model_config.block_configs = list_override( + new_model_config.block_configs, model_config_overrides + ) + + return new_model_config + + +def _parse_model_config_overrides( + model_config_overrides_json: str | dict | Path | list[dict], + n_layer: int, +) -> list[dict[str, Any]]: + """ + example model_config_overrides_dict: + { + "attention": [{"num_key_value_heads": 4}], + "ffn": [{"intermediate_size": 14336}] + } + """ + if isinstance(model_config_overrides_json, list) and isinstance( + model_config_overrides_json[0], dict + ): + return model_config_overrides_json + + if isinstance(model_config_overrides_json, dict): + model_config_overrides_dict = model_config_overrides_json + else: + if os.path.exists( + model_config_overrides_json + ): # using os.path.exists, because Path.exists throws an exception on long strings + model_config_overrides_json = Path(model_config_overrides_json).read_text() + print(f"I'm json loadsing over here. {model_config_overrides_json=}") + model_config_overrides_dict = json.loads(model_config_overrides_json) + + # Sanity checks and conversion to list of dictionaries + layer_wise_overrides = [{} for _ in range(n_layer)] + for config_key, config_value in model_config_overrides_dict.items(): + assert config_key in SUBBLOCK_CLS_DICT, f"Unknown config key: {config_key}" + assert isinstance(config_value, list), ( + f"Expected a list for {config_key}, got {config_value}" + ) + assert len(config_value) == n_layer or len(config_value) == 1, ( + f"Number of elements in {config_key} must be 1 or equal to the number of layers in the model" + ) + + if len(config_value) == 1: + model_config_overrides_dict[config_key] = config_value * n_layer + + for layer_idx in range(n_layer): + layer_wise_overrides[layer_idx][config_key] = model_config_overrides_dict[config_key][ + layer_idx + ] + + return layer_wise_overrides + + +def _apply_hidden_size_pruning( + out_state_dict: dict[str, torch.Tensor], + original_state_dict: dict[str, torch.Tensor], + new_config: PretrainedConfig, + original_config: PretrainedConfig, + descriptor, + hidden_size_init_mode: HiddenSizeInitMode, + channel_importance_path: Optional[str] = None, + owned_block_indexes: Optional[list[int]] = None, +) -> dict[str, torch.Tensor]: + """ + Apply hidden size pruning to all layers that depend on hidden_size. + This includes embeddings, layer norms, and any linear layers that haven't been handled yet. + """ + if isinstance(hidden_size_init_mode, str): + hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode) + + # Get language model config (for VL models this extracts the nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + original_hidden_size = original_lm_config.hidden_size + new_hidden_size = new_lm_config.hidden_size + + if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs: + return out_state_dict + + # Load channel ranking if needed + channel_ranking = None + if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_importance_path is not None: + with open(channel_importance_path, "r") as f: + channel_ranking = json.load(f)["channel_importance_ranking"] + else: + raise ValueError( + "channel_ranking_path must be provided in hidden_size_init_config for PruneByChannelRanking mode" + ) + + # Handle embedding layer + embed_key = "model.embed_tokens.weight" + if embed_key in out_state_dict and embed_key in original_state_dict: + out_state_dict[embed_key] = _prune_hidden_size_dimension( + original_state_dict[embed_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + else: + raise ValueError( + f"Embed key {embed_key} not found in out_state_dict or original_state_dict" + ) + + # Handle final layer norm + norm_key = "model.norm.weight" + if norm_key in out_state_dict and norm_key in original_state_dict: + out_state_dict[norm_key] = _prune_hidden_size_dimension( + original_state_dict[norm_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + + # Handle LM head + lm_head_key = "lm_head.weight" + if lm_head_key in out_state_dict and lm_head_key in original_state_dict: + if out_state_dict[lm_head_key].shape[1] != new_hidden_size: + out_state_dict[lm_head_key] = _prune_hidden_size_dimension( + original_state_dict[lm_head_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + + for block_idx in owned_block_indexes: + if new_config.block_configs[block_idx].parallel_blocks is None: + key_prefix = f"model.layers.{block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + new_config.block_configs[block_idx], + key_prefix, + ) + else: + for internal_block_idx in range( + len(new_config.block_configs[block_idx].parallel_blocks) + ): + block_config = new_config.block_configs[block_idx].parallel_blocks[ + internal_block_idx + ] + key_prefix = f"model.layers.{block_idx}.parallel_blocks.{internal_block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, + ) + return out_state_dict + + +def _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, +): + for layer_norm in ["input_layernorm", "post_attention_layernorm"]: + for part in ["weight", "bias"]: + key = f"{key_prefix}.{layer_norm}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + attn_prefix = f"{key_prefix}.self_attn" + if block_config.attention.replace_with_linear: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + for dim in [0, 1]: + out_state_dict[linear_attn_key] = _prune_hidden_size_dimension( + out_state_dict[linear_attn_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.attention.is_mamba: + for proj in ["in", "out"]: + mamba_key = f"{attn_prefix}.mamba_mixer.{proj}_proj.weight" + out_state_dict[mamba_key] = _prune_hidden_size_dimension( + out_state_dict[mamba_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if proj == "in" else 0, + ) + else: + for k in "qkvo": + for part in ["weight", "bias"]: + if k in "qkv" and part == "bias": + continue + key = f"{attn_prefix}.{k}_proj.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and k in "qkv" else 0, + ) + ffn_prefix = f"{key_prefix}.mlp" + if block_config.ffn.replace_with_linear: + linear_mlp_key = f"{ffn_prefix}.linear_mlp.weight" + for dim in [0, 1]: + out_state_dict[linear_mlp_key] = _prune_hidden_size_dimension( + out_state_dict[linear_mlp_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.ffn.moe is not None: + router_key = f"{ffn_prefix}.router.weight" + out_state_dict[router_key] = _prune_hidden_size_dimension( + out_state_dict[router_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.shared_expert", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + for expert_idx in range(block_config.ffn.moe.num_local_experts): + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.experts.{expert_idx}", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + else: + _prune_hidden_size_dimension_mlp( + ffn_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking + ) + return out_state_dict + + +def _prune_hidden_size_dimension_mlp( + name_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking +): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for part in ["weight", "bias"]: + if proj != "down_proj" and part == "bias": + continue + key = f"{name_prefix}.{proj}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and proj != "down_proj" else 0, + ) + + +def _prune_hidden_size_dimension( + original_tensor: torch.Tensor, + new_hidden_size: int, + hidden_size_init_mode: HiddenSizeInitMode, + channel_ranking: Optional[list[int]] = None, + dim: int = -1, +) -> torch.Tensor: + """ + Prune a tensor along the specified dimension to match the new hidden size. + """ + original_size = original_tensor.shape[dim] + + if hidden_size_init_mode == HiddenSizeInitMode.Random: + # Initialize with random weights + new_shape = list(original_tensor.shape) + new_shape[dim] = new_hidden_size + return torch.randn(new_shape, dtype=original_tensor.dtype, device=original_tensor.device) + + elif hidden_size_init_mode == HiddenSizeInitMode.Truncate: + # Simple truncation - take the first new_hidden_size elements + if dim == -1: + return original_tensor[..., :new_hidden_size] + elif dim == 0: + return original_tensor[:new_hidden_size, ...] + elif dim == 1: + return original_tensor[:, :new_hidden_size, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = slice(new_hidden_size) + return original_tensor[tuple(slices)] + + elif hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_ranking is None: + raise ValueError("Channel ranking must be provided for PruneByChannelRanking mode") + + # Use channel ranking to select the most important channels + if len(channel_ranking) < new_hidden_size: + raise ValueError( + f"Channel ranking has {len(channel_ranking)} channels but need {new_hidden_size}" + ) + + # Take the top new_hidden_size channels according to ranking + selected_channels = channel_ranking[:new_hidden_size] + + if dim == -1: + return original_tensor[..., selected_channels] + elif dim == 0: + return original_tensor[selected_channels, ...] + elif dim == 1: + return original_tensor[:, selected_channels, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = selected_channels + return original_tensor[tuple(slices)] + + else: + raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}") + + +def _get_head_dim(config) -> int: + """Get head dimension from config in a model-agnostic way. + + Some models like Llama have `head_dim` as a direct attribute, while others + like Qwen2 don't. This helper computes it from hidden_size and num_attention_heads. + """ + if hasattr(config, "head_dim") and config.head_dim is not None: + return config.head_dim + return config.hidden_size // config.num_attention_heads diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py new file mode 100644 index 0000000000..b5c027797c --- /dev/null +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" + +import json +import time +from typing import Optional + +import torch +import yaml +from transformers import AutoModelForCausalLM + +from modelopt.torch.export import copy_hf_ckpt_remote_code + +from ...anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ...anymodel.puzzformer import deci_x_patcher +from ..checkpoint_utils import copy_tokenizer, load_state_dict +from ..checkpoint_utils_hf import ( + _get_auto_class_for_trust_remote_code, + _save_checkpoint, + load_model_config, +) +from ..logger import mprint +from ..sharded_checkpoint_utils import _get_model_class_from_config +from .child_init import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + create_child_state_dict, + update_model_config, +) + +__all__ = ["init_child_from_parent"] + + +def init_child_from_parent( + descriptor: ModelDescriptor, + pruning_mixin, + parent_checkpoint_dir: str, + model_config_overrides_dict: dict | str, + output_checkpoint_dir: str, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config_yaml: Optional[str], + linear_init_mode: LinearInitMode, + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None +) -> None: + """ + Init child models from parent models in the style of bypass training, + but without having to run the entire bypass pipeline. + + Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. + + I/O Optimization Parameters: + - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) + - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) + """ + assert ( + gqa_init_mode not in [GQAInitMode.RandomKV, GQAInitMode.RandomBlock] + and mlp_init_mode != MlpInitMode.Random + and linear_init_mode != LinearInitMode.Random + ), ( + "We do not support random init of any subblock in this script to avoid initializing the student model" + ) + + descriptor = ModelDescriptorFactory.get(descriptor) + + copy_tokenizer( + parent_checkpoint_dir, + output_checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + if descriptor.requires_trust_remote_code(): + copy_hf_ckpt_remote_code(parent_checkpoint_dir, output_checkpoint_dir) + + parent_model_config = load_model_config( + parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) + parent_state_dict = load_state_dict(parent_checkpoint_dir) + + # Parse JSON if string + if isinstance(model_config_overrides_dict, str): + model_config_overrides_dict = json.loads(model_config_overrides_dict) + + # Separate global config overrides from block-level overrides + global_config_overrides = {} + block_config_overrides = {} + + for key, value in model_config_overrides_dict.items(): + if key in ["hidden_size"]: + global_config_overrides[key] = value + else: + block_config_overrides[key] = value + + # Load child model config with global overrides + child_model_config = load_model_config( + parent_checkpoint_dir, + model_config_overrides=global_config_overrides, + ignore_unexpected_config_keys=True, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + # Apply block-level overrides if any + if block_config_overrides: + child_model_config = update_model_config( + model_config=child_model_config, + model_config_overrides=block_config_overrides, + ) + + with torch.device("meta"): + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + with deci_x_patcher( + model_descriptor=descriptor, block_configs=child_model_config.block_configs + ): + model_class = _get_model_class_from_config(child_model_config) + trust_remote_code = descriptor.requires_trust_remote_code() + if trust_remote_code: + auto_cls = _get_auto_class_for_trust_remote_code(child_model_config) + child_model = auto_cls.from_config( + child_model_config, trust_remote_code=trust_remote_code + ) + elif model_class is AutoModelForCausalLM: + child_model = AutoModelForCausalLM.from_config(child_model_config) + else: + child_model = model_class._from_config(child_model_config) + + child_state_dict_with_meta_tensors = child_model.state_dict() + + mlp_init_config = ( + yaml.safe_load(mlp_init_config_yaml) + if isinstance(mlp_init_config_yaml, str) + else mlp_init_config_yaml + ) + + # Profile create_child_state_dict with automatic layer parallelization + mprint("Starting create_child_state_dict...") + start_time = time.time() + child_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=parent_state_dict, + new_state_dict=child_state_dict_with_meta_tensors, + original_config=parent_model_config, + new_config=child_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, + channel_importance_path=channel_importance_path, + max_layer_workers=max_layer_workers, + ) + create_child_state_dict_time = time.time() - start_time + mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") + + # Profile _save_checkpoint with automatic I/O worker calculation + mprint("Starting _save_checkpoint...") + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"I/O Settings: max_workers={actual_io_workers}") + start_time = time.time() + _save_checkpoint( + child_model_config, + child_state_dict, + output_checkpoint_dir, + descriptor, + max_workers=max_workers, + ) + save_checkpoint_time = time.time() - start_time + mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") + + # Print profiling summary with actual worker counts used + total_core_time = create_child_state_dict_time + save_checkpoint_time + actual_layer_workers = max_layer_workers if max_layer_workers else "auto" + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"\n=== PROFILING SUMMARY ===") + mprint( + f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" + ) + mprint( + f"_save_checkpoint: {save_checkpoint_time:.2f}s ({save_checkpoint_time / total_core_time * 100:.1f}%)" + ) + mprint(f"Total core processing: {total_core_time:.2f}s") + mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") + mprint(f"=========================\n") diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py new file mode 100644 index 0000000000..becabf0431 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Utilities for loading and initializing PyTorch model checkpoints (AnyModel / HF layouts).""" + +import concurrent.futures +import warnings +from functools import partial +from pathlib import Path +from typing import Literal, TypeVar + +import torch +from safetensors.torch import load_file as safe_load_file +from torch import nn +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from .checkpoint_utils_hf import load_model_config +from .common import infer_weights_dtype + +__all__ = [ + "SAFETENSORS_SUBBLOCKS_DIR_NAME", + "PTH_SUBBLOCKS_DIR_NAME", + "STATE_DICT_FILE_NAME", + "load_state_dict", + "load_model_config", + "init_module_with_state_dict", + "init_empty_module", + "skip_init", + "is_valid_decilm_checkpoint", + "copy_tokenizer", +] + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +STATE_DICT_FILE_NAME = "model.pth" + + +def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: + checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) + + if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists(): + return torch.load(state_dict_path, map_location="cpu", weights_only=True) + + if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(safetensors_subblocks_dir) + + if (pth_subblocks_dir := checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(pth_subblocks_dir) + + if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( + checkpoint_dir / SAFE_WEIGHTS_NAME + ).exists(): + from .sharded_checkpoint_utils import ( + load_sharded_state_dict, # local import to avoid circular import + ) + + return load_sharded_state_dict(checkpoint_dir) + + raise FileNotFoundError( + f"Couldn't find state dict path or subblocks dir inside {checkpoint_dir}" + ) + + +def _normalize_checkpoint_dir(checkpoint_dir: Path | str) -> Path: + checkpoint_dir = Path(checkpoint_dir) + if checkpoint_dir.is_file(): + checkpoint_dir = checkpoint_dir.parent + return checkpoint_dir + + +def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tensor]: + torch_paths = list(subblocks_dir.glob("*.pth")) + safetensors_paths = list(subblocks_dir.glob("*.safetensors")) + + if len(torch_paths) != 0: + load_fn = partial(torch.load, map_location="cpu", weights_only=True) + file_paths = torch_paths + elif len(safetensors_paths) != 0: + load_fn = safe_load_file + file_paths = safetensors_paths + else: + raise ValueError(f"No tensor files found in {subblocks_dir=}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + state_dict_shards = list(executor.map(load_fn, file_paths)) + + state_dict = {k: v for shard in state_dict_shards for k, v in shard.items()} + return state_dict + + +NNModule = TypeVar("NNModule", bound=nn.Module) + + +def init_module_with_state_dict( + state_dict: dict[str, torch.Tensor], + module_cls: type[NNModule], + *init_args, + **init_kwargs, +) -> NNModule: + weights_dtype = infer_weights_dtype(state_dict) + module = init_empty_module(module_cls, weights_dtype, *init_args, **init_kwargs) + module.load_state_dict(state_dict) + return module + + +def init_empty_module( + module_cls: type[NNModule], + dtype: torch.dtype, + *init_args, + **init_kwargs, +) -> NNModule: + default_dtype = torch.get_default_dtype() + current_device = torch.ones(1).device + torch.set_default_dtype(dtype) + try: + module = skip_init(module_cls, *init_args, device=current_device, **init_kwargs) + finally: + torch.set_default_dtype(default_dtype) + return module + + +def skip_init(module_cls, *args, **kwargs) -> nn.Module: + """Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg.""" + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + + final_device = kwargs.pop("device", "cpu") + with torch.device("meta"): + module = module_cls(*args, **kwargs) + + module = module.to_empty(device=final_device) + return module + + +def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bool = False) -> bool: + """True if the checkpoint config loads and defines ``block_configs`` (AnyModel / puzzletron layout). + + Args: + checkpoint_dir: Path to checkpoint directory + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + True if the config has ``block_configs``, False otherwise + """ + try: + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) + if not hasattr(model_config, "block_configs") or model_config.block_configs is None: + warnings.warn( + f"Skipping checkpoint '{checkpoint_dir}' - missing block_configs (not an AnyModel-style layout)" + ) + return False + return True + except Exception as e: + warnings.warn(f"Skipping checkpoint '{checkpoint_dir}' - failed to load config: {e}") + return False + + +def copy_tokenizer( + source_dir_or_tokenizer_name: Path | str, + target_dir: Path | str, + on_failure: Literal["raise", "warn"] = "raise", + trust_remote_code: bool = False, +) -> None: + """Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) + to avoid collision between transformers versions. + """ + source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt" + if source_tokenizer_name_path.exists(): + source_dir_or_tokenizer_name = source_tokenizer_name_path.read_text().strip() + + tokenizer = None + try: + tokenizer = AutoTokenizer.from_pretrained( + source_dir_or_tokenizer_name, trust_remote_code=trust_remote_code + ) + except Exception: + message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'" + if on_failure == "raise": + raise FileNotFoundError(message) + else: + warnings.warn(message) + + if tokenizer is not None: + target_dir = Path(target_dir) + target_dir.mkdir(exist_ok=True, parents=True) + tokenizer.save_pretrained(target_dir) + + target_tokenizer_name_path = target_dir / "tokenizer_name.txt" + is_given_tokenizer_name_as_argument = not Path(source_dir_or_tokenizer_name).exists() + if is_given_tokenizer_name_as_argument: + target_tokenizer_name_path.write_text(source_dir_or_tokenizer_name) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py new file mode 100644 index 0000000000..69b8e5e29d --- /dev/null +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Utilities for loading and saving Hugging Face-format checkpoints (``AutoConfig`` + optional ``block_configs``). +""" + +import concurrent.futures +import dataclasses +import fcntl +import os +import time +from collections import defaultdict +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO + +import torch +import transformers +from safetensors.torch import save_file as safe_save_file +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from modelopt.torch.utils import json_dumps + +from ..block_config import maybe_cast_block_configs + +if TYPE_CHECKING: + from ..anymodel.model_descriptor import ModelDescriptor +from .logger import mprint + +__all__ = [ + "SAFETENSORS_SUBBLOCKS_DIR_NAME", + "PTH_SUBBLOCKS_DIR_NAME", + "RELATIVE_SUBBLOCKS_DIR", + "force_cache_dynamic_modules", + "load_model_config", + "init_model_from_config", + "save_checkpoint", + "save_subblocks", + "save_model_config", +] + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME) + + +# TODO: (esegal) Should ask the model for something like this +NON_LAYER_MODULE_TO_FILE_TYPE = { + "model.embed_tokens": "embeddings", + "model.norm": "lm_head", + "lm_head": "lm_head", +} +MODULE_WITHIN_LAYER_TO_FILE_TYPE = { + "input_layernorm": "attention", + "self_attn": "attention", + "post_attention_layernorm": "ffn", + "mlp": "ffn", + "parallel_blocks": "multi_block", +} +LAYERS_MODULE_NAME = "model.layers" + + +def force_cache_dynamic_modules( + config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False +): + has_remote_code = ( + hasattr(config, "auto_map") + and isinstance(config.auto_map, dict) + and "AutoConfig" in config.auto_map.keys() + ) + if has_remote_code and trust_remote_code: + for class_reference in config.auto_map.values(): + _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) + + +def load_model_config( + checkpoint_dir: Path | str, + model_config_overrides: Mapping | None = None, + ignore_unexpected_config_keys: bool = False, + trust_remote_code: bool = False, +): + """Load model configuration from a checkpoint directory. + + Args: + checkpoint_dir: Path to the checkpoint directory (e.g. containing config.json). + model_config_overrides: Optional mapping of config overrides. + ignore_unexpected_config_keys: If True, ignore unexpected config keys. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + Loaded model configuration (PretrainedConfig). + """ + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + if model_config_overrides is None: + model_config_overrides = {} + + config, unused_kwargs = AutoConfig.from_pretrained( + checkpoint_dir, + trust_remote_code=trust_remote_code, + return_unused_kwargs=True, + **model_config_overrides, + ) + if hasattr(config, "block_configs"): + config.block_configs = maybe_cast_block_configs(config.block_configs) + + force_cache_dynamic_modules(config, checkpoint_dir, trust_remote_code=trust_remote_code) + + if not ignore_unexpected_config_keys: + if unused_kwargs: + raise ValueError(f"Unexpected config keys: {unused_kwargs.keys()}") + + return config + + +def _get_model_class_from_config(config: PretrainedConfig) -> type: + """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + +def _get_auto_class_for_trust_remote_code(config: PretrainedConfig) -> type: + """Pick the right Auto class for a trust_remote_code model by inspecting auto_map. + + When a model requires trust_remote_code, the native transformers class resolved from + config.architectures must NOT be used directly — it may have a different module structure + than the trust_remote_code class (e.g. NemotronH: native uses ``model.`` prefix, but the + trust_remote_code class uses ``backbone.`` prefix, causing key mismatches throughout the + pipeline). Instead, we route through the appropriate Auto class so that from_config() + resolves the class via auto_map, picking up the correct trust_remote_code implementation. + + Models declare which Auto class they support via config.auto_map. We walk a priority list + so that CausalLM models and VL models (AutoModelForConditionalGeneration or similar) are + both handled correctly. + """ + auto_map = getattr(config, "auto_map", {}) + priority = [ + "AutoModelForCausalLM", + "AutoModelForConditionalGeneration", + "AutoModelForImageTextToText", + "AutoModel", + ] + for name in priority: + if name in auto_map and hasattr(transformers, name): + return getattr(transformers, name) + return AutoModelForCausalLM + + +def init_model_from_config( + config: PretrainedConfig, + *, + trust_remote_code: bool = False, + **kwargs, +) -> PreTrainedModel: + """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts). + + ``trust_remote_code`` defaults to False (only ``AutoModelForCausalLM.from_config`` uses it). + Pass True when loading configs that rely on custom modeling code from the checkpoint. + """ + model_class = _get_model_class_from_config(config) + if trust_remote_code: + auto_cls = _get_auto_class_for_trust_remote_code(config) + return auto_cls.from_config(config, trust_remote_code=trust_remote_code, **kwargs) + if model_class is AutoModelForCausalLM: + return AutoModelForCausalLM.from_config(config, **kwargs) + # Concrete model classes (e.g. GptOssForCausalLM, Qwen3VLMoeForConditionalGeneration): + # _from_config forwards kwargs to __init__, which does not accept trust_remote_code. + return model_class._from_config(config, **kwargs) + + +def save_checkpoint( + model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor" +) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) + + +def _save_checkpoint( + model_config: PretrainedConfig, + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Phase 1: Save config + save_model_config(model_config, checkpoint_dir) + + # Phase 2: Build weight map using descriptor and write index + subblock_keys = descriptor.get_weight_groups( + layer_names=state_dict.keys(), + num_hidden_layers=model_config.num_hidden_layers, + ) + + weight_map = {} + for subblock, layer_keys in subblock_keys.items(): + weight_map_entries = { + key: f"subblocks_safetensors/{subblock}.safetensors" for key in layer_keys + } + weight_map.update(weight_map_entries) + + # Handle tie_word_embeddings - remove from state_dict and weight_map BEFORE writing index + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False) and output_emb_weight_name in state_dict: + state_dict = {k: v for k, v in state_dict.items() if k != output_emb_weight_name} + weight_map = {k: v for k, v in weight_map.items() if k != output_emb_weight_name} + + # Write index (now without tied embedding) + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) + + # Phase 3: Save subblocks + save_subblocks( + state_dict, + checkpoint_dir, + weight_map=weight_map, + multi_threaded=True, + max_workers=max_workers, + ) + + +def save_subblocks( + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + weight_map: dict[str, str] | None = None, + multi_threaded: bool = True, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting save_subblocks detailed profiling ===") + subblocks_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Build weight map (use provided or build from state_dict) + weight_map_start_time = time.time() + if weight_map is None: + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 2: Create subblocks directory + dir_create_start_time = time.time() + subblocks_path = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_path.mkdir(parents=True, exist_ok=True) + dir_create_time = time.time() - dir_create_start_time + mprint(f" Step 2 - Create directory: {dir_create_time:.2f}s") + + # Step 3: Organize tensors by file + organize_start_time = time.time() + filename_to_partial_state_dict = defaultdict(dict) + total_tensor_size = 0 + for weight_name, weight in state_dict.items(): + if weight_name in weight_map: + # Ensure tensor is contiguous and on CPU for faster I/O + tensor = ( + weight.contiguous().cpu() if weight.device.type != "cpu" else weight.contiguous() + ) + filename_to_partial_state_dict[weight_name_to_filename[weight_name]][weight_name] = ( + tensor + ) + total_tensor_size += weight.numel() * weight.element_size() + organize_time = time.time() - organize_start_time + mprint( + f" Step 3 - Organize tensors: {organize_time:.2f}s ({total_tensor_size / (1024**3):.2f}GB total)" + ) + + # Step 4: Prepare save arguments and auto-calculate optimal I/O workers + prepare_start_time = time.time() + safe_save_kwargs = [ + {"tensors": partial_state_dict, "filename": filename, "metadata": {"format": "pt"}} + for filename, partial_state_dict in filename_to_partial_state_dict.items() + ] + + # Auto-calculate optimal I/O workers: min(cpu_count, num_files) + if max_workers is None: + cpu_count = os.cpu_count() or 1 + num_files = len(safe_save_kwargs) + max_workers = min(cpu_count, num_files) + mprint( + f" Auto-calculated I/O workers: min({cpu_count} CPUs, {num_files} files) = {max_workers}" + ) + else: + mprint(f" Using specified I/O workers: {max_workers}") + + prepare_time = time.time() - prepare_start_time + mprint(f" Step 4 - Prepare save args: {prepare_time:.2f}s ({len(safe_save_kwargs)} files)") + + # Step 5: Save files with optimal worker count + save_start_time = time.time() + if multi_threaded: + mprint(f" Using multi-threaded saving with {max_workers} workers...") + + def optimized_safe_save(kwargs): + try: + safe_save_file(**kwargs) + return True + except Exception as e: + mprint(f" Error saving {kwargs['filename']}: {e}") + return False + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(optimized_safe_save, safe_save_kwargs)) + + # Check for any failures + failed_saves = sum(1 for r in results if not r) + if failed_saves > 0: + raise RuntimeError(f" {failed_saves} shard file(s) failed to save") + else: + mprint(" Using single-threaded saving...") + for kwargs in safe_save_kwargs: + safe_save_file(**kwargs) + + save_time = time.time() - save_start_time + mprint(f" Step 5 - Save files: {save_time:.2f}s ({max_workers} workers)") + + subblocks_total_time = time.time() - subblocks_start_time + mprint(f"=== save_subblocks completed in {subblocks_total_time:.2f}s ===") + mprint( + f" Breakdown: WeightMap {weight_map_time:.1f}s + DirCreate {dir_create_time:.1f}s + " + f"Organize {organize_time:.1f}s + Prepare {prepare_time:.1f}s + Save {save_time:.1f}s" + ) + + # Calculate effective I/O speed + io_speed_gbps = (total_tensor_size / (1024**3)) / save_time if save_time > 0 else 0 + mprint(f" Effective I/O speed: {io_speed_gbps:.2f} GB/s ({max_workers} workers)") + mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time") + + +def _write_text(content: str, f: BinaryIO) -> None: + f.write(content.encode("utf-8")) + + +def _write_file_process_safe( + content: Any, + path: Path | str, + write_fn: Callable[[Any, BinaryIO], None] = _write_text, +) -> None: + """ + Write a file in a multi-process safe way. + If another process tries to write the same file using this method, the current process + "gives up" and assumes that the matter is being taken care of by another process. + + write_fn is a function that receives file contents and a binary file object, + and writes the content to the file. It can be _write_text (defined above), or torch.save, + or a similar function (not safetensors.torch.save_file since it expects a path). + """ + # Open with "ab+" so the file is not truncated before the lock is acquired. + # Once we hold the exclusive lock we seek to the start and truncate explicitly. + with open(path, "ab+") as f: + # Try to acquire an exclusive, non-blocking lock + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + return # Exit immediately if the lock is not acquired + + f.seek(0) + f.truncate() + write_fn(content, f) # Write the content if lock is acquired + f.flush() # Ensure data is written to disk + + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + + +def _build_safetensors_weight_map( + *, + state_dict: dict[str, torch.Tensor], + non_layer_module_to_file_type: dict[str, str], + module_within_layer_to_file_type: dict[str, str], + layers_module_name: str, +) -> dict[str, Path]: + weight_map = {} + unmapped_weight_names = [] + for weight_name in state_dict: + found_match = False + for module_name, file_type in non_layer_module_to_file_type.items(): + if weight_name.startswith(f"{module_name}."): + weight_map[weight_name] = str(RELATIVE_SUBBLOCKS_DIR / f"{file_type}.safetensors") + found_match = True + if not found_match: + if weight_name.startswith(f"{layers_module_name}."): + name_parts = weight_name[len(layers_module_name) + 1 :].split(".") + layer_index = name_parts[0] + name_within_layer = ".".join(name_parts[1:]) + + for module_name, file_type in module_within_layer_to_file_type.items(): + if name_within_layer.startswith(f"{module_name}."): + weight_map[weight_name] = str( + RELATIVE_SUBBLOCKS_DIR / f"block_{layer_index}_{file_type}.safetensors" + ) + found_match = True + + if not found_match: + unmapped_weight_names.append(weight_name) + + if len(unmapped_weight_names) > 0: + raise ValueError( + f"Unmapped weight names: {unmapped_weight_names}\n" + f"Add them to the `non_layer_module_to_file_type` or " + f"`module_within_layer_to_file_type` dictionaries." + ) + + return weight_map + + +def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: + if hasattr(model_config, "block_configs"): + model_config.block_configs = [ + dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf + for conf in model_config.block_configs + ] + model_config.save_pretrained(checkpoint_dir) diff --git a/modelopt/torch/puzzletron/tools/common.py b/modelopt/torch/puzzletron/tools/common.py new file mode 100644 index 0000000000..2c7b65ea67 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/common.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +__all__ = [ + "resolve_torch_dtype", + "infer_weights_dtype", +] + + +def resolve_torch_dtype(dtype: str | torch.dtype) -> torch.dtype: + """Resolve a dtype that may be a string (e.g. from Hydra/OmegaConf config) to torch.dtype. + + Accepts ``torch.dtype`` objects (returned as-is) and strings like + ``"torch.bfloat16"`` or ``"bfloat16"``. + """ + if isinstance(dtype, torch.dtype): + return dtype + name = dtype.removeprefix("torch.") + try: + result = getattr(torch, name) + except AttributeError: + raise ValueError(f"Unknown torch dtype: {dtype!r}") from None + if not isinstance(result, torch.dtype): + raise ValueError(f"torch.{name} is not a dtype (got {type(result).__name__})") + return result + + +def infer_weights_dtype(state_dict: dict[str, torch.Tensor]) -> torch.dtype: + weights_dtype = [p.dtype for p in state_dict.values() if torch.is_floating_point(p)] + weights_dtype = weights_dtype[0] if len(weights_dtype) > 0 else torch.get_default_dtype() + return weights_dtype diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py new file mode 100644 index 0000000000..c30be4efde --- /dev/null +++ b/modelopt/torch/puzzletron/tools/hydra_utils.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for hydra config initialization. +""" + +import datetime +import random +from pathlib import Path + +from hydra import compose, initialize, initialize_config_dir +from hydra.utils import get_object +from omegaconf import DictConfig, OmegaConf + +__all__ = [ + "register_hydra_resolvers", + "initialize_hydra_config_for_dir", + "initialize_hydra_config", +] + + +def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: + """ + Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. + Used as a resolver in hydra configs. + """ + steps = (int(tokens) // int(block)) // int(mbs) + w = pct * steps + return max(1, round(w)) + + +def register_hydra_resolvers(): + OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) + OmegaConf.register_new_resolver( + "random_int", lambda low, high: random.randint(int(low), int(high)) + ) + OmegaConf.register_new_resolver( + "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None + ) + OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/modelopt/torch/puzzletron/tools/kd_model.py b/modelopt/torch/puzzletron/tools/kd_model.py new file mode 100644 index 0000000000..8a3ec0af13 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/kd_model.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Knowledge distillation loss functions. + +Provides normalized_mse_loss and cosine_embedding_loss_batched for comparing +model outputs. Used by validation.py. +""" +# mypy: ignore-errors + +from typing import Literal + +import torch +import torch.nn.functional as F +from torch import Tensor + +__all__ = ["normalized_mse_loss", "cosine_embedding_loss_batched"] + + +def normalized_mse_loss( + input: Tensor, + target: Tensor, + reduction: Literal["none", "mean", "sum"] = "mean", + epsilon: float = 1e-6, +) -> Tensor: + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor: + # inputs are of shape (B,T,H) + batch_size = input.size(0) + input = input.view(batch_size, -1) + target = target.view(batch_size, -1) + target_tensor = input.new(input.size(0)).fill_(1) + loss = F.cosine_embedding_loss( + input1=input, input2=target, target=target_tensor, reduction="none" + ) + return loss diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py new file mode 100644 index 0000000000..9931f68251 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import logging +import os +import sys + +import torch.distributed.launch # noqa: F401 + +__all__ = [ + "logger", + "aprint", + "lmprint", + "mprint", + "lprint", +] + + +logging.getLogger("fsspec.local").setLevel(logging.ERROR) +logging.getLogger("websockets.client").setLevel(logging.WARN) +logging.getLogger("websockets.server").setLevel(logging.WARN) +logging.getLogger("websockets.server:connection").setLevel(logging.WARN) + + +class LogColors: + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" + + +class DistributedLogger(logging.Logger): + verbosity = logging.ERROR + + def __init__(self, name, level=logging.DEBUG): + super().__init__(name, level) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + def dist_log(self, msg: str, ranks: str = "main"): + """Log parameter msg with the given ranks. + + Args: + msg: The message to log. + ranks: The ranks to log the message to. Choices are: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes + """ + # print(msg, ranks) + if ranks not in ["all", "main", "local_main", "last"]: + raise NotImplementedError( + f"Could not broadcast msg {msg} - " + f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}" + ) + # All ranks to print + if ranks == "all": + pass + + # Only main rank at node 0 to print + elif ( + (ranks == "main" and self.global_rank != 0) + or (ranks == "last" and self.global_rank != self.world_size - 1) + or (ranks == "local_main" and self.local_rank != 0) + ): + return + + message_source = self.get_caller_location() + + self.info( + f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}" + ) + + # def dist_warning(self, msg): + # if self.verbosity <= logging.WARNING: + # self.warning(f"[rank-{self.global_rank}] " + msg) + + @staticmethod + def get_caller_location() -> str: + # Get the caller's stack frame + frame = inspect.currentframe() + + # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source + caller_frame = frame.f_back.f_back.f_back + + # Get the filename and line number from the caller's stack frame + filename = os.path.basename(caller_frame.f_code.co_filename) + lineno = caller_frame.f_lineno + return f"{filename}:{lineno}" + + +# Initialize logger without modifying global logger class or torch logger +logger = DistributedLogger(__name__) +logger.propagate = False + +formatter = logging.Formatter("[%(asctime)s]%(message)s") +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(formatter) +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + + +# Define a custom function to redirect warnings to logger +# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): +# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') + + +# Use the custom warning handler +# warnings.showwarning = custom_warning_handler + +logger: DistributedLogger + + +def aprint(msg: str | None): + """ + All ranks from all nodes prints + """ + return logger.dist_log(msg=msg, ranks="all") + + +def lmprint(msg: str | None): + """ + All local main ranks prints (rank 0 in each node) + """ + return logger.dist_log(msg=msg, ranks="local_main") + + +def mprint(msg: str | None): + """ + Master prints only (rank 0 in node 0) + """ + return logger.dist_log(msg=msg, ranks="main") + + +def lprint(msg: str | None): + """ + Last rank prints only (rank -1 in node 0) + """ + return logger.dist_log(msg=msg, ranks="last") diff --git a/modelopt/torch/puzzletron/tools/post_init_sparse.py b/modelopt/torch/puzzletron/tools/post_init_sparse.py new file mode 100644 index 0000000000..ab939e3bb4 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/post_init_sparse.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import torch +from torch import nn +from torch.nn.utils.prune import custom_from_mask + +""" +Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) +into a standard format with sparsified weights. +""" + +__all__ = [] + + +class SparsityMethod: + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" + + @staticmethod + def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): + sparsity_masks = {} + for name in list(state_dict.keys()): + original_name = name.replace("_orig", "") + mask_name = original_name + "_mask" + if name[-4:] == "orig" and mask_name in state_dict: + val = state_dict[name] + mask = state_dict[name[:-4] + "mask"] + val[mask == 0] = 0 + sparsity = (val == 0).sum() / mask.numel() + sparsity_masks[original_name[:-7]] = mask + if verbose: + print(f"fix_state_dict_inplace: {name} {sparsity=}") + del state_dict[mask_name] + del state_dict[name] + state_dict[original_name] = val + if change_dtype: + for name in state_dict: + state_dict[name] = state_dict[name].to(torch.bfloat16) + return state_dict, sparsity_masks + + def filter_function(self): + pass + + def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None: + for name, module in model.named_modules(): + if name in mask_dict: + custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device)) + print(name) + print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) + + def do_sparsity(self, model: nn.Module, mask_dict=None): + full_name_layers = [] + for block_idx, block_config in enumerate(model.config.block_configs): + ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] + att_name = block_config.attention.sparsify + block = model.model.layers[block_idx] + if hasattr(block, "mlp"): + for name, m in block.mlp.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "mlp." + name + ) + if hasattr(block, "self_attn"): + for name, m in block.self_attn.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "self_attn." + name + ) + + if mask_dict is None: + state_dict_for_sparsifying = { + k.rstrip(".weight"): v + for k, v in model.state_dict().items() + if k.rstrip(".weight") in full_name_layers + } + mask_dict = self.calculate_masks(state_dict_for_sparsifying) + # print('Apply sparsity') + # print(full_name_layers) + # print(model.state_dict().keys()) + # print(list(mask_dict.keys())) + + self.apply_masks(model, mask_dict) + + +class SparsityMethod2o4(SparsityMethod): + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" + mask_dict = {} + for key, val in state_dict.items(): + orig_size = val.shape + scores = val.flatten() ** 2 + mask = self.create_mask(scores) + mask = mask.reshape(orig_size) + mask_dict[key] = mask + return mask_dict + + def create_mask(self, score, value=0): + score = score # .cpu() + orig_size = score.shape + score = score.view(-1, 4) + mask = torch.zeros(score.shape) + values, indices = torch.topk(score, 2, dim=1) + rows = torch.arange(mask.size(0)).unsqueeze(-1) + mask[rows, indices] = 1 + mask = mask.view(orig_size) + return mask # dev = score.device, return mask.to(dev) + + @staticmethod + def filter_function(name, modules_to_sparsify_in_block): + if modules_to_sparsify_in_block is None: + return False + return name in modules_to_sparsify_in_block diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py new file mode 100644 index 0000000000..986c5c0107 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides utilities for distributed loading, saving, and manipulation of +large language model checkpoints across multiple GPUs/processes. + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. +""" + +import json +from collections.abc import Iterable +from pathlib import Path +from types import SimpleNamespace +from typing import Literal + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import transformers +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file +from safetensors.torch import save_file as safe_save_file +from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files + +import modelopt.torch.utils.distributed as dist + +from ..utils.dummy_modules import DummyLMHead, DummyWTE +from ..utils.misc import EmptyInitOnDevice +from .checkpoint_utils import load_model_config, load_state_dict +from .checkpoint_utils_hf import _get_auto_class_for_trust_remote_code +from .logger import mprint + +__all__ = [ + "set_submodule", + "load_and_shard_model", + "create_sharded_model", + "load_sharded_state_dict", + "is_in_safetensors_format", +] + + +def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: + """Set a submodule on a model by dotted path.""" + parts = module_name.split(".") + parent_path = ".".join(parts[:-1]) + attr = parts[-1] + parent_module = model.get_submodule(parent_path) if parent_path else model + setattr(parent_module, attr, new_submodule) + + +def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): + # Get language model config (handles nested configs like Qwen3-VL's text_config) + lm_config = descriptor.get_language_model_config(model.config) + all_block_indexes = set(range(lm_config.num_hidden_layers)) + has_first_block = 0 in owned_block_indexes + has_last_block = max(all_block_indexes) in owned_block_indexes + + unowned_block_indexes = all_block_indexes - owned_block_indexes + for block_index in unowned_block_indexes: + decoder_layer_name = descriptor.layer_block_name(block_index) + decoder_layer = model.get_submodule(decoder_layer_name) + set_submodule( + model, + decoder_layer_name, + descriptor.create_dummy_block(decoder_layer, block_index=block_index), + ) + + # If we have the last block with tied embeddings, keep embed_tokens so lm_head works. + # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file, + # and since they're tied, lm_head.weight gets populated too. + if not has_first_block and not (has_last_block and model.config.tie_word_embeddings): + set_submodule( + model, + descriptor.input_embedding_name(), + DummyWTE(lm_config.hidden_size, dtype=runtime.dtype), + ) + + if not has_last_block: + set_submodule(model, descriptor.final_norm_name(), nn.Identity()) + if not (model.config.tie_word_embeddings and has_first_block): + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(lm_config)) + + return model + + +def _get_model_class_from_config(config: PretrainedConfig): + """ + Get the model class from config.architectures field. + Works for any model registered in transformers (CausalLM, VL models, etc.). + Falls back to AutoModelForCausalLM if architectures is not available. + """ + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + +def load_and_shard_model( + descriptor, + checkpoint_path: str | Path, + owned_block_indexes: set[int] | Literal["auto"] = "auto", + model_config: PretrainedConfig | None = None, +): + checkpoint_path = Path(checkpoint_path) + runtime = SimpleNamespace( + device=torch.device(dist.local_rank()), + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + use_autocast=True, # Default: use autocast; descriptor can override + ) + + with runtime.device: + if model_config is None: + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(checkpoint_path, trust_remote_code=trust_remote_code) + + num_hidden_layers = descriptor.get_language_model_config(model_config).num_hidden_layers + if owned_block_indexes == "auto": + owned_block_indexes = set( + np.array_split(np.arange(num_hidden_layers), runtime.world_size)[ + runtime.global_rank + ] + ) + + mprint("Initializing model shards") + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + from ..anymodel.puzzformer import deci_x_patcher + + with deci_x_patcher( + model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None) + ): + model_shard = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) + + if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( + checkpoint_path / SAFE_WEIGHTS_INDEX_NAME + ).exists(): + mprint("Loading shard state_dict from safetensors") + shard_keys = [ + *[name for name, _ in model_shard.named_parameters()], + *[name for name, _ in model_shard.named_buffers()], + ] + shard_state_dict = load_sharded_state_dict( + model_name_or_path=str(checkpoint_path), + keys_to_load=shard_keys, + device=runtime.device, + ) + + # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B) + model_shard.load_state_dict(shard_state_dict, strict=False, assign=True) + + del shard_state_dict + + # Re-tie weights after load_state_dict with assign=True, which severs the tie. + # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). + has_first_block = 0 in owned_block_indexes + has_last_block = (num_hidden_layers - 1) in owned_block_indexes + if model_config.tie_word_embeddings and (has_first_block or has_last_block): + model_shard.tie_weights() + + # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_() + # just to load the weight and tie it to lm_head. Now replace it with a dummy so it + # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens). + if model_config.tie_word_embeddings and has_last_block and not has_first_block: + set_submodule( + model_shard, + descriptor.input_embedding_name(), + DummyWTE(model_config.hidden_size, dtype=runtime.dtype), + ) + else: + mprint("Loading state_dict in main process") + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + + mprint("Distributing model to shards") + load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) + del state_dict + + descriptor.init_rotary_embedding(model_shard, runtime) + + model_shard.type(runtime.dtype) + + # Configure autocast based on model descriptor (some models like Qwen3-VL MoE + # have dtype bugs under autocast) + runtime.use_autocast = descriptor.uses_autocast() + + params_on_meta_device = [ + param_name + for param_name, param in model_shard.named_parameters() + if param.device == torch.device("meta") + ] + assert len(params_on_meta_device) == 0, ( + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + ) + + return model_shard + + +def create_sharded_model( + runtime, + descriptor, + model_config: PretrainedConfig, + owned_block_indexes: set[int], + device: str | torch.device | None = "meta", + dtype: torch.dtype | None = torch.float32, +): + if isinstance(device, str): + device = torch.device(device) + + dist.barrier() + + with EmptyInitOnDevice(device="meta", dtype=dtype): + # Get model class from config.architectures (works for CausalLM, VL models, etc.) + model_class = _get_model_class_from_config(model_config) + trust_remote_code = descriptor.requires_trust_remote_code() + if trust_remote_code: + auto_cls = _get_auto_class_for_trust_remote_code(model_config) + model = auto_cls.from_config(model_config, trust_remote_code=trust_remote_code) + elif model_class is AutoModelForCausalLM: + model = AutoModelForCausalLM.from_config(model_config) + else: + model = model_class._from_config(model_config) + create_local_shard_( + model=model, + owned_block_indexes=owned_block_indexes, + descriptor=descriptor, + runtime=runtime, + ) + + if device != torch.device("meta"): + local_shard_state_dict = { + k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() + } + model.load_state_dict(local_shard_state_dict, assign=True) + + return model + + +def load_state_dict_to_shards( + model_shard: torch.nn.Module, loaded_state_dict: dict | None = None +) -> None: + from ..sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + + model_shard.to("meta") + local_state_dict_keys = list(model_shard.state_dict().keys()) + + if dist.is_master(): + gathered_state_dict_keys = [None] * dist.size() + torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) + + assert loaded_state_dict is not None + loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()} + + works: list[torch.distributed.Work] = [] + for i, shard_keys in enumerate(gathered_state_dict_keys[1:]): + process_id = i + 1 + shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys} + process_works = distributed_isend_obj(shard_state_dict, process_id) + works.extend(process_works) + + for work in works: + work.wait() + + shard_state_dict = { + k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys + } + else: + torch.distributed.gather_object(local_state_dict_keys) + shard_state_dict = distributed_recv_obj() + + print(f"{dist.rank()} loaded state_dict shard") + + missing_keys, unexpected_keys = model_shard.load_state_dict( + shard_state_dict, strict=False, assign=True + ) + assert len(unexpected_keys) == 0 + assert all("dummy_param" in key for key in missing_keys) + + model_shard.cuda(dist.local_rank()) + + dist.barrier() + + +def save_sharded_model( + model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path +): + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ + dist.barrier() + + if isinstance(model_shard, torch.nn.Module): + shard_state_dict = model_shard.state_dict() + elif isinstance(model_shard, dict): + shard_state_dict = model_shard + else: + raise ValueError(f"Unrecognized model shard type: {type(model_shard)}") + + shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()} + total_shard_size = sum( + weight.numel() * weight.element_size() for weight in shard_state_dict.values() + ) + + num_shards = dist.size() + idx = dist.rank() + + out_path = Path(out_path) + shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") + + shard_metadata = { + "total_shard_size": total_shard_size, + "shard_keys": list(shard_state_dict.keys()), + "shard_file": str(shard_file), + } + + if dist.is_master(): + shard_metadatas = [{} for _ in range(dist.size())] + torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) + total_size = sum(x["total_shard_size"] for x in shard_metadatas) + metadata = {"total_size": total_size} + weight_map: dict[str, str] = {} + for shard_metadata in shard_metadatas: + weight_map.update( + {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]} + ) + + index = {"metadata": metadata, "weight_map": weight_map} + index_path = Path(str(out_path) + ".index.json") + index_path.write_text(json.dumps(index, indent=2)) + + else: + torch.distributed.gather_object(shard_metadata, dst=0) + + if out_path.suffix == ".safetensors": + safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, shard_file) + + dist.barrier() + + +def load_sharded_state_dict( + model_name_or_path: str | Path, + keys_to_load: Iterable[str] | None = None, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor]: + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ + shard_paths = _resolve_shard_paths(model_name_or_path) + # print(f"shard_paths: {shard_paths}") + partial_state_dict = {} + for safetensors_path in shard_paths: + if keys_to_load is None: + shard = safe_load_file(safetensors_path) + partial_state_dict.update(shard) + else: + with safe_open(safetensors_path, framework="pt", device=str(device)) as f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable + if key in keys_to_load: + partial_state_dict[key] = f.get_tensor(key) + return partial_state_dict + + +def _resolve_shard_paths(model_name_or_path: str) -> list[str]: + try: + unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME) + return [unsharded_path] + except OSError: + index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path) + return shard_paths + + +def is_in_safetensors_format(checkpoint_dir: Path) -> bool: + return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py new file mode 100644 index 0000000000..b5d997286f --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +the loss, and optionally registers hooks to capture the inputs and the outputs +of pytorch modules that are used for activation scoring for pruning. + +TODO: Consider moving this a separate module dedicated for scoring + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. +""" + +import textwrap +from pathlib import Path +from typing import Type + +import torch +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist + +from ..activation_scoring.activation_hooks import register_activation_hooks +from ..anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from ..anymodel.puzzformer import Same +from ..utils.data import create_validation_dataloader +from ..utils.parsing import simple_parse_args_string # noqa: F401 (kept for backwards compat) +from ..utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline +from .common import resolve_torch_dtype +from .logger import aprint, mprint +from .sharded_checkpoint_utils import load_and_shard_model, set_submodule + +__all__ = ["validate_model", "prepare_model", "prepare_dataloader"] + +""" +Two goals: +1) Calculate lm loss and token accuracy for a model. +May raise lots of NCCL warnings when it finishes, don't be alarmed. +Can be used to validate a HuggingFace model. +Automatically uses pipeline parallelism via device_map="auto". + +2) Register hooks to capture the inputs and the outputs of pytorch modules. +For example, to collect activations scores for various layers (ffn, layer_norm, etc.) +that are used for pruning (ffn_hidden_size, embedding_pruning, etc). +See activations_log_dir and activation_hooks_kwargs arguments. +""" + + +@torch.no_grad() +def validate_model( + args: DictConfig, + model: PreTrainedModel | None = None, + tokenizer: PreTrainedTokenizerBase | None = None, + target_hidden_states_per_batch: list[torch.Tensor] | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + val_dataloader: DataLoader | None = None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """Validate a language model on a dataset by calculating loss and optionally capturing activations. + + Args: + args: Configuration object containing the following attributes: + + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. + + model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. + tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. + target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. + return_hidden_states: Whether to return hidden states from the model. + calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. False calculates only a small suite for efficiency. + val_dataloader: Pre-created validation dataloader. If None, will be created from args. + + Returns: + A tuple containing: + - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). + - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + + Returns (None, None) if not on master rank. + """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + + if val_dataloader is None: + val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None + validation_full_iters = ( + args.eval_samples // args.micro_batch_size + ) # model pipeline, single data rank + + model = prepare_model(args, descriptor=descriptor, model=model) + + just_model_forward = False + checkpoint_manager = None + activation_hooks = None + + if args.activations_log_dir is not None: + activation_hooks_kwargs = args.activation_hooks_kwargs or {} + activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + hook_class = args.hook_class + + # Create activation hooks using pruning mixin + activation_hooks = register_activation_hooks( + model=model, + activation_hooks_kwargs=activation_hooks_kwargs, + hook_class=hook_class, + pruning_mixin=args.pruning_mixin, + ) + + # Create checkpoint manager with hooks + from ..utils.checkpoint_manager import ScoringCheckpointManager + + mprint( + f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" + ) + checkpoint_manager = ScoringCheckpointManager( + checkpoint_dir=args.activations_log_dir, + activation_hooks=activation_hooks, + checkpoint_interval=50, # Save every 50 batches + ) + + # Load existing checkpoint if available + mprint("Attempting to load existing checkpoint...") + checkpoint_data = checkpoint_manager.load_checkpoint() + if checkpoint_data: + mprint(f"Checkpoint loaded successfully: {checkpoint_data}") + else: + mprint("No checkpoint found, starting fresh") + just_model_forward = True + set_submodule(model, descriptor.output_embedding_name(), Same()) + + losses, hidden_states_per_batch = calculate_losses_pipeline( + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + autocast_dtype=resolve_torch_dtype(getattr(args, "autocast_dtype", "torch.bfloat16")), + descriptor=descriptor, + use_autocast=descriptor.uses_autocast(), + ) + + if losses is not None: + avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} + + results_str = f""" + validate_model: + {args.model_name_or_path=} + Average losses = {avg_losses} + Actual num samples = {len(next(iter(losses.values()))["per_sample"])} + {args=} + """ + results_str = textwrap.dedent(results_str) + aprint(results_str) + if args.write_results: + Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) + + if activation_hooks is not None: + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args) + + return losses, hidden_states_per_batch + + +def prepare_model( + args: DictConfig, + descriptor: Type[ModelDescriptor], + model: PreTrainedModel | None = None, +) -> nn.Module: + if model is None: + assert args.model_name_or_path is not None + model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path) + + model.eval() + return model + + +def prepare_dataloader( + args: DictConfig, tokenizer: PreTrainedTokenizerBase | None = None +) -> DataLoader: + if tokenizer is None: + tokenizer_name = getattr(args, "tokenizer_name", None) + assert (tokenizer_name is not None) or (args.model_name_or_path is not None) + # Auto-detect trust_remote_code from the descriptor when not explicitly set. + # Required for models like NemotronH v2 whose configs use characters (e.g. '-') that + # the native transformers NemotronHConfig._pattern_to_list doesn't support. + trust_remote_code = getattr(args, "trust_remote_code", False) + if not trust_remote_code and getattr(args, "descriptor", None): + try: + descriptor_cls = ModelDescriptorFactory.get(args.descriptor) + trust_remote_code = descriptor_cls.requires_trust_remote_code() + except Exception: + pass + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name or args.model_name_or_path, + trust_remote_code=trust_remote_code, + ) + + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=args.seed, + tokenizer=tokenizer, + block_size=args.block_size, + dataset=args.dataset_path, + content_field=args.data_column, + fim_rate=args.fim_rate, + fim_spm_rate=args.fim_spm_rate, + micro_batch_size=args.micro_batch_size, + eval_samples=args.eval_samples, + dataset_name=args.val_dataset_name, + source_datasets_to_discard=args.source_datasets_to_discard, + bos_rate=args.bos_rate, + varlen=args.varlen, + shuffle_seed=args.shuffle_seed, + load_dataset_fn=args.load_dataset_fn, + ) + + return val_dataloader diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py new file mode 100644 index 0000000000..d8471aee23 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates puzzle solutions by applying layer replacements and evaluating model performance. + +TODO: Consider moving this a separate module dedicated for scoring +""" + +# mypy: ignore-errors + +import json +import warnings +from functools import partial +from pathlib import Path +from typing import Optional + +import torch +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist + +from ..anymodel.converter import Converter +from ..anymodel.model_descriptor import ModelDescriptorFactory +from ..replacement_library.library import ReplacementLibrary +from ..replacement_library.replacement_utils import parse_layer_replacement +from ..utils.parsing import get_nested_key +from ..utils.validate_runtime_pipeline import perform_pipeline_stitches +from . import validate_model +from .checkpoint_utils import copy_tokenizer +from .checkpoint_utils_hf import save_checkpoint +from .common import resolve_torch_dtype +from .sharded_checkpoint_utils import load_and_shard_model +from .validation_utils import ( + validate_model_and_extract_hidden_states, + validate_model_with_teacher_similarity_metrics, +) + +__all__ = ["validate_puzzle_solutions", "load_puzzle_solutions"] + +""" +Usage Example: +============== + +Validate single_block_replacement_solutions by calling validate_puzzle_solutions() directly +with an args object containing the required attributes. See the function docstring for details. + +""" + + +@torch.no_grad() +def validate_puzzle_solutions(args: DictConfig) -> None: + """Validate puzzle solutions by applying layer replacements and evaluating model performance. + + Args: + args: Configuration object containing the following attributes: + + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + + Returns: + None. Saves validation results and optionally model checkpoints to disk. + """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + + puzzle_solutions = load_puzzle_solutions( + args.solutions_path, args.sort_solutions_by, args.bigger_is_better + ) + if args.solutions_to_validate is None: + args.solutions_to_validate = list(range(len(puzzle_solutions))) + puzzle_solutions = [puzzle_solutions[i] for i in args.solutions_to_validate] + + tokenizer = _load_tokenizer(args, trust_remote_code=descriptor.requires_trust_remote_code()) + if not args.skip_validation: + val_dataloader = ( + validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None + ) + + output_dir = ( + args.output_dir + if getattr(args, "output_dir", None) is not None + else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") + ) + + replacement_library = ReplacementLibrary( + args.replacement_library_path, + descriptor=descriptor, + model_config_overrides={"use_cache": False}, + ) + + teacher_hidden_states = None + if (args.teacher_dir is not None) and (not args.skip_validation): + teacher_model = load_and_shard_model( + checkpoint_path=args.teacher_dir, descriptor=descriptor + ) + teacher_model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor) + teacher_hidden_states = validate_model_and_extract_hidden_states( + args, + stitched_model, + tokenizer, + output_dir, + model_name="teacher", + val_dataloader=val_dataloader, + ) + + # Properly release CUDA memory after teacher validation + teacher_model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() + + for i_solution, puzzle_solution in tqdm( + list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" + ): + layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False + model_config = replacement_library.create_model_config(layer_replacements) + if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): + model = replacement_library.load_model(layer_replacements) + model_config = model.config + + if args.save_models: + checkpoint_dir = ( + args.solutions_path.with_name(f"{args.solutions_path.stem}--checkpoints") + / f"solution_{i_solution}" + ) + + model_config.dtype = resolve_torch_dtype(getattr(args, "model_dtype", "torch.bfloat16")) + Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) + if realizable_as_symlinks: + if dist.is_master(): + # TODO: Loo into internal Puzzleron code to see how to save as symlinks + # save_checkpoint_as_symlinks is currently not supported + pass + save_checkpoint(model, checkpoint_dir, descriptor) + + copy_tokenizer( + args.tokenizer_name, + checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + dist.barrier() + + if not args.skip_validation: + model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(model, descriptor=descriptor) + validate_model_with_teacher_similarity_metrics( + args, + stitched_model, + tokenizer, + teacher_hidden_states, + output_dir, + model_name=f"solution_{i_solution}", + extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, + val_dataloader=val_dataloader, + ) + + # Properly release CUDA memory after solution validation + model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + dist.barrier() + + +def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: + for layer_replacement in layer_replacements: + num_parent_layers = len(layer_replacement["parent_layer_indices"]) + num_child_layers = len(layer_replacement["child_block_configs"]) + if num_parent_layers != num_child_layers or num_parent_layers != 1: + return False + return True + + +def _load_tokenizer(args: DictConfig, trust_remote_code: bool = False) -> PreTrainedTokenizerBase: + tokenizer = None + if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code + ) + elif args.teacher_dir is not None: + try: + tokenizer = AutoTokenizer.from_pretrained( + args.teacher_dir, trust_remote_code=trust_remote_code + ) + except Exception: + pass + if tokenizer is None: + warnings.warn("Couldn't find a tokenizer, trying to continue without one") + return tokenizer + + +def _extract_layer_replacements_from_puzzle_solution( + puzzle_solution: dict, +) -> list[dict]: + puzzle_solution = puzzle_solution.get("puzzle_solution", puzzle_solution) + layer_replacements = [ + parse_layer_replacement(rep) for rep in puzzle_solution["chosen_replacements"] + ] + return layer_replacements + + +def load_puzzle_solutions( + solutions_path: Path, + sort_solutions_by: Optional[str], + bigger_is_better: bool, +) -> list[dict]: + assert solutions_path.exists(), f"{solutions_path=} does not exist" + + if solutions_path.is_file(): + puzzle_solutions = json.loads(solutions_path.read_text()) + if isinstance(puzzle_solutions, dict): + puzzle_solutions = [puzzle_solutions] + else: + puzzle_solutions = [ + json.loads(p.read_text()) for p in solutions_path.glob("*solution*.json") + ] + + if len(puzzle_solutions) == 0: + raise ValueError(f"No solutions under {solutions_path=}") + + if sort_solutions_by is not None: + puzzle_solutions = sorted( + puzzle_solutions, key=partial(get_nested_key, field=sort_solutions_by) + ) + if bigger_is_better: + puzzle_solutions = puzzle_solutions[::-1] + vals = [get_nested_key(sol, sort_solutions_by) for sol in puzzle_solutions] + print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}") + + return puzzle_solutions diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py new file mode 100644 index 0000000000..7fd763fbef --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for validating models and extracting hidden states and similarity metrics. + +TODO: Consider moving this a separate module dedicated for scoring. +""" + +# mypy: ignore-errors + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +import torch +from omegaconf import DictConfig, OmegaConf +from torch import nn +from transformers import PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import json_dump + +from ..utils.validation import LowMemorySparseTensor +from . import validate_model +from .logger import mprint + +if TYPE_CHECKING: + from ..sewing_kit import StitchedModule + +__all__ = [ + "validate_model_and_extract_hidden_states", + "validate_model_with_teacher_similarity_metrics", + "write_results", +] + + +def validate_model_and_extract_hidden_states( + args: DictConfig, + model: "nn.Module | StitchedModule", + tokenizer: PreTrainedTokenizerBase, + output_dir: str | Path, + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + val_dataloader=None, +) -> list[torch.Tensor | LowMemorySparseTensor]: + mprint(f""" + +################################################################ +validate_model_and_extract_token_probs({model_name=}) +################################################################ + +""") + losses, hidden_states_per_batch = validate_model.validate_model( + args, + model, + tokenizer, + return_hidden_states=True, + val_dataloader=val_dataloader, + ) + if dist.is_last_process(): + output_dir = output_dir if (output_dir is not None) else args.bypass_dir + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + return hidden_states_per_batch + + +def validate_model_with_teacher_similarity_metrics( + args: DictConfig, + model: "nn.Module | StitchedModule", + tokenizer: PreTrainedTokenizerBase, + target_hidden_states_per_batch: list[torch.Tensor], + output_dir: str | Path, + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + calculate_full_score_ablations: bool = False, + val_dataloader=None, +) -> None: + is_calc_kl_div = target_hidden_states_per_batch is not None + mprint(f""" + +################################################################ +validate_model_with_kl_div({model_name=}, {is_calc_kl_div=}) +################################################################ + +""") + losses, _ = validate_model.validate_model( + args, + model, + tokenizer, + target_hidden_states_per_batch=target_hidden_states_per_batch, + calculate_full_score_ablations=calculate_full_score_ablations, + val_dataloader=val_dataloader, + ) + if dist.is_last_process(): + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + + +def write_results( + output_dir: str | Path, result_name: str, args: DictConfig, payload: dict[str, Any] +) -> None: + output_path = Path(output_dir) / f"{result_name}.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + results = { + **payload, + "args": OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else args.__dict__, + } + json_dump(results, output_path) diff --git a/modelopt/torch/puzzletron/utils/__init__.py b/modelopt/torch/puzzletron/utils/__init__.py new file mode 100644 index 0000000000..7ad661b4e0 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for Puzzletron.""" + +from .checkpoint_manager import * +from .data import * +from .dummy_modules import * +from .misc import * +from .parsing import * +from .validate_runtime_pipeline import * +from .validation import * diff --git a/modelopt/torch/puzzletron/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py new file mode 100644 index 0000000000..e0b90deaea --- /dev/null +++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint manager for activation hook scoring with periodic saves and resume support.""" + +import json +import time +from pathlib import Path +from typing import Any + +import modelopt.torch.utils.distributed as dist + +from ..tools.logger import aprint, mprint + +__all__ = ["ScoringCheckpointManager"] + + +class ScoringCheckpointManager: + """Manages checkpointing for activation hook scoring with periodic saves.""" + + def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100): + """Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + activation_hooks: Dictionary of activation hooks to manage + checkpoint_interval: Save checkpoint every N batches + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.activation_hooks = activation_hooks + self.checkpoint_interval = checkpoint_interval + self.rank = dist.rank() + self.is_main_process = dist.is_master() + + # Debug: Log checkpoint manager initialization + hook_count = len(activation_hooks) if activation_hooks else 0 + aprint( + f"[Rank {self.rank}] Checkpoint manager initialized: {hook_count} hooks, dir: {checkpoint_dir}" + ) + + # Checkpoint files + self.progress_file = self.checkpoint_dir / "scoring_progress.json" + self.hook_states_file = self.checkpoint_dir / f"hook_states_rank_{self.rank}.pth" + + # Progress tracking + self.current_batch = 0 + self.total_batches = 0 + self.start_time = time.time() + + # Ensure directory exists + if self.is_main_process: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def load_checkpoint(self) -> dict[str, Any] | None: + """Load existing checkpoint if available, including hook states. + + Returns: + Dict with checkpoint info or None if no checkpoint exists + """ + aprint(f"[Rank {self.rank}] Looking for checkpoint at: {self.progress_file}") + if not self.progress_file.exists(): + aprint(f"[Rank {self.rank}] No checkpoint file found at {self.progress_file}") + return None + + try: + with open(self.progress_file) as f: + checkpoint_data = json.load(f) + + # Validate checkpoint + if "current_batch" in checkpoint_data and "total_batches" in checkpoint_data: + self.current_batch = checkpoint_data["current_batch"] + self.total_batches = checkpoint_data["total_batches"] + + mprint( + f"Found checkpoint: batch {self.current_batch}/{self.total_batches} ({checkpoint_data.get('progress', 0.0):.1%})" + ) + mprint( + f"Will resume from batch {self.current_batch}, skipping batches 0-{self.current_batch - 1}" + ) + + # Load hook states if hooks are available + if self.activation_hooks is not None: + success = self.load_hook_states(self.activation_hooks) + if success: + aprint( + f"[Rank {self.rank}] Successfully loaded hook states from checkpoint" + ) + else: + aprint(f"[Rank {self.rank}] Failed to load hook states - starting fresh") + + return checkpoint_data + else: + aprint( + f"[Rank {self.rank}] Invalid checkpoint format (missing current_batch/total_batches): {checkpoint_data}" + ) + return None + + except (json.JSONDecodeError, KeyError) as e: + mprint(f"Error loading checkpoint: {e}") + + return None + + def load_hook_states(self, activation_hooks) -> bool: + """Load hook states from checkpoint files. + + Args: + activation_hooks: Hook objects to load states into + + Returns: + bool: True if hook states were successfully loaded, False otherwise + """ + import os + + # Each rank loads only its own hook states + current_rank = int(os.environ.get("RANK", 0)) + hook_states_path = self.checkpoint_dir / f"hook_states_rank_{current_rank}.pth" + + if hook_states_path.exists(): + aprint(f"[Rank {current_rank}] Loading hook states from {hook_states_path}") + try: + import torch + + hook_states = torch.load(hook_states_path, map_location="cpu") + + # Load states into corresponding hooks + loaded_count = 0 + for module_name, hook in activation_hooks.items(): + if module_name in hook_states: + hook.load_state_dict(hook_states[module_name]) + loaded_count += 1 + + # Log progress info if available (only for a few hooks to avoid spam) + if loaded_count <= 3: # Only log first few hooks + progress_info = hook.get_progress_info() + if progress_info: + aprint(f"[Rank {current_rank}] {module_name}: {progress_info}") + else: + aprint( + f"[Rank {current_rank}] Warning: No saved state found for hook: {module_name}" + ) + + aprint( + f"[Rank {current_rank}] Successfully loaded states for {loaded_count}/{len(activation_hooks)} hooks" + ) + return True + + except Exception as e: + aprint(f"[Rank {current_rank}] Error loading hook states: {e}") + return False + else: + aprint(f"[Rank {current_rank}] No hook states file found at {hook_states_path}") + return False + + def should_skip_batch(self, batch_idx: int) -> bool: + """Check if we should skip this batch (already processed in previous run).""" + should_skip = batch_idx < self.current_batch + if should_skip and batch_idx % 10 == 0: # Log every 10th skipped batch to avoid spam + mprint(f"Skipping batch {batch_idx} (resume from batch {self.current_batch})") + return should_skip + + def update_progress(self, batch_idx: int, total_batches: int): + """Update progress and potentially save checkpoint. + + Args: + batch_idx: Current batch index + total_batches: Total number of batches + """ + self.current_batch = batch_idx + self.total_batches = total_batches + + # Save checkpoint periodically or on completion + should_save = ( + (batch_idx % self.checkpoint_interval == 0) # Periodic save + or (batch_idx == total_batches - 1) # Final batch + ) + + if should_save: + # All ranks save their hook states + if self.activation_hooks is not None: + try: + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + + ForwardHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) + except Exception as e: + mprint(f"Warning: Failed to save hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.save_checkpoint() + + # Synchronize all ranks after checkpointing + dist.barrier() + + def save_checkpoint(self): + """Save current checkpoint to disk (progress info only). + Hook states are saved separately in update_progress. + """ + try: + # Save progress + progress_data = { + "current_batch": self.current_batch, + "total_batches": self.total_batches, + "progress": self.current_batch / self.total_batches + if self.total_batches > 0 + else 0.0, + "timestamp": time.time(), + "elapsed_time": time.time() - self.start_time, + "rank": self.rank, + } + + # Write progress atomically + temp_file = self.progress_file.with_suffix(".tmp") + with open(temp_file, "w") as f: + json.dump(progress_data, f, indent=2) + temp_file.replace(self.progress_file) + + # Hook states are saved at a higher level to ensure all ranks participate + + if self.current_batch % (self.checkpoint_interval) == 0: + progress_pct = progress_data["progress"] * 100 + elapsed = progress_data["elapsed_time"] + mprint( + f"Checkpoint saved: batch {self.current_batch}/{self.total_batches} ({progress_pct:.1f}%), elapsed: {elapsed:.1f}s" + ) + + except Exception as e: + mprint(f"Error saving checkpoint: {e}") + + def finalize(self): + """Mark scoring as completed.""" + # All ranks save their final hook states + if self.activation_hooks is not None: + try: + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + + saved_path = ForwardHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir + ) + mprint(f"Final hook states saved to {saved_path}") + except Exception as e: + mprint(f"Warning: Failed to save final hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.current_batch = self.total_batches + self.save_checkpoint() + mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") + + # Synchronize all ranks after finalization + dist.barrier() diff --git a/modelopt/torch/puzzletron/utils/data/__init__.py b/modelopt/torch/puzzletron/utils/data/__init__.py new file mode 100644 index 0000000000..1e0d93744f --- /dev/null +++ b/modelopt/torch/puzzletron/utils/data/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset and dataloader utilities for Puzzletron.""" + +from .dataloaders import * +from .dataset import * diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py new file mode 100644 index 0000000000..f404653149 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataLoader utilities for language model training and validation.""" + +from collections.abc import Callable, Mapping, Sequence +from functools import partial +from typing import Protocol, TypeVar + +import datasets +import torch +import torch.distributed +from accelerate import Accelerator +from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +from ...tools.logger import mprint +from .dataset import ConstantLengthDataset + +__all__ = ["create_validation_dataloader", "create_padded_tensor"] + + +def collate_none_fn( + batch, *, collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None +): + return None + + +collate_fn_map_with_none_support = {**default_collate_fn_map, type(None): collate_none_fn} +collate_fn_with_none_support = partial(collate, collate_fn_map=collate_fn_map_with_none_support) + + +class LoadDatasetFn(Protocol): + def __call__( + self, dataset_path: str, content_field: str, keep_in_memory: bool = False + ) -> Mapping[str, Dataset]: ... + + +def load_from_disk_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + + +def load_streaming_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + dataset = datasets.load_dataset( + dataset_path, + streaming=True, + features=datasets.Features( + { + content_field: datasets.Value(dtype="string"), + } + ), + keep_in_memory=keep_in_memory, + ) + + return dataset + + +def create_validation_dataloader( + accelerator: Accelerator | None, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "__auto__", + keep_in_memory: bool = False, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, + shuffle_seed: int | None = None, +): + if accelerator is None: + accelerator = Printer() + + if accelerator.is_main_process: + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): + valid_data = dataset + mprint( + "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" + ) + else: + assert isinstance(dataset, datasets.DatasetDict) + if dataset_name == "__auto__": + val_split_options = [] + for val_key_prefix in ("val", "test"): + if len(val_split_options) == 0: + val_split_options = [ + split + for split in dataset # DatasetDict is dict-like and supports direct iteration + if split.lower().startswith(val_key_prefix) + ] + assert len(val_split_options) == 1, ( + f"Expected exactly one validation split, got {val_split_options=} ({dataset.keys()=})" + ) + val_split = val_split_options[0] + mprint(f"Inferred validation split automatically: '{val_split}'") + else: + val_split = dataset_name + mprint(f"Validation split explicitly chosen: '{val_split}'") + valid_data = dataset[val_split] + + if shuffle_seed is not None: + mprint(f"Shuffling with {shuffle_seed=}") + valid_data = valid_data.shuffle(seed=shuffle_seed) + + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + infinite=False, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + if varlen and eval_samples is not None: + eval_samples = eval_samples // micro_batch_size + val_offloaded_dataset = realize_dataset_in_memory(valid_dataset, eval_samples) + + valid_data_len = len(val_offloaded_dataset) + mprint(f"num validation examples = {valid_data_len}") + else: + val_offloaded_dataset = None + + if not isinstance(accelerator, Printer): + obj_list = [val_offloaded_dataset] + torch.distributed.broadcast_object_list(obj_list) + val_offloaded_dataset = obj_list[0] + + # let accelerate prepare to handle distributed sampling + val_dataloader = DataLoader( + val_offloaded_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + ) + + return val_dataloader + + +def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None) -> list[dict]: + tqdm_desc = f"realize_dataset_in_memory({eval_samples=})" + if eval_samples is None: + offloaded_dataset = list(tqdm(dataset, desc=tqdm_desc)) + else: + val_iter = iter(dataset) + offloaded_dataset = [next(val_iter) for _ in tqdm(range(eval_samples), desc=tqdm_desc)] + return offloaded_dataset + + +TensorT = TypeVar("TensorT", bound=torch.Tensor) + + +@torch.no_grad() +def create_padded_tensor( + tensor: TensorT, desired_shape: Sequence[int], padding_value: float = 0 +) -> TensorT: + if tensor.shape == torch.Size(desired_shape): + return tensor + + padded_tensor = torch.full( + desired_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device + ) + indices = torch.where(torch.ones_like(tensor, dtype=torch.bool)) + padded_tensor[indices] = tensor.view(-1) + return padded_tensor + + +class Printer: + is_main_process = True + process_index = None + + @staticmethod + def print(*args, **kwargs) -> None: + print(*args, **kwargs) diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py new file mode 100644 index 0000000000..f88e44a234 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import functools +from collections.abc import Sequence + +import numpy as np +import torch +from torch.utils.data import IterableDataset + +__all__ = [ + "FIM_TOKEN_START", + "CODEGEN_FIM_TOKENS", + "ConstantLengthDataset", + "permute", + "get_fim_token_ids", +] + +FIM_TOKEN_START = "", "middle>", "suffix>", "pad>"] +CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] + + +class ConstantLengthDataset(IterableDataset): + """Iterable dataset that returns constant length chunks of tokens from stream of text files. + + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. + fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. + seed (int): Seed for random number generator. + label_shift (bool): Whether to shift labels by 1 or not. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + fim_rate=0.5, + fim_spm_rate=0.5, + seed=0, + label_shift=True, + max_sample_length=200_000, + tokens_field="token_ids", + source_datasets_to_discard: Sequence[str] | None = tuple(), + bos_rate: float = 1.0, + return_cu_seqlens: bool = False, + seqlen_cap: int | None = None, + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + # self.concat_token_id = tokenizer.eos_id # for lit-lamma tokenizer + self.dataset = dataset + self.is_dataset_already_tokenized = tokens_field in self.dataset.column_names + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + if not self.is_dataset_already_tokenized: + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.max_sample_length = max_sample_length + else: + self.max_buffer_size = seq_length * num_of_sequences + # self.max_sample_length = int(max_sample_length / chars_per_token) + self.max_sample_length = max_sample_length # we don't know the exact chars_per_token + self.content_field = content_field + self.tokens_field = tokens_field + self.fim_rate = fim_rate + self.fim_spm_rate = fim_spm_rate + self.seed = seed + self.max_sample_length = max_sample_length + + self.fim_token_ids = get_fim_token_ids(self.tokenizer) + if None in self.fim_token_ids.values() and self.fim_rate > 0: + self.fim_rate = 0 + self.label_shift = label_shift + self.bos_rate = bos_rate + self.source_datasets_to_discard = ( + source_datasets_to_discard if source_datasets_to_discard is not None else tuple() + ) + self.return_cu_seqlens = return_cu_seqlens + self.seqlen_cap = seqlen_cap + self.np_rng = np.random.RandomState(seed=self.seed) + + def __iter__(self) -> dict[str, torch.Tensor]: + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + sample = next(iterator) + if ( + len(self.source_datasets_to_discard) > 0 + and sample["dataset_name"] in self.source_datasets_to_discard + ): + continue + if not self.is_dataset_already_tokenized: + sample = sample[self.content_field] + if ( + isinstance(sample, list) + and isinstance(sample[0], dict) + and {"content", "role"}.issubset(sample[0]) + ): + if len(sample) > 1: + sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + else: + sample = sample[0]["content"] + else: + sample = sample[self.tokens_field] + sample = sample[: self.max_sample_length] + buffer.append(sample) + buffer_len += len(sample) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + + if not self.is_dataset_already_tokenized: + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + else: + tokenized_inputs = buffer + + all_token_ids = [] + + for tokenized_input in tokenized_inputs: + if ( + self.bos_rate < 1.0 + and not self.np_rng.binomial(1, self.bos_rate) + and self.tokenizer.bos_token_id is not None + and tokenized_input[0] == self.tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] + # optionally do FIM permutations + if self.fim_rate > 0: + tokenized_input, np_rng = permute( + sample=tokenized_input, + np_rng=self.np_rng, + fim_token_ids=self.fim_token_ids, + fim_rate=self.fim_rate, + fim_spm_rate=self.fim_spm_rate, + truncate_or_pad=False, + ) + + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + examples = [] + # cuts code snippets in the middle to yield constant length instances + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + labels = all_token_ids[ + i + int(self.label_shift) : i + int(self.label_shift) + self.seq_length + ] + # ignores last short example in the buffer + if len(labels) == self.seq_length: + examples.append((input_ids, labels)) + + shuffling_indices = self.np_rng.permutation(len(examples)) + examples = [examples[i] for i in shuffling_indices] + + for input_ids, labels in examples: + self.current_size += 1 + input_ids = torch.LongTensor(input_ids) + if self.return_cu_seqlens: + cu_seqlens = self.prepare_cu_seqlens(input_ids) + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + "cu_seqlens": cu_seqlens, + } + else: + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + } + + def prepare_cu_seqlens(self, input_ids): + if not self.return_cu_seqlens: + return None + # seqlens is of shape (num_seqs+1,) and with the property that + # the i-th sequnce is input_ids[seqlens[i-1]:seqlens[i]] + cu_seqlens = (input_ids == self.concat_token_id).nonzero().squeeze(-1).int() + 1 + cu_seqlens = torch.cat( + ( + torch.IntTensor([0]), + cu_seqlens, + torch.IntTensor([len(input_ids)]), + ) + ) + if self.seqlen_cap is not None: + i = 1 + while i < len(cu_seqlens): + curr_seqlen = cu_seqlens[i] - cu_seqlens[i - 1] + if curr_seqlen > self.seqlen_cap: + cu_seqlens = torch.cat( + (cu_seqlens[:i], cu_seqlens[[i - 1]] + self.seqlen_cap, cu_seqlens[i:]) + ) + i += 1 + if cu_seqlens[-1] == cu_seqlens[-2]: + cu_seqlens = cu_seqlens[:-1] + return cu_seqlens + + +## Adapted from https://github.com/NVIDIA/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py +def permute( + sample, + np_rng, + fim_token_ids, + fim_rate=0.5, + fim_spm_rate=0.5, + truncate_or_pad=False, +): + """Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: + PSM and SPM (with a probability of fim_spm_rate). + """ + if np_rng.binomial(1, fim_rate): + boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2)) + boundaries.sort() + + prefix = np.array(sample[: boundaries[0]], dtype=np.int64) + middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64) + suffix = np.array(sample[boundaries[1] :], dtype=np.int64) + + if truncate_or_pad: + raise NotImplementedError + + if "" in fim_token_ids: # use codegen FIM pattern + assert fim_spm_rate == 0 + new_sample = np.concatenate( + [ + prefix, + [fim_token_ids[""]], + suffix, + [fim_token_ids["<|endoftext|>"]], + [fim_token_ids[""]], + [fim_token_ids[""]], + middle, + ] + ) + elif np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"], fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + prefix, + middle, + ] + ) + else: + # PSM + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"]], + prefix, + [fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + middle, + ] + ) + else: + # don't do FIM preproc + new_sample = sample + + return list(new_sample), np_rng + + +# this is expensive so we cache it +@functools.lru_cache(maxsize=None) +def get_fim_token_ids(tokenizer): + # ugly fix for Salesforce/codegen25-7b-multi tokenizer + if hasattr(tokenizer, "encoder"): + search_vocab = tokenizer.encoder._special_tokens + fim_token_ids = {tok: search_vocab.get(tok, None) for tok in CODEGEN_FIM_TOKENS} + else: + search_vocab = tokenizer.vocab + if (FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + FIM_TOKEN_END_LIST[0]) in search_vocab: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + else: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_SANTA + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + fim_token_ids = { + "suffix_tok_id": suffix_tok_id, + "prefix_tok_id": prefix_tok_id, + "middle_tok_id": middle_tok_id, + "pad_tok_id": pad_tok_id, + } + return fim_token_ids diff --git a/modelopt/torch/puzzletron/utils/dummy_modules.py b/modelopt/torch/puzzletron/utils/dummy_modules.py new file mode 100644 index 0000000000..14f13b08f5 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/dummy_modules.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import override + +__all__ = ["DummyModule", "DummyBlock", "DummyWTE", "DummyLMHead"] + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, + incompatible_keys: torch.nn.modules.module._IncompatibleKeys, + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, block_index: int): + super().__init__() + self.block_index = block_index + + @override + def forward( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, None]: + return x + + +class DummyWTE(DummyModule): + def __init__(self, hidden_size: int, dtype: Optional[torch.dtype] = None): + super().__init__() + self.n_embd = hidden_size + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result diff --git a/modelopt/torch/puzzletron/utils/misc.py b/modelopt/torch/puzzletron/utils/misc.py new file mode 100644 index 0000000000..68751d1e07 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/misc.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json +import os +from copy import deepcopy +from typing import Any + +import torch + +from ..block_config import AttentionConfig, BlockConfig, FFNConfig + +__all__ = [ + "calculate_kv_dim", + "raise_unknown_subblock_config_error", + "sizeof_dtype", + "load_json", + "solution_to_str", + "block_config_to_str", + "subblock_config_to_str", + "EmptyInitOnDevice", +] + + +def calculate_kv_dim(num_key_value_heads: int, n_head: int, n_embd: int) -> int: + """Calculate the key-value dimension for grouped-query attention. + + Args: + num_key_value_heads: Number of key-value heads. + n_head: Total number of attention heads. + n_embd: Embedding dimension. + + Returns: + Combined dimension for key and value tensors (2 * num_key_value_heads * head_size). + """ + if num_key_value_heads is None: + return 0 + head_size = n_embd // n_head + kv_dim = 2 * num_key_value_heads * head_size + return kv_dim + + +def raise_unknown_subblock_config_error(subblock_config: Any) -> None: + """Raise an error for invalid subblock configuration types. + + TODO: Consider a better place for this function. + Args: + subblock_config: The invalid subblock configuration object. + + Raises: + ValueError: Always raised with a message indicating the expected types. + """ + raise ValueError( + f"subblock_config should be an instance of FFNConfig or AttentionConfig, instead got {type(subblock_config)}" + ) + + +def sizeof_dtype(dtype: torch.dtype) -> int | float: + """Return the size in bytes of the given data type. + + TODO: Consider a better place for this function. + Args: + dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). + + Returns: + Size in bytes of the data type. Special case: 'nvfp4' returns ~0.588 bytes. + """ + if dtype == "nvfp4": + return 1 / 1.7 + return torch.tensor([], dtype=dtype).element_size() + + +def load_json(file_path: str): + """Load and parse a JSON file. + + TODO: Consider a better place for this function. + + Args: + file_path: Path to the JSON file to load. + + Returns: + Parsed JSON data as a Python object, or None if the file doesn't exist. + """ + if not os.path.exists(file_path): + print("file does not exist {file_path}") + return None + + with open(file=file_path) as f: + return json.load(f) + + +def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: + """Convert a list of block configurations to a human-readable string representation. + + TODO: Consider a better place for this function. + Better place for this and subsequent related function would be in __repr__ function in class + BlockConfig so when we print it or do str(block_config), it automatically + prints in this custom formatted string + + Args: + block_configs: List of BlockConfig dataclasses or dicts containing layer configurations. + + Returns: + Multi-line string with each block's configuration on a separate line. + """ + block_configs = deepcopy(block_configs) + reps = [] + for block_idx, block_config in enumerate(block_configs): + rep = f"block_{block_idx}:".ljust(9) + rep += block_config_to_str(block_config) + reps.append(rep) + rep = "\n".join(reps) + "\n" + return rep + + +def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: + """ + Convert a BlockConfig to a human-readable string representation. + + TODO: Consider a better place for this function. + Args: + block_config: BlockConfig dataclass or dict containing attention and ffn configs. + + Returns: + Formatted string with attention and FFN information, or None if input is None. + """ + if block_config is None: + return None + rep = "" + if dataclasses.is_dataclass(block_config): + block_config = dataclasses.asdict(block_config) + for subblock_name in ["attention", "ffn"]: + subblock_config = block_config[subblock_name] + rep += subblock_config_to_str(subblock_config, subblock_name) + return rep + + +# TODO: Consider a better place for this function. +def subblock_config_to_str( + subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None, + subblock_name: None | str = None, +) -> str | None: + """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. + + Args: + subblock_config: FFNConfig, AttentionConfig dataclass or dict. + subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). + Auto-detected if subblock_config is a dataclass. + + Returns: + Formatted string showing subblock type and key parameters (e.g., intermediate_size, + num_key_value_heads), or None if input is None. + """ + if subblock_config is None: + return None + subblock_name = ( + "ffn" + if isinstance(subblock_config, FFNConfig) + else "mamba" + if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba + else "attention" + if isinstance(subblock_config, AttentionConfig) + else subblock_name + ) + assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict." + + if dataclasses.is_dataclass(subblock_config): + subblock_config = dataclasses.asdict(subblock_config) + + if subblock_name == "attention" and subblock_config.get("mamba") is not None: + subblock_name = "mamba" + + if subblock_name == "ffn" and subblock_config.get("moe") is not None: + subblock_name = "moe" + + rep = f" {subblock_name}" + if subblock_config.get("no_op"): + rep += " no_op".ljust(8) + elif subblock_config.get("replace_with_linear"): + rep += " linear".ljust(8) + elif subblock_name == "ffn": + intermediate_size = subblock_config["intermediate_size"] + rep += f" intermediate_{intermediate_size}".ljust(8) + elif subblock_name == "attention": + num_key_value_heads = subblock_config["num_key_value_heads"] + rep += f" kv_heads_{num_key_value_heads}".ljust(8) + elif subblock_name == "mamba": + mamba_num_heads = subblock_config["mamba"]["num_heads"] + mamba_head_dim = subblock_config["mamba"]["head_dim"] + rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8) + elif subblock_name == "moe": + moe_num_local_experts = subblock_config["moe"]["num_local_experts"] + moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"] + shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"] + num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"] + rep += f" num_experts_{moe_num_local_experts} expert_intermediate_dim_{moe_expert_intermediate_dim} shared_expert_intermediate_dim_{shared_expert_intermediate_dim} num_experts_per_tok_{num_experts_per_tok}".ljust( + 8 + ) + else: + raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.") + + return rep + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None): + """Create tensors with given device and dtype using uninitialized memory. + + Args: + device: ``torch.device`` to work with. + dtype: ``torch.dtype`` to work with. + + Example:: + + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA(model_config) + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth")) + """ + + self.device = device + self.dtype = dtype + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs) diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py new file mode 100644 index 0000000000..149563b432 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parsing and formatting utilities for configuration handling in model compression. + +This module provides utilities for: +- Parsing command-line arguments and configuration strings +- Formatting and displaying model configurations (block configs, attention, FFN) +- Formatting loss metrics for logging and visualization +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any + +import torch +from omegaconf import DictConfig + +__all__ = [ + "handle_arg_string", + "simple_parse_args_string", + "parse_json", + "parse_path", + "get_nested_key", + "format_global_config", + "format_stitched_losses", +] + + +def handle_arg_string(arg): + if arg.lower() == "true": + return True + elif arg.lower() == "false": + return False + elif arg.isnumeric(): + return int(arg) + try: + return float(arg) + except ValueError: + return arg + + +def simple_parse_args_string(args_string): + """Parse ``args1=val1,arg2=val2`` into a dictionary.""" + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = {k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]} + return args_dict + + +def parse_json(s: str | None) -> Any: + if s is None: + return None + return json.loads(s) + + +def parse_path(s: str | None) -> Path | None: + if s is None or s == "": + return None + return Path(s) + + +def parse_dtype(dtype_name: str) -> torch.dtype: + dtype = { + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + }[dtype_name] + return dtype + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """ + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def format_block_configs(config) -> str: + """ + Formats block_configs from a model configuration into a beautiful, readable string. + + Each line represents a layer with attention and FFN configuration. + + Args: + config: PretrainedConfig object containing block_configs + + Returns: + Formatted string with layer configurations + + Example output: + ╭─────────────────────── Model Architecture ────────────────────────╮ + │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ + │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ + │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + if not hasattr(config, "block_configs") or not config.block_configs: + return "❌ No block configs found" + + lines = [] + + # Header + header = "╭─────────────────────────────────────── Model Architecture ────────────────────────────────────────╮" + lines.append(header) + + # Format each layer + for i, block in enumerate(config.block_configs, 1): + attention_info = _format_attention_config(block.attention) + ffn_info = _format_ffn_config(block.ffn) + + # Create formatted line with proper padding + layer_str = f"Layer {i:2d}" + attention_str = f"Attention: {attention_info}" + ffn_str = f"FFN: {ffn_info}" + + line = f"│ {layer_str:8s} │ {attention_str:30s} │ {ffn_str:18s} │" + lines.append(line) + + # Footer + footer = "╰────────────────────────────────────────────────────────────────────────────────────────────────────╯" + lines.append(footer) + + return "\n".join(lines) + + +def _format_attention_config(attention_config) -> str: + """Format attention configuration for display with visual indicators.""" + if not attention_config: + return "default" + + if attention_config.no_op: + return "❌ no_op" + + num_kv_heads = attention_config.num_key_value_heads + if num_kv_heads is not None: + return f"{num_kv_heads} kv heads" + + if attention_config.replace_with_linear: + return "linear replacement" + + # Check for other attention types + if attention_config.mamba: + return "🐍 mamba" + if attention_config.llama4: + return "🦙 llama4" + + window_length = attention_config.window_length + if window_length is not None: + return f"windowed ({window_length})" + + if attention_config.sparsify: + return "sparse" + + return "default" + + +def _format_ffn_config(ffn_config) -> str: + """Format FFN configuration for display with visual indicators.""" + if not ffn_config: + return "default" + + if ffn_config.no_op: + return "❌ no_op" + + if ffn_config.replace_with_linear: + return "linear" + + ffn_intermediate = ffn_config.intermediate_size + if ffn_intermediate is not None: + return f"ffn_intermediate = {ffn_intermediate}" + + # Check for MoE configuration + moe_config = ffn_config.moe + if moe_config: + return "MoE" + + if ffn_config.sparsify: + return "sparse" + + return "default" + + +def format_global_config(config: DictConfig, title: str = "Global Configuration") -> str: + """ + Pretty prints a global DictConfig with nice formatting and visual indicators. + + Args: + config: DictConfig object to format + title: Title to display at the top of the formatted output + + Returns: + Formatted string with configuration details + + Example output: + ╭─────────────────── Global Configuration ────────────────────╮ + │ Training │ + │ • learning_rate: 1e-4 │ + │ • batch_size: 32 │ + │ • epochs: 100 │ + │ Model │ + │ • hidden_dim: 512 │ + │ • num_layers: 6 │ + │ Data │ + │ • dataset_path: /path/to/data │ + │ • block_size: 2048 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not config: + return "❌ No configuration found" + + lines = [] + + # Calculate box width based on title + box_width = max(60, len(title) + 10) + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"\n╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + lines.extend([header, title_line]) + + def _format_value(value: Any, indent: int = 0) -> str: + """Format a value with appropriate type indicators.""" + prefix = " " * indent + + if isinstance(value, (bool, int, float)): + return f"{prefix} {value}" + elif isinstance(value, str): + # Show truncated long strings + if len(value) > 50: + return f"{prefix} {value[:47]}..." + return f"{prefix} {value}" + elif isinstance(value, (list, tuple)): + if not value: + return f"{prefix} []" + elif len(value) <= 3: + return f"{prefix} {list(value)}" + else: + return f"{prefix} [{len(value)} items]" + elif value is None: + return f"{prefix} None" + else: + return f"{prefix} {value!s}" + + def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0): + """Recursively add configuration sections.""" + if section_name: + indent_str = " " * indent + section_line = f"│ {indent_str}{section_name}" + # Pad to box width + padding_needed = box_width - len(section_line) - 1 + section_line += " " * padding_needed + "│" + lines.append(section_line) + + for key, value in cfg.items(): + if isinstance(value, DictConfig): + # Nested configuration section + _add_config_section(value, f"{key}", indent + 1) + else: + # Regular key-value pair + indent_str = " " * (indent + 1) + value_str = _format_value(value).replace(" " * 0, "").strip() + line = f"│ {indent_str} {key}: {value_str}" + # Pad to box width + if len(line) >= box_width - 1: + # Truncate long lines + line = line[: box_width - 4] + "..." + padding_needed = box_width - len(line) - 1 + line += " " * padding_needed + "│" + lines.append(line) + + # Add configuration sections + _add_config_section(config) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) + + +def format_stitched_losses( + losses_dict: dict[str, float], + best_steps_dict: dict[str, int] | None = None, + best_values_dict: dict[str, float] | None = None, + step_number: int | None = None, + title: str = "Stitched Module Losses", +) -> str: + """ + Pretty prints stitched module losses with comprehensive tracking and visual indicators. + + Args: + losses_dict: Dictionary with block names as keys and current loss values as floats + best_steps_dict: Optional dictionary with block names as keys and best step numbers as values + best_values_dict: Optional dictionary with block names as keys and best loss values as floats + step_number: Optional current step number to include in summary + title: Title to display at the top of the formatted output + + Returns: + Formatted string with loss values in a comprehensive table format + + Example output: + ╭─────────────────── Stitched Module Losses ──────────────────╮ + │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ + │───────┼────────────┼───────────┼────────────┼──────────────────│ + │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ + │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ + │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not losses_dict: + return "❌ No losses found" + + lines = [] + + # Calculate statistics + loss_values = list(losses_dict.values()) + max_loss = max(loss_values) + min_loss = min(loss_values) + avg_loss = sum(loss_values) / len(loss_values) + + # Calculate box width for new layout (removed Bar column) + box_width = 74 + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + separator = ( + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " + f"{'Best Value':<12} │ {'Change from avg':<18} │" + ) + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + + lines.extend([header, title_line, separator, divider]) + + # Format each loss + for block_name, loss_value in losses_dict.items(): + # Format current loss value + loss_str = f"{loss_value:.2e}" + + # Format best step + if best_steps_dict and block_name in best_steps_dict: + best_step_str = f"Step {best_steps_dict[block_name]}" + else: + best_step_str = " --" + + # Format best value + if best_values_dict and block_name in best_values_dict: + best_value = best_values_dict[block_name] + best_value_str = f"{best_value:.2e}" + else: + best_value = loss_value # Assume current is best if no history + best_value_str = f"{best_value:.2e}" + + # Calculate change from average + change_from_avg = loss_value - avg_loss + if abs(change_from_avg) > 1e-8: # Only show if meaningful + change_str = f"{abs(change_from_avg):.1e}" + if change_from_avg > 0: + # Current is above average (worse for loss) + change_display = f"↑ +{change_str}" + else: + # Current is below average (better for loss) + change_display = f"↓ -{change_str}" + else: + # At average value + change_display = "↔ 0.0e+00" + + # Format the line + block_display = block_name.replace("block_", "").zfill(2) + + line = ( + f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " + f"{best_value_str:<12} │ {change_display:<18} │" + ) + lines.append(line) + + # Add summary statistics + lines.append(divider) + + # Build summary string with optional step number + summary_parts = [] + if step_number is not None: + summary_parts.append(f"Step {step_number}") + summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + + summary_text = ", ".join(summary_parts) + summary = f"│ Summary: {summary_text}" + + # Pad summary to box width + padding_needed = box_width - len(summary) - 1 + summary += " " * padding_needed + "│" + lines.append(summary) + + # Add best step summary if we have best step data + if best_steps_dict and best_values_dict: + # Find the most common best step (modal step) + step_counts = {} + for step in best_steps_dict.values(): + step_counts[step] = step_counts.get(step, 0) + 1 + + if step_counts: + modal_best_step = max(step_counts, key=step_counts.get) + + # Get values at the modal best step for blocks that have it as their best + best_step_values = [] + for block_name, best_step in best_steps_dict.items(): + if best_step == modal_best_step and block_name in best_values_dict: + best_step_values.append(best_values_dict[block_name]) + + if best_step_values: + best_step_avg = sum(best_step_values) / len(best_step_values) + best_step_max = max(best_step_values) + best_step_min = min(best_step_values) + + best_step_summary_text = ( + f"Best: Step {modal_best_step}, Avg={best_step_avg:.2e}, " + f"Max={best_step_max:.2e}, Min={best_step_min:.2e}" + ) + best_step_summary = f"│ {best_step_summary_text}" + + # Pad best step summary to box width + padding_needed = box_width - len(best_step_summary) - 1 + best_step_summary += " " * padding_needed + "│" + lines.append(best_step_summary) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py new file mode 100644 index 0000000000..9980d1ef6b --- /dev/null +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. + +Coordinates forward passes and loss computation through model shards distributed across GPUs +using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. + +Used by validate_model.py during activation scoring for sharded models. +""" + +# mypy: ignore-errors +from __future__ import annotations + +import traceback +from contextlib import nullcontext +from typing import TYPE_CHECKING, Type + +import numpy as np +import torch +from torch import nn +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist + +from ..sewing_kit.core import ( + ExternalTarget, + InputReducer, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, +) +from ..sewing_kit.passage import InputArgs +from ..sewing_kit.utils import distributed_recv_obj, distributed_send_obj, fake_tensor +from ..tools.checkpoint_utils import init_module_with_state_dict +from ..utils.dummy_modules import DummyBlock +from .validation import _organize_outputs, calculate_batch_outputs + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from ..anymodel.model_descriptor import ModelDescriptor + +__all__ = [ + "LMHead", + "HiddenStatesAndLMHead", + "calculate_losses_pipeline", + "perform_pipeline_stitches", +] + + +def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None: + """Log detailed error info for distributed forward pass failures. + + When one rank crashes during distributed forward, others may hang waiting for communication. + This logging helps diagnose which rank failed and why. + """ + error_msg = ( + f"\n{'=' * 60}\n" + f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n" + f"Error: {type(e).__name__}: {e}\n" + f"{'=' * 60}\n" + f"{traceback.format_exc()}" + f"{'=' * 60}\n" + ) + print(error_msg, flush=True) + + +class LMHead(nn.Linear): + """Special class to allow FSDP wrapping without affecting other Linear layers in the model. + + Small nn helpers for puzzletron pipeline code. Model configs come from HuggingFace ``AutoConfig`` (AnyModel). + ``LMHead`` is a distinct ``nn.Linear`` subclass so pipeline / FSDP code can target it explicitly + """ + + +class HiddenStatesAndLMHead(list): + def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): + super().__init__(hidden_states) + self.lm_head_weights = lm_head_weights + + +@torch.no_grad() +def calculate_losses_pipeline( + stitched_model: StitchedModule, + dataloader: DataLoader | None, + target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + calc_on_cpu: bool = False, + just_model_forward: bool = False, + checkpoint_manager=None, + autocast_dtype: torch.dtype = torch.bfloat16, + descriptor: Type[ModelDescriptor] = None, + use_autocast: bool = True, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """Do model forward on each batch and calculate LM loss. + + Optionally also calculate kl_div loss and other metrics from given + *target_hidden_states_per_batch*. Optionally return hidden states per batch. + Does not support data-parallel. + *just_model_forward*: skip loss calculation, just forward the model (useful for activation hooks). + + Returns: + Tuple of ``(losses, target_hidden_states_per_batch)``. + + ``losses`` is a dict, e.g.:: + + { + "lm_loss": {"avg": float, "per_sample": [float, ...]}, + ... # more metrics if target_hidden_states_per_batch is provided + } + + ``target_hidden_states_per_batch`` is returned when *return_hidden_states* is True. + """ + if not isinstance(stitched_model, StitchedModule): + stitched_model = perform_pipeline_stitches(stitched_model, descriptor) + + params = list(stitched_model.parameters()) + model_device = params[0].device if params else "cpu" + + # Pre-populate outputs with dummy values for skipped batches + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + if dist.is_last_process(): + outputs = [{"lm_loss": [0.0]}] * start_batch + else: + outputs = None + + if dist.is_master(): + all_input_ids, all_targets = zip( + *[(batch["input_ids"], batch["targets"]) for batch in dataloader] + ) + if dist.size() > 1: + distributed_send_obj(all_targets, dst=dist.size() - 1) + + if dist.is_last_process(): + if dist.size() > 1: + all_targets = distributed_recv_obj(src=0) + + lm_head: LMHead = next( + module + for module_name, module in stitched_model.named_modules() + if "lm_head" in module_name + ) + + if target_hidden_states_per_batch is not None: + lm_head_weights = target_hidden_states_per_batch.lm_head_weights + with torch.device(model_device): + target_lm_head = init_module_with_state_dict( + {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False + ) + + if dist.is_master(): + num_batches = len(all_input_ids) + seq_len = all_input_ids[0].shape[1] + if dist.size() > 1: + torch.distributed.broadcast_object_list([num_batches, seq_len]) + + # Create progress bar with sliced range starting from checkpoint position + desc = ( + f"[rank {dist.rank()}] calculate_losses_pipeline(" + f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" + ) + progress_bar = tqdm(range(start_batch, num_batches), desc=desc) + else: + obj_list = [None, None] + if dist.size() > 1: + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list + progress_bar = range(start_batch, num_batches) + + stitched_model.eval() + + # Use autocast for mixed precision, or nullcontext if disabled + # (some models like Qwen3-VL MoE have dtype bugs under autocast) + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext() + ) + with autocast_ctx: + fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device) + for i_batch in progress_bar: + if dist.is_master(): + input_ids = all_input_ids[i_batch].to(model_device) + else: + input_ids = fake_input_ids + + try: + output = stitched_model({}, {}, input_ids) + except Exception as e: + _log_forward_error(e, dist.rank(), i_batch, num_batches) + raise + + if dist.is_last_process(): + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + hidden_states = output.captured_outputs.get("hidden_states") + targets = all_targets[i_batch].to(model_device) + + target_hidden_states = None + target_logits = None + if target_hidden_states_per_batch is not None: + target_hidden_states = target_hidden_states_per_batch[i_batch] + target_hidden_states = target_hidden_states.to(hidden_states.device) + target_logits = target_lm_head(target_hidden_states) + + if just_model_forward: + batch_outputs = {"lm_loss": [-1.0] * len(targets)} + else: + batch_outputs = calculate_batch_outputs( + hidden_states, + target_hidden_states, + logits, + target_logits, + targets, + return_hidden_states, + calculate_full_score_ablations, + calc_on_cpu, + ) + + outputs.append(batch_outputs) + + # Free GPU memory after processing each batch + del logits, hidden_states, targets + if target_hidden_states is not None: + del target_hidden_states + if target_logits is not None: + del target_logits + + # Free output tensor memory on all ranks + del output + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, hidden_states_per_batch = ( + _organize_outputs(outputs) if outputs is not None else (None, None) + ) + + if hidden_states_per_batch is not None: + hidden_states_per_batch = HiddenStatesAndLMHead( + hidden_states_per_batch, lm_head.weight.cpu() + ) + + dist.barrier() + return losses, hidden_states_per_batch + + +def perform_pipeline_stitches( + model, + descriptor: Type[ModelDescriptor], +) -> StitchedModule: + """Create pipeline stitches for distributed model evaluation. + + Args: + model: The model to stitch (any HuggingFace model with AnyModel descriptor). + descriptor: ModelDescriptor for layer naming. + """ + target = ModuleTarget("module", model) + stitcher = Needle() + + num_layers = model.config.num_hidden_layers + + is_real_block = np.flatnonzero( + [ + not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock) + for i in range(num_layers) + ] + ) + + first_block, last_block = is_real_block.min(), is_real_block.max() + + if dist.rank() != 0: + # receive activations from previous rank + stitcher.stitch( + RemoteTarget(peer_rank=dist.rank() - 1).value( + name="activations", adapter=lambda x: InputArgs(x) + ), + target.input( + name=descriptor.layer_block_name(first_block), + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + if not dist.is_last_process(): + # send activations to next rank + stitcher.stitch( + target.output(descriptor.layer_block_name(last_block)), + RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), + ) + else: + # register model output + stitcher.stitch( + target.output(name=descriptor.output_embedding_name()), + ExternalTarget().output("model_output"), + ) + stitcher.stitch( + target.output(name=descriptor.final_norm_name()), + ExternalTarget().output("hidden_states"), + ) + + stitched_module = stitcher.knot(ignore_extra_overrides=True) + return stitched_module diff --git a/modelopt/torch/puzzletron/utils/validation.py b/modelopt/torch/puzzletron/utils/validation.py new file mode 100644 index 0000000000..987a970aed --- /dev/null +++ b/modelopt/torch/puzzletron/utils/validation.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model validation and loss calculation utilities for single-GPU and multi-GPU setups. + +Also provides helper functions for loss metrics, KL divergence, JS divergence, +and similarity losses for knowledge distillation. +""" + +# mypy: ignore-errors +import functools +import math +from enum import Enum + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from typing_extensions import Self + +from ..tools import kd_model + +__all__ = [ + "LowMemorySparseTensor", + "calculate_losses", + "calculate_batch_outputs", + "cosine_embedding_loss", + "normalized_mse_loss", + "mse_loss", + "kl_div", +] + + +class UnshardedLowMemorySparseTensor: + def __init__(self, x: torch.Tensor): + inds_dtype = self._infer_inds_dtype(x) + x_sparse = x.to_sparse_coo() + self._values = x_sparse.values() + self._indices = x_sparse.indices().to(inds_dtype) + self._size = x_sparse.size() + + @staticmethod + def _infer_inds_dtype(x: torch.Tensor) -> torch.dtype: + max_dim = max(x.shape) + for inds_dtype in [torch.int16, torch.int32, torch.int64]: + if torch.iinfo(inds_dtype).max >= max_dim: + return inds_dtype + + def to_sparse_coo(self) -> torch.Tensor: + return torch.sparse_coo_tensor(values=self._values, indices=self._indices, size=self._size) + + def to_dense(self) -> torch.Tensor: + return self.to_sparse_coo().to_dense() + + def to(self, *args) -> Self: + self._values = self._values.to(*args) + for arg in args: + if isinstance(arg, torch.device) or isinstance(arg, str): + self._indices = self._indices.to(arg) + return self + + +class LowMemorySparseTensor: + _max_sparse_size = torch.iinfo(torch.int32).max + + def __init__(self, x: torch.Tensor): + num_chunks = math.ceil(x.numel() / self._max_sparse_size) + self._chunk_dim = np.argmax(x.shape) + self._chunks = [ + UnshardedLowMemorySparseTensor(chunk) + for chunk in torch.chunk(x, num_chunks, dim=self._chunk_dim) + ] + + def to(self, *args) -> Self: + for chunk in self._chunks: + chunk.to(*args) + return self + + def to_dense(self) -> torch.Tensor: + return torch.concat([chunk.to_dense() for chunk in self._chunks], dim=self._chunk_dim) + + +@torch.no_grad() +def calculate_losses( + model: nn.Module, + dataloader: DataLoader, + target_probs: None = None, + return_probs: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], None] | tuple[None, None]: + """Do model forward on each batch and calculate LM loss. + + Works on lit-llama models (single GPU) and HuggingFace models (can be multi-GPU). + Does not support data-parallel. + + .. note:: + Anything related to probs and hidden states is not supported currently. + + Returns: + Tuple of ``(outputs, None)``. ``outputs`` is a dict:: + + { + "lm_loss": [float, ...], + "token_accuracy_top_1": [float, ...], + "token_accuracy_top_5": [float, ...], + "token_accuracy_top_10": [float, ...], + } + """ + if (target_probs is not None) or return_probs: + raise NotImplementedError( + "calculate_losses() isn't updated according to the major refactor in " + "calculate_losses_pipeline() regarding hidden states." + ) + + model_device = next(model.parameters()).device + outputs = [] + + try: + num_batches = len(dataloader) + except: + num_batches = None + + # Adjust progress bar for resume + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + progress_bar = tqdm( + enumerate(dataloader), + total=num_batches, + desc=f"calculate_losses({(target_probs is None)=}, {return_probs=})", + ) + if start_batch > 0: + progress_bar.update(start_batch) + + for i_batch, batch in progress_bar: + # Skip batch if resuming from checkpoint + if checkpoint_manager and checkpoint_manager.should_skip_batch(i_batch): + continue + + input_ids = batch["input_ids"].to(model_device) + logits = model(input_ids) + if hasattr(logits, "logits"): + logits = logits.logits + # logits = logits.float() + + targets = batch["targets"].to(model_device) + + batch_outputs = calculate_batch_outputs( + hidden_states=None, + target_hidden_states=None, + logits=logits, + target_logits=None, + targets=targets, + return_hidden_states=False, + calculate_full_score_ablations=False, + calc_on_cpu=False, + ) + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, _ = _organize_outputs(outputs) + return losses, None + + +def calculate_batch_outputs( + hidden_states: torch.Tensor | None, + target_hidden_states: torch.Tensor | None, + logits: torch.Tensor, + target_logits: torch.Tensor | None, + targets: torch.Tensor, + return_hidden_states: bool, + calculate_full_score_ablations: bool, + calc_on_cpu: bool, +) -> dict: + if calc_on_cpu: + if hidden_states is not None: + hidden_states = hidden_states.cpu() + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.cpu() + if logits is not None: + logits = logits.cpu() + if target_logits is not None: + target_logits = target_logits.cpu() + if targets is not None: + targets = targets.cpu() + + batch_outputs = _calculate_ground_truth_based_scores(logits, targets) + + if (target_hidden_states is not None) or (target_logits is not None): + batch_outputs.update( + _calculate_teacher_similarity_scores( + hidden_states, + target_hidden_states, + logits, + target_logits, + calculate_full_score_ablations, + ) + ) + + if return_hidden_states: + batch_outputs["hidden_states_per_batch"] = hidden_states.cpu() + + return batch_outputs + + +def _organize_outputs( + outputs_per_batch: list[dict], +) -> tuple[dict[str, dict], list[torch.Tensor] | None]: + outputs = _concatenate_batch_outputs(outputs_per_batch) + hidden_states_per_batch = outputs.pop("hidden_states_per_batch", None) + losses = { + loss_name: { + "avg": sum(loss_per_sample) / len(loss_per_sample), + "per_sample": loss_per_sample, + } + for loss_name, loss_per_sample in outputs.items() + } + return losses, hidden_states_per_batch + + +def _concatenate_batch_outputs(outputs_per_batch: list[dict]) -> dict[str, list]: + outputs = {} + for output_name in outputs_per_batch[0]: # Regular dict is directly iterable + item_list = [] + for batch_outputs in outputs_per_batch: + batch_items = batch_outputs[output_name] + if isinstance(batch_items, list | tuple): + item_list.extend(batch_items) + else: + item_list.append(batch_items) + outputs[output_name] = item_list + return outputs + + +def _calculate_per_sample_lm_loss( + logits: torch.Tensor, + targets: torch.Tensor, +) -> list[float]: + per_sample_lm_loss = ( + torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + .mean(dim=-1) + .tolist() + ) + return per_sample_lm_loss + + +def _calculate_ground_truth_based_scores( + logits: torch.Tensor, + targets: torch.Tensor, +) -> dict[str, list[float]]: + scores = {"lm_loss": _calculate_per_sample_lm_loss(logits, targets)} + + for top_k in (1, 5, 10): + top_k_predictions = logits.topk(top_k, dim=-1).indices # [b, t, top_k] + is_target_in_predictions = (targets.unsqueeze(-1) == top_k_predictions).any( + dim=-1 + ) # [b, t] + fraction_model_predicted_target = is_target_in_predictions.float().mean(dim=-1) # [b] + scores[f"token_accuracy_top_{top_k}"] = fraction_model_predicted_target.tolist() + + return scores + + +def cosine_embedding_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return kd_model.cosine_embedding_loss_batched(hidden_states, target_hidden_states).tolist() + + +def normalized_mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + kd_model.normalized_mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mae_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.l1_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def _calculate_teacher_similarity_scores( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, + logits: torch.Tensor, + target_logits: torch.Tensor, + calculate_full_score_ablations: bool, +) -> dict[str, list[float]]: + """hidden_states: [batch, tokens, n_embd] + target_hidden_states: [batch, tokens, n_embd] + logits: [batch, tokens, vocab] + target_logits: [batch, tokens, vocab] + """ + + def calc_per_sample(func, logits, target_probs): + return [ + func(logits=logits[i_sample], target_probs=target_probs[i_sample]) + for i_sample in range(logits.shape[0]) + ] + + score_ablations = {} + + if (target_hidden_states is not None) and (hidden_states.shape == target_hidden_states.shape): + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_hidden_states" + score_ablations[score_name] = func(hidden_states, target_hidden_states) + + if target_logits is not None: + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_logits" + score_ablations[score_name] = func(logits, target_logits) + + for top_p in (0.99, 0.95, None) if calculate_full_score_ablations else (None,): + transformed_logits = ( + logits if (top_p is None) else top_p_top_k(logits, top_p=top_p, top_k=None) + ) + transformed_target_logits = ( + target_logits + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + target_probs = transformed_target_logits.softmax(-1) + + for func in (kl_div, js_div, tv_dist): + for clip_epsilon in ( + ( + ClipEpsilon.NO_CLIP, + ClipEpsilon.CLIP_NO_RENORMALIZE, + ClipEpsilon.CLIP_RENORMALIZE, + ) + if calculate_full_score_ablations + else (ClipEpsilon.NO_CLIP,) + ): + epsilon_factors = ( + (1.0, 0.1, 0.01) if not clip_epsilon == ClipEpsilon.NO_CLIP else (None,) + ) + + for epsilon_factor in epsilon_factors: + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_{clip_epsilon.name}" + f"--epsilon_factor_{epsilon_factor}" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, transformed_logits, target_probs + ) + if (top_p is None) and (clip_epsilon == ClipEpsilon.NO_CLIP): + short_score_name = func.__name__ + score_ablations[short_score_name] = score_ablations[score_name] + + for top_k in (1, 5, 10): + teacher_greedy_prediction = target_logits.argmax(dim=-1, keepdim=True) # [b,t,1] + student_top_k_predictions = logits.topk(top_k, dim=-1).indices # [b,t,k] + is_teacher_prediction_in_student_predictions = ( + teacher_greedy_prediction == student_top_k_predictions + ).any(dim=-1) # [b,t] + fraction_student_predicted_teacher = ( + is_teacher_prediction_in_student_predictions.float().mean(dim=-1) + ) # [b] + score_ablations[f"greedy_teacher_prediction_in_student_top_{top_k}"] = ( + fraction_student_predicted_teacher.tolist() + ) + + if calculate_full_score_ablations: + for top_p in (0.99, 0.95, 0.50, None): + # student + transformed_logits = logits.clone() + + # teacher + transformed_target_logits = ( + target_logits.clone() + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + + target_probs = transformed_target_logits.softmax(-1) + mask = transformed_target_logits == -1000 + if torch.any(mask): + transformed_logits[mask] = 0 + transformed_target_logits[mask] = 0 + target_probs[mask] = 0 + + for func in (mse_loss, mae_loss): + score_name = f"{func.__name__}_logits_top_p_{top_p}" + score_ablations[score_name] = func( + transformed_logits, transformed_target_logits + ) + + if top_p is not None and top_p > 0.9: + func = kl_div + clip_epsilon = ClipEpsilon.NO_CLIP + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, logits, target_probs + ) + # score_name = f"{func.__name__}_abs--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + # score_ablations[score_name] = [s.abs() for s in score_ablations[score_name]] + + return score_ablations + + +class ClipEpsilon(Enum): + NO_CLIP = "NO_CLIP" + CLIP_RENORMALIZE = "CLIP_RENORMALIZE" + CLIP_NO_RENORMALIZE = "CLIP_NO_RENORMALIZE" + + +def _logits_to_logprobs( + logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float +) -> torch.Tensor: + """logits: [tokens, vocab]""" + logprobs = logits.log_softmax( + -1 + ) # must normalize logits before clipping otherwise log(1/voacb) means nothing + if clip_epsilon == ClipEpsilon.NO_CLIP: + return logprobs + vocab_size = logprobs.shape[-1] + epsilon = math.log(epsilon_factor * 1 / vocab_size) + logprobs = torch.clip(logprobs, min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + logprobs = logprobs.log_softmax( + -1 + ) # we do log_softmax again to retain legitimate distributions + return logprobs + + +def kl_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Kullback-Leibler Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens = logits.shape[0] + logprobs = _logits_to_logprobs(logits, clip_epsilon, epsilon_factor) + + _kl_div = ( + F.kl_div(logprobs, target_probs, reduction="sum", log_target=False).item() / num_tokens + ) + return _kl_div + + +def js_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Jensen-Shannon Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + probs = logits.softmax(-1) + mixture_probs = (probs + target_probs) / 2 + mixture_logprobs = mixture_probs.log().clip(min=-1000) + + pred_kl_div = kl_div(mixture_logprobs, probs, clip_epsilon, epsilon_factor) + target_kl_div = kl_div(mixture_logprobs, target_probs, clip_epsilon, epsilon_factor) + _js_div = 0.5 * (pred_kl_div + target_kl_div) + return _js_div + + +def tv_dist( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Total Variation Distance (L1-loss) for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens, vocab_size = logits.shape + probs = logits.softmax(-1) + + if clip_epsilon != ClipEpsilon.NO_CLIP: + epsilon = epsilon_factor * 1 / vocab_size + probs = probs.clip(min=epsilon) + target_probs = target_probs.clip(min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + probs = probs / probs.sum(-1, keepdim=True) + target_probs = target_probs / target_probs.sum(-1, keepdim=True) + + _tv_dist = 0.5 * (probs - target_probs).abs().sum().item() / num_tokens + return _tv_dist + + +DEFAULT_TOP_P = 0.999 +# WestLake model: +# 700 = percentile 0.9 for top_p=0.99 +# 1700 = percentile 0.95 for top_p=0.99 and percentile 0.75 for top_p=0.999 +# For top_p=0.999 and top_k=1700 you take about 75 GB for 2048*8192 tokens +DEFAULT_TOP_K = 1000 + + +def top_p_top_k( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + filter_value=-1000, +) -> torch.Tensor: + logit_warpers = [] + if top_p is not None: + logit_warpers.append(TopPLogitsWarper(top_p=top_p, filter_value=filter_value)) + if top_k is not None: + logit_warpers.append(TopKLogitsWarper(top_k=top_k, filter_value=filter_value)) + + warped_logits = [] + for sample_logits in logits: + for warper in logit_warpers: + sample_logits = warper(input_ids=None, scores=sample_logits) + warped_logits.append(sample_logits) + warped_logits = torch.stack(warped_logits) + + return warped_logits diff --git a/modelopt/torch/utils/robust_json.py b/modelopt/torch/utils/robust_json.py index c4a72fde83..23a3091637 100644 --- a/modelopt/torch/utils/robust_json.py +++ b/modelopt/torch/utils/robust_json.py @@ -55,8 +55,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects (e.g. mixins injected into Hydra configs) + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/pyproject.toml b/pyproject.toml index 6170876308..16da6d6dc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,16 @@ hf = [ "transformers>=4.56", # Should match modelopt/torch/__init__.py and tox.ini "wonderwords", ] + +puzzletron = [ # Dependedencies for modelopt.torch.puzzletron subpackage + "fire", + "hydra-core==1.3.2", + "immutabledict", + "lru-dict", + "mip", + "pandas", + "typeguard", +] dev-lint = [ "bandit[toml]==1.7.9", # security/compliance checks "mypy==1.17.1", @@ -115,7 +125,7 @@ dev-test = [ "tox-current-env>=0.0.12", ] # Compound extras via self-references -all = ["nvidia-modelopt[onnx,hf]"] +all = ["nvidia-modelopt[onnx,hf,puzzletron]"] dev = ["nvidia-modelopt[all,dev-docs,dev-lint,dev-test]"] [project.urls] @@ -205,6 +215,17 @@ extend-ignore = [ "D", "E501", ] # Ignore missing docstrings or line length for Jupyter notebooks +"modelopt/torch/puzzletron/*" = [ + "C4", + "D", + "E", + "F", + "N", + "PERF", + "RUF", + "SIM", + "UP", +] # TODO: Disabled for now, will enable later, once all puzzletron code is migrated "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style "modelopt/torch/sparsity/attention_sparsity/kernels/*" = [ "N803", diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py new file mode 100644 index 0000000000..ea0a6fd219 --- /dev/null +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +import torch +from _test_utils.torch.transformers_models import get_tiny_tokenizer +from datasets import Dataset, DatasetDict +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase + +import modelopt.torch.puzzletron as mtpz +import modelopt.torch.utils.distributed as dist +from modelopt.torch.export import copy_hf_ckpt_remote_code + + +def setup_test_model_and_data( + tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None +) -> tuple[Path, Path, Path]: + """ + Setup the test model and data for the puzzletron NAS search. + + Args: + tmp_path: the temporary path to use for the test + rank: the rank of the process + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") + hybrid_override_pattern: For NemotronH models, the layer type pattern + + Returns: + tuple[Path, Path, Path]: the puzzle_dir, hf_checkpoint_path, dataset_path + """ + # Register Hydra custom resolvers (needed for config resolution) + mtpz.tools.register_hydra_resolvers() + + puzzle_dir = tmp_path / hf_model_name + hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_model_name}" + dataset_path = puzzle_dir / "dummy_dataset" + + if rank == 0: + save_dummy_dataset(dataset_path) + + # Create a small HF model + tokenizer = get_tiny_tokenizer() + create_and_save_small_hf_model( + output_path=str(hf_checkpoint_path), + tokenizer=tokenizer, + hf_model_name=hf_model_name, + hybrid_override_pattern=hybrid_override_pattern, + ) + dist.barrier() + + return puzzle_dir, hf_checkpoint_path, dataset_path + + +def create_and_save_small_hf_model( + output_path: str, + tokenizer: PreTrainedTokenizerBase, + hf_model_name: str, + hybrid_override_pattern: str | None = None, +): + """Create and save a small HuggingFace model for testing the conversion pipeline. + + Uses real HuggingFace config to preserve model-specific settings (like tie_word_embeddings), + but shrinks size parameters for fast testing. + + Args: + output_path: Where to save the model. + tokenizer: Tokenizer to save alongside the model. + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct"). + hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for + Attention+MLP, "M-" for Mamba+MLP). Must match num_hidden_layers. + """ + # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) + config = AutoConfig.from_pretrained(hf_model_name, trust_remote_code=True) + + # Override size-related params to make it small for testing + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + + # VL models have nested configs (text_config, vision_config) + if hasattr(config, "text_config") and hasattr(config, "vision_config"): + config.text_config.vocab_size = tokenizer.vocab_size + config.text_config.hidden_size = 256 + config.text_config.intermediate_size = 512 + config.text_config.num_hidden_layers = 2 + config.text_config.num_attention_heads = 32 + config.text_config.num_key_value_heads = 8 + config.text_config.num_experts = 16 # Reduce from 128 + config.text_config.moe_intermediate_size = 256 + config.text_config.max_position_embeddings = 512 + config.vision_config.depth = 2 # Reduce from 27 + config.vision_config.hidden_size = 256 + config.vision_config.intermediate_size = 512 + config.vision_config.out_hidden_size = 256 + # TODO: this is hack, redesign converter to not read config.num_hidden_layers directly. + # set top-level num_hidden_layers for converter compatibility + config.num_hidden_layers = config.text_config.num_hidden_layers + else: + # Regular models have flat config + config.vocab_size = tokenizer.vocab_size + config.hidden_size = 256 + config.intermediate_size = 512 + config.num_hidden_layers = max(2, dist.size()) + config.num_attention_heads = 32 + config.num_key_value_heads = 8 + config.max_position_embeddings = 512 + + # Fix layer_types to match num_hidden_layers (newer transformers validates this) + if hasattr(config, "layer_types") and config.layer_types is not None: + config.layer_types = config.layer_types[: config.num_hidden_layers] + + # Fix rope_scaling to be consistent with max_position_embeddings + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["original_max_position_embeddings"] = 256 + + # NemotronH requires hybrid_override_pattern to match num_hidden_layers + if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None: + config.hybrid_override_pattern = hybrid_override_pattern + + # Ensure pad_token_id is within vocab_size (nn.Embedding requires padding_idx < num_embeddings) + if ( + getattr(config, "pad_token_id", None) is not None + and config.pad_token_id >= tokenizer.vocab_size + ): + config.pad_token_id = 0 + + # Ensure moe_latent_size is present: the native transformers NemotronH model (>=5.5) + # accesses config.moe_latent_size but older trust_remote_code configs don't define it. + if not hasattr(config, "moe_latent_size"): + config.moe_latent_size = None + + # Set seed for reproducible weight initialization + torch.manual_seed(42) + + # Create and save the model + # Force CPU initialization for deterministic behavior (prevents NaN on RTX GPUs) + original_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + # TODO: Consider using AutoModel.from_config instead. + if hasattr(config, "text_config") and hasattr(config, "vision_config"): + from transformers import Qwen3VLMoeForConditionalGeneration + + model = Qwen3VLMoeForConditionalGeneration._from_config(config) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + # Initialize weights to ensure all parameters are properly initialized + # This prevents NaN values in uninitialized parameters (e.g., backbone.layers.1.mixer.gate.weight + # in nemotron-3-nano-30b-a3b-base-bf16) that can occur with from_config on RTX GPU cards (not on H100) + model.initialize_weights() + + # Fix any remaining NaN/Inf values that initialize_weights() might have missed + for param in model.parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + nan_inf_mask = torch.isnan(param) | torch.isinf(param) + param.data = torch.where(nan_inf_mask, torch.zeros_like(param), param) + + # Restore CUDA_VISIBLE_DEVICES after model creation and initialization + if original_cuda_visible is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible + else: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + model.to(dtype=torch.bfloat16) + # save_original_format=False: skip transformers' revert_weight_conversion so weights are saved + # with in-memory key names (e.g. backbone.embeddings.weight) rather than the on-disk "original" + # format (e.g. backbone.embedding.weight for NemotronH). This avoids key mismatches in + # load_and_shard_model which looks up shard keys from model.named_parameters(). + try: + model.save_pretrained(output_path, save_original_format=False) + except AttributeError: + # Workaround: some trust_remote_code models define _tied_weights_keys in an older + # format (returning a list) that is incompatible with transformers v5, which + # expects _get_tied_weight_keys to return a dict. Clear tied weight keys and retry. + for submodule in model.modules(): + if getattr(submodule, "_tied_weights_keys", None) is not None: + submodule._tied_weights_keys = None + model.save_pretrained(output_path, save_original_format=False) + + # Save tokenizer, config, and custom code files + tokenizer.save_pretrained(output_path) + config.save_pretrained(output_path) + if hasattr(config, "auto_map") and isinstance(config.auto_map, dict): + copy_hf_ckpt_remote_code(hf_model_name, output_path) + + +def save_dummy_dataset(dataset_path: Path | str): + """ + Save a dummy dataset for testing purposes. + """ + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(str(dataset_path)) diff --git a/tests/_test_utils/torch/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..02ee80b619 --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/_test_utils/torch/tokenizer/tokenizer.json b/tests/_test_utils/torch/tokenizer/tokenizer.json new file mode 100644 index 0000000000..7cfedbc616 --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/tokenizer.json @@ -0,0 +1,257 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 104, + "content": "<|start_header_id|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 105, + "content": "<|end_header_id|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101, + "Ġ": 102, + "Ċ": 103, + "<|start_header_id|>": 104, + "<|end_header_id|>": 105, + "Ā": 106, + "ā": 107, + "Ă": 108, + "ă": 109, + "Ą": 110, + "ą": 111, + "Ć": 112, + "ć": 113, + "Ĉ": 114, + "ĉ": 115, + "ċ": 116, + "Č": 117, + "č": 118, + "Ď": 119, + "ď": 120, + "Đ": 121, + "đ": 122, + "Ē": 123, + "ē": 124, + "Ĕ": 125, + "ĕ": 126, + "Ė": 127 + }, + "merges": [] + } +} diff --git a/tests/_test_utils/torch/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..66600edeef --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}Q: {{ message['content'] }}{% elif message['role'] == 'assistant' %}A: {{ message['content'] }}{% endif %}{{ eos_token }}{% endfor %}", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 8fe2f68b32..a6bfc4484a 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -39,6 +39,12 @@ SEED = 1234 +TINY_TOKENIZER_PATH = Path(__file__).parent / "tokenizer" + + +def get_tiny_tokenizer() -> "transformers.PreTrainedTokenizerBase": + return AutoTokenizer.from_pretrained(TINY_TOKENIZER_PATH) + ##### Qwen3 ##### def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel: @@ -66,9 +72,7 @@ def create_tiny_qwen3_dir( ) -> Path | tuple[Path, PreTrainedModel]: qwen3_dir = Path(tmp_path) / "tiny_qwen3" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(qwen3_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size tiny_qwen3 = get_tiny_qwen3(**config_kwargs) @@ -109,9 +113,7 @@ def create_tiny_qwen3_moe_dir( ) -> Path: qwen3_moe_dir = Path(tmp_path) / "tiny_qwen3_moe" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(qwen3_moe_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size get_tiny_qwen3_moe(**config_kwargs).save_pretrained(qwen3_moe_dir) @@ -144,9 +146,7 @@ def create_tiny_gpt_oss_dir( ) -> Path: gpt_oss_dir = Path(tmp_path) / "tiny_gpt_oss" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(gpt_oss_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size get_tiny_gpt_oss(**config_kwargs).save_pretrained(gpt_oss_dir) @@ -177,9 +177,7 @@ def create_tiny_llama_dir( ) -> Path: llama_dir = Path(tmp_path) / "tiny_llama" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(llama_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size diff --git a/tests/conftest.py b/tests/conftest.py index 53a2330c22..a4e65ff2ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,7 +115,7 @@ def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() -@pytest.fixture +@pytest.fixture(scope="session") def project_root_path(request: pytest.FixtureRequest) -> Path: """Fixture providing the project root path for tests.""" return Path(request.config.rootpath) diff --git a/tests/examples/megatron_bridge/test_distill.py b/tests/examples/megatron_bridge/test_distill.py index b5a0ca86d6..9f84f50c28 100644 --- a/tests/examples/megatron_bridge/test_distill.py +++ b/tests/examples/megatron_bridge/test_distill.py @@ -18,12 +18,14 @@ from pathlib import Path from _test_utils.examples.run_command import extend_cmd_parts, run_example_command -from _test_utils.torch.transformers_models import create_tiny_qwen3_dir +from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model +from _test_utils.torch.transformers_models import create_tiny_qwen3_dir, get_tiny_tokenizer + +from modelopt.torch.puzzletron.anymodel import convert_model def test_distill_and_convert(tmp_path: Path, num_gpus): teacher_hf_path = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True) - train_iters = 5 distill_output_dir = tmp_path / "distill_output" distill_cmd_parts = extend_cmd_parts( @@ -32,6 +34,7 @@ def test_distill_and_convert(tmp_path: Path, num_gpus): teacher_hf_path=teacher_hf_path, output_dir=distill_output_dir, tp_size=num_gpus, + pp_size=1, seq_length=32, mbs=1, gbs=4, @@ -63,3 +66,68 @@ def test_distill_and_convert(tmp_path: Path, num_gpus): check=True, ) assert (distilled_hf_path / "config.json").exists() + + +def test_distill_puzzletron_anymodel(tmp_path: Path, num_gpus): + """Integration test for distill.py with Puzzletron AnyModel (heterogeneous) checkpoints. + + Creates Qwen3 models, converts the student to Puzzletron AnyModel format + (heterogeneous layer architectures), runs mbridge distillation, and exports + the distilled checkpoint to HuggingFace format via --hf_export_path. + """ + student_hf_dir, student_anymodel_dir, teacher_hf_dir = ( + _prepare_puzzletron_anymodel_student_and_teacher(tmp_path) + ) + + train_iters = 5 + output_dir = tmp_path / "distill_output" + hf_export_path = tmp_path / "distilled_anymodel_hf" + cmd_parts = extend_cmd_parts( + ["torchrun", f"--nproc_per_node={num_gpus}", "distill.py", "--use_mock_data"], + student_hf_path=student_anymodel_dir, + teacher_hf_path=teacher_hf_dir, + output_dir=output_dir, + tp_size=num_gpus, + pp_size=1, + seq_length=32, + mbs=1, + gbs=4, + train_iters=train_iters, + lr_warmup_iters=2, + eval_interval=5, + eval_iters=1, + log_interval=1, + hf_export_path=hf_export_path, + student_hf_model=student_hf_dir, + ) + run_example_command(cmd_parts, example_path="megatron_bridge") + + run_config_path = output_dir / "checkpoints" / f"iter_{train_iters:07d}" / "run_config.yaml" + assert run_config_path.exists(), f"Expected run_config.yaml at: {run_config_path}" + + assert (hf_export_path / "config.json").exists(), ( + f"Expected HF export at: {hf_export_path}/config.json" + ) + + +def _prepare_puzzletron_anymodel_student_and_teacher(tmp_path: Path) -> tuple[Path, Path, Path]: + """Create Qwen3 models and convert student to Puzzletron AnyModel format.""" + student_hf_dir = tmp_path / "student_hf" + teacher_hf_dir = tmp_path / "teacher_hf" + + tokenizer = get_tiny_tokenizer() + + create_and_save_small_hf_model( + output_path=str(student_hf_dir), tokenizer=tokenizer, hf_model_name="Qwen/Qwen3-0.6B" + ) + + create_and_save_small_hf_model( + output_path=str(teacher_hf_dir), tokenizer=tokenizer, hf_model_name="Qwen/Qwen3-0.6B" + ) + + student_anymodel_dir = tmp_path / "student_anymodel" + convert_model( + input_dir=str(student_hf_dir), output_dir=str(student_anymodel_dir), converter="qwen3" + ) + + return student_hf_dir, student_anymodel_dir, teacher_hf_dir diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 763679d589..2436204f86 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -92,6 +92,7 @@ def draft_vocab_cache_dir(tmp_path_factory): def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft_vocab_cache_dir): """Test calibration of draft vocabulary.""" + draft_vocab_size = 80 # tiny tokenizer has vocab size 128 only run_example_command( [ "python", @@ -101,7 +102,7 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft "--data", tiny_daring_anteater_path, "--draft_vocab_size", - "100", + str(draft_vocab_size), "--save_dir", draft_vocab_cache_dir, ], @@ -110,7 +111,9 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft model_name = os.path.basename(os.path.normpath(tiny_llama_path)) d2t = torch.load(os.path.join(draft_vocab_cache_dir, model_name, "d2t.pt")) - assert d2t.shape[0] == 100, f"Expected draft vocab size 100, got {d2t.shape[0]}" + assert d2t.shape[0] == draft_vocab_size, ( + f"Expected draft vocab size {draft_vocab_size}, got {d2t.shape[0]}" + ) # fmt: off diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml new file mode 100644 index 0000000000..2843f0b97a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /Qwen/Qwen2.5-7B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen2 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..cf6201080c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml new file mode 100644 index 0000000000..cd82a47271 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml @@ -0,0 +1,112 @@ +# @package _global_ +defaults: + - /Qwen/Qwen3-8B/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen3 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..6bfeec715c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml new file mode 100644 index 0000000000..00b21ea979 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /Qwen/Qwen3-VL-30B-A3B-Instruct/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen3_vl + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + + mip_constraints: + - stats.num_local_experts: 1472 # same constraint as nemotron-3-nano for test consistency + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..4e0786dc7a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor.Qwen3VLExpertRemovalLayerDescriptor + target_name: "mlp" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.Qwen3VLRemoveExpertsIndependentHook} +activation_hooks_kwargs: + +# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) +num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" + layer_prefix_template: "model.language_model.layers.{layer_idx}.mlp" diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml new file mode 100644 index 0000000000..57051431a1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: attn_pruning + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +dataset_path: ??? diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml new file mode 100644 index 0000000000..8e2e0786b3 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml @@ -0,0 +1,106 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..6e8af1f651 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/attn_pruning@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaKVHeadsLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml new file mode 100644 index 0000000000..78cb6bd73c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml @@ -0,0 +1,106 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.2-3B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml new file mode 100644 index 0000000000..e042c4bb62 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml @@ -0,0 +1,112 @@ +# @package _global_ +defaults: + - /mistralai/Mistral-Small-24B-Instruct-2501/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: mistral_small + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..37c21fd638 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml new file mode 100644 index 0000000000..ab2b09e679 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml @@ -0,0 +1,115 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + mip_constraints: + - stats.num_local_experts: 1472 # teacher has: 23 moe-blocks * 128 experts = 2944 total experts use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..ae20b6d7d2 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHExpertRemovalLayerDescriptor + target_name: "mixer" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.NemotronHRemoveExpertsIndependentHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [96, 64, 32, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..abc501287d --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml @@ -0,0 +1,14 @@ +defaults: + - /pruning/pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml new file mode 100644 index 0000000000..906b7338d8 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h_v2 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..f68068c3ac --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml new file mode 100644 index 0000000000..78c878aa98 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml @@ -0,0 +1,111 @@ +# @package _global_ +defaults: + - /openai/gpt-oss-20b/pruning@pruning: expert_removal # TODO: Note: Works for unquantized test models, not MXFP4 quantized production models + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - bypass: + - override /hydra/hydra_logging: disabled + - _self_ + +descriptor: gpt_oss + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true # TODO: Works for unquantized test models + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + block_size: 512 # Toy model has max_position_embeddings=512; attention is O(batch*heads*seq^2) + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + - stats.num_local_experts: 48 # teacher has: 2 layers * 32 experts = 64 total experts + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + block_size: 512 # Toy model has max_position_embeddings=512; attention is O(batch*heads*seq^2) + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml new file mode 100644 index 0000000000..50d9f87028 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml @@ -0,0 +1,25 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +# Toy test model has max_position_embeddings=512 and num_attention_heads=32. +# Attention is O(batch * heads * seq^2), so we must keep batch and seq small. +# pruning_defaults uses micro_batch_size=4 and block_size=8192, which creates +# (4, 32, 8192, 8192) = 16 GiB attn tensors even with a tiny hidden_size. +micro_batch_size: 1 +block_size: 512 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..0dadc20134 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml @@ -0,0 +1,23 @@ +defaults: + - /pruning/pruning_defaults@_here_ + - _self_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn + layer_descriptor: + _target_: ??? + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook} +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml new file mode 100644 index 0000000000..c1c951984f --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml @@ -0,0 +1,19 @@ +defaults: + - /pruning/pruning_defaults@_here_ + - _self_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: ??? + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..4033fedf3a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..f00a86da66 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults@_here_ + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml new file mode 100644 index 0000000000..9dabef7413 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/test_nas_convert.py b/tests/gpu/torch/puzzletron/test_nas_convert.py new file mode 100644 index 0000000000..57239f1c0e --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_nas_convert.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import timedelta +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data + +import modelopt.torch.nas as mtn +import modelopt.torch.puzzletron as mtpz +import modelopt.torch.utils.distributed as dist + + +def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_ffn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_ffn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(minutes=10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + + # + # Run the mnt.convert() step + # + input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel() + mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + dist.cleanup() + + +def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_attn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_attn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(minutes=10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel() + mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + + dist.cleanup() diff --git a/tests/gpu/torch/puzzletron/test_nas_search.py b/tests/gpu/torch/puzzletron/test_nas_search.py new file mode 100644 index 0000000000..ae0dff47be --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_nas_search.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data + +import modelopt.torch.nas as mtn +import modelopt.torch.puzzletron as mtpz +import modelopt.torch.utils.distributed as dist + + +def test_nas_search(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_search_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_search_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(minutes=10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + + # + # Run the mnt.convert() step + # + input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Run the mnt.search() step + # + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + # + # Check assertions for mtn.search() step + # + if rank == 0: + # assertions for the build_library_and_stats step + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + + assert solution_0_filepath.exists() + + # assertions for the mip_and_realize_models step + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) + + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() + + dist.cleanup() diff --git a/tests/gpu/torch/puzzletron/test_nemotron_h_gpu_validation.py b/tests/gpu/torch/puzzletron/test_nemotron_h_gpu_validation.py new file mode 100644 index 0000000000..4c24e6c69c --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_nemotron_h_gpu_validation.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU validation for Nemotron-H hybrid model subblock parameter counting. + +Requires HuggingFace Hub access to nvidia/NVIDIA-Nemotron-Nano-12B-v2-Base (config only, +no weights are downloaded) and mamba_ssm (CUDA). + +Usage: + pytest -v -s -o addopts= tests/gpu/puzzletron/test_nemotron_h_gpu_validation.py +""" + +import copy + +import pytest + +import modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor # noqa: F401 +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.block_config import FFNConfig +from modelopt.torch.puzzletron.subblock_stats.calc_subblock_params_and_memory import ( + calculate_subblock_params, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config + +MODEL_ID = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-Base" + + +@pytest.fixture +def nemotron_descriptor(): + return ModelDescriptorFactory.get("nemotron_h_v2") + + +@pytest.fixture +def nemotron_config(nemotron_descriptor): + return load_model_config( + MODEL_ID, trust_remote_code=nemotron_descriptor.requires_trust_remote_code() + ) + + +def test_ffn_variants_produce_distinct_params(nemotron_config, nemotron_descriptor): + """FFN subblocks with different intermediate_size must report different param counts. + + On hybrid models, hybrid_override_pattern must be truncated to match the subblock + type; otherwise a single-layer model always builds layer 0 (Mamba) and every FFN + variant reports identical param counts. + """ + lm_config = nemotron_descriptor.get_language_model_config(nemotron_config) + pattern = lm_config.hybrid_override_pattern.replace("|", "") + ffn_indices = [i for i, c in enumerate(pattern) if c in ("-", "E")] + assert ffn_indices, f"No FFN layers in pattern: {pattern}" + + teacher_size = lm_config.intermediate_size + sizes = [teacher_size // 4, teacher_size // 2, teacher_size] + + param_counts = {} + for size in sizes: + layer_config = copy.deepcopy(nemotron_config) + ModelDescriptor.truncate_pattern_for_subblock( + nemotron_descriptor.get_language_model_config(layer_config), ffn_indices[0] + ) + + params = calculate_subblock_params( + layer_config, FFNConfig(intermediate_size=size), nemotron_descriptor + ) + param_counts[size] = params + print(f" intermediate_size={size:>8d} -> params={params:>12,}") + + assert len(set(param_counts.values())) == len(sizes), ( + f"Expected {len(sizes)} distinct param counts, got: {param_counts}" + ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py new file mode 100644 index 0000000000..a393e1e086 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from datetime import timedelta +from functools import partial +from pathlib import Path + +import pytest +import torch +import transformers +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from packaging.version import Version + +# The puzzletron pipeline imports mip unconditionally at module level. In NeMo containers +# the [puzzletron] extras are not pre-installed, so importing the test file fails with a +# deep ModuleNotFoundError. Skip early with an actionable message instead. +pytest.importorskip("mip", reason="pip install -e '.[puzzletron]' to install MIP solver") + +import modelopt.torch.puzzletron as mtpz +import modelopt.torch.utils.distributed as dist + +# The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) +# using a one-click command. +# +# Note: Bypass is disabled now in the test. +# + +SEED = 1234 + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [ + ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False), + ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False), + ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), + ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), + ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), + ("openai/gpt-oss-20b", "gpt_oss", None, True), + ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), + ("Qwen/Qwen3-8B", "qwen3", None, False), + ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), + ], +) +def test_puzzletron( + project_root_path: Path, + tmp_path: Path, + num_gpus, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str, + has_moe_layers: bool, +): + if "Qwen3-VL" in hf_model_name and Version(transformers.__version__) < Version("4.57.0"): + pytest.skip("Qwen3-VL is not supported with transformers < 4.57.0") + + if "Nemotron" in hf_model_name: + pytest.importorskip("mamba_ssm", reason="mamba_ssm required for Nemotron tests") + + spawn_multiprocess_job( + size=num_gpus, + job=partial( + _test_puzzletron_multiprocess_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_puzzletron_multiprocess_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str, + has_moe_layers: bool, + rank: int, + size: int, +): + # Set seed BEFORE dist.setup() to ensure reproducibility across all processes + set_seed(SEED) + dist.setup(timeout=timedelta(minutes=10)) + + # Setup the test model and data. + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, hf_model_name, hybrid_override_pattern + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + model_basename = hf_model_name.split("/")[1] + hydra_config_name = f"{hf_model_name}/{model_basename}" + + # Convert the model using AnyModel converter. + if rank == 0: + mtpz.anymodel.convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, + ) + dist.barrier() + + # Compress the model using a one-click approach + hydra_cfg = mtpz.entrypoint.puzzletron( + str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) + ) + + # + # Check assertions + # + if rank == 0: + if has_moe_layers: + # assertions for the score_pruning_activations step 1 (MoE models only) + rank_filepath = ( + f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/num_experts_8").exists() + + # assertions for the mip_and_realize_models step 6 + # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) + mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" + solution_dirs = [ + d + for d in mip_solutions_dir.iterdir() + if d.is_dir() and d.name.startswith("stats_num_local_experts_") + ] + assert len(solution_dirs) == 1, ( + f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" + ) + solution_dir = solution_dirs[0] + + solution_0_ckpt_config_path = ( + solution_dir / "solutions--checkpoints/solution_0/config.json" + ) + assert solution_0_ckpt_config_path.exists() + assert (solution_dir / "solutions.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_model_name, tolerance=0.01) + else: + # assertions for the score_pruning_activations step 1 (FFN pruning) + _assert_score_pruning_activations(puzzle_dir, hf_model_name) + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # assertions for the mip_and_realize_models step 6 + _assert_mip_solutions(puzzle_dir, hf_model_name) + + # assertions for the build_library_and_stats step 4 + assert (puzzle_dir / "replacement_library.json").is_file() + _assert_subblock_stats_anymodel(hf_model_name, hydra_cfg) + + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + assert solution_0_filepath.exists() + + dist.cleanup() + + +def _assert_subblock_stats_anymodel(hf_model_name: str, hydra_cfg) -> None: + """Minimal subblock_stats checks and teacher memory / param regression values.""" + assert (Path(hydra_cfg.puzzle_dir) / "subblock_stats.json").is_file() + teacher_mem_mib = mtpz.mip.get_teacher_memory_from_subblock_stats(hydra_cfg) + teacher_num_params = mtpz.mip.get_teacher_num_params_from_subblock_stats(hydra_cfg) + + assert abs(teacher_mem_mib - EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]) < 1e-2, ( + f"Teacher memory mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]}, got {teacher_mem_mib}" + ) + assert teacher_num_params == EXPECTED_TEACHER_NUM_PARAMS[hf_model_name], ( + f"Teacher num_params mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]}, got {teacher_num_params}" + ) + + +def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): + """Assertions for the score_pruning_activations step 1.""" + rank = dist.rank() + rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + assert (puzzle_dir / rank_filepath).is_file() + + pruning_scores = torch.load(puzzle_dir / rank_filepath) + layer_names = list(pruning_scores.keys()) + expected = EXPECTED_FFN_PRUNING_VALUES[hf_model_name] + size = dist.size() + + if expected is not None: + # In multi-GPU: layers are distributed across ranks + # Each rank processes len(expected) // size layers + expected_layers_per_rank = len(expected) // size + assert len(layer_names) == expected_layers_per_rank, ( + f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + ) + # Check each layer's values + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + # Calculate global layer index from rank and local index + global_idx = rank * expected_layers_per_rank + i + assert layer_data["score"][0].item() == expected[global_idx]["score"], ( + layer_name, + layer_data["score"][0].item(), + expected[global_idx]["score"], + global_idx, + ) + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) + else: + observed_values = [] + for layer_name in layer_names: + layer_data = pruning_scores[layer_name] + observed_values.append( + { + "score": layer_data["score"][0].item(), + "channels": layer_data["channels_importance_ascending"][0].item(), + } + ) + pytest.fail(f"Expected pruning values not found for {hf_model_name}!\n{observed_values=}") + + +def _assert_lm_loss(puzzle_dir: Path, hf_model_name: str, tolerance: float = 0.01): + """Validate lm_loss for a model solution.""" + solution_0_path = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + with open(solution_0_path) as f: + validation = json.load(f) + + actual_lm_loss = validation["lm_loss"]["avg"] + expected_lm_loss = EXPECTED_LM_LOSS.get(hf_model_name) + if expected_lm_loss is not None: + assert abs(actual_lm_loss - expected_lm_loss) < tolerance, ( + f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}" + ) + # TODO: not reproducible in CI, skipping for now + elif hf_model_name != "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": + pytest.fail( + f"Expected lm_loss values not found for {hf_model_name}! Observed value: {actual_lm_loss}" + ) + + +def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): + """Assertions for the mip_and_realize_models step.""" + mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB" + + assert (mip_dir / "solutions.json").exists() + assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_model_name) + + +# Expected pruning activation values per model +# Each model has a list of (score[0], channels[0]) tuples for each FFN layer +EXPECTED_FFN_PRUNING_VALUES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + {"score": 435, "channels": 94}, + {"score": 82, "channels": 338}, + ], + "meta-llama/Llama-3.2-3B-Instruct": [ + {"score": 440, "channels": 94}, + {"score": 88, "channels": 338}, + ], + "mistralai/Mistral-Small-24B-Instruct-2501": [ + {"score": 410, "channels": 94}, + {"score": 82, "channels": 338}, + ], + # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [ + {"score": 469, "channels": 81}, + ], + "Qwen/Qwen2.5-7B-Instruct": [ + {"score": 374, "channels": 205}, + # NOTE: below score differs as per GPU: set as per CI's RTX Pro 6000 BW. Getting 100 on RTX 6000 Ada + {"score": 102, "channels": 317}, + ], + "Qwen/Qwen3-8B": [ + {"score": 405, "channels": 173}, + {"score": 48, "channels": 376}, + ], +} + + +# Expected lm_loss values per model +EXPECTED_LM_LOSS = { + "meta-llama/Llama-3.1-8B-Instruct": 4.913641, + "meta-llama/Llama-3.2-3B-Instruct": 4.885118, + "mistralai/Mistral-Small-24B-Instruct-2501": 4.913618, + # TODO: not reproducible in CI, skipping for now + # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 5.068373, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.987095, + "openai/gpt-oss-20b": 4.898407, + "Qwen/Qwen2.5-7B-Instruct": 4.890478, + "Qwen/Qwen3-8B": 4.927514, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 5.0625, # 4.828125 for transformers v4.57 +} + + +# Expected teacher memory from subblock_stats (MiB) +EXPECTED_TEACHER_MEMORY_MIB = { + "meta-llama/Llama-3.1-8B-Instruct": 395.63, + "meta-llama/Llama-3.2-3B-Instruct": 395.63, + "mistralai/Mistral-Small-24B-Instruct-2501": 395.63, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 432.81, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 197.63, + "openai/gpt-oss-20b": 437.33, + "Qwen/Qwen2.5-7B-Instruct": 386.25, + "Qwen/Qwen3-8B": 395.63, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 406.14, +} + + +# Expected total teacher params from subblock_stats +EXPECTED_TEACHER_NUM_PARAMS = { + "meta-llama/Llama-3.1-8B-Instruct": 6096128, + "meta-llama/Llama-3.2-3B-Instruct": 6096128, + "mistralai/Mistral-Small-24B-Instruct-2501": 6096128, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 126255872, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 2949888, + "openai/gpt-oss-20b": 27959168, + "Qwen/Qwen2.5-7B-Instruct": 1181696, + "Qwen/Qwen3-8B": 6096640, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 11609856, +} diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 8867854737..d183855abb 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -17,8 +17,7 @@ import pytest import torch -from _test_utils.torch.transformers_models import get_tiny_llama -from transformers import AutoTokenizer +from _test_utils.torch.transformers_models import get_tiny_llama, get_tiny_tokenizer import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight @@ -208,7 +207,7 @@ def test_gptq_export_roundtrip(): "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) def test_gptq_e2e_flow(quant_cfg): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + tokenizer = get_tiny_tokenizer() model = get_tiny_llama(vocab_size=tokenizer.vocab_size).to("cuda") if tokenizer.pad_token is None: diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py index 64177ad89f..80408024a2 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -137,7 +137,8 @@ def test_forward_decode(self, tinyllama_model): sparse_model = sparse_attn.sparsify(model, config) # Create decode input (seq_len = 1) - input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + vocab_size = model.config.vocab_size + input_ids = torch.randint(0, vocab_size, (1, 1), device="cuda") # Forward pass sparse_model.eval() @@ -147,7 +148,7 @@ def test_forward_decode(self, tinyllama_model): # Verify output assert outputs.logits is not None assert not torch.isnan(outputs.logits).any() - assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + assert outputs.logits.shape == (1, 1, vocab_size) def test_gqa_attention(self, tinyllama_model): """Verify GQA support (num_kv_heads < num_heads).""" @@ -176,7 +177,7 @@ def test_gqa_attention(self, tinyllama_model): sparse_model = sparse_attn.sparsify(model, sparse_config) # Test forward pass with GQA - input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + input_ids = torch.randint(0, config.vocab_size, (1, 32), device="cuda") sparse_model.eval() with torch.no_grad(): diff --git a/tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py b/tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py deleted file mode 100644 index c9bc9a6027..0000000000 --- a/tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py""" - -from unittest.mock import patch - -from modelopt.torch.export.plugins.hf_checkpoint_utils import copy_remote_code - - -def test_copy_remote_code_local_dir(tmp_path): - """copy_remote_code copies top-level .py files from a local directory.""" - src_dir = tmp_path / "src" - src_dir.mkdir() - (src_dir / "modeling_custom.py").write_text("# custom model") - (src_dir / "configuration_custom.py").write_text("# custom config") - (src_dir / "not_python.txt").write_text("not python") - (src_dir / "subdir").mkdir() - (src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied") - - dst_dir = tmp_path / "dst" - dst_dir.mkdir() - - copy_remote_code(src_dir, dst_dir) - - assert (dst_dir / "modeling_custom.py").read_text() == "# custom model" - assert (dst_dir / "configuration_custom.py").read_text() == "# custom config" - assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied" - assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied" - - -def test_copy_remote_code_local_dir_no_py_files(tmp_path): - """copy_remote_code is a no-op when the local directory has no .py files.""" - src_dir = tmp_path / "src" - src_dir.mkdir() - (src_dir / "config.json").write_text("{}") - - dst_dir = tmp_path / "dst" - dst_dir.mkdir() - - copy_remote_code(src_dir, dst_dir) # should not raise - - assert list(dst_dir.iterdir()) == [], "no files should be copied" - - -def test_copy_remote_code_hub_id(tmp_path): - """copy_remote_code downloads and copies top-level .py files from a Hub model ID.""" - dst_dir = tmp_path / "dst" - dst_dir.mkdir() - - # Create a fake cached file that hf_hub_download would return - cached_py = tmp_path / "cached_modeling_custom.py" - cached_py.write_text("# custom hub model") - - repo_files = [ - "modeling_custom.py", # top-level .py — should be downloaded - "config.json", # non-.py — skip - "model.safetensors", # non-.py — skip - "subdir/nested.py", # subdirectory .py — skip (contains "/") - ] - - with ( - patch( - "modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files", - return_value=repo_files, - ) as mock_list, - patch( - "modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download", - return_value=str(cached_py), - ) as mock_download, - ): - copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir) - - mock_list.assert_called_once_with("meta-llama/Llama-3.2-1B") - # Only the top-level .py should have been downloaded - mock_download.assert_called_once_with( - repo_id="meta-llama/Llama-3.2-1B", filename="modeling_custom.py" - ) - assert (dst_dir / "modeling_custom.py").read_text() == "# custom hub model" - - -def test_copy_remote_code_hub_id_no_py_files(tmp_path): - """copy_remote_code is a no-op when the Hub repo has no top-level .py files.""" - dst_dir = tmp_path / "dst" - dst_dir.mkdir() - - with ( - patch( - "modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files", - return_value=["config.json", "model.safetensors"], - ), - patch("modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download") as mock_download, - ): - copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir) - - mock_download.assert_not_called() - assert list(dst_dir.iterdir()) == [] diff --git a/tests/unit/torch/export/test_hf_checkpoint_utils.py b/tests/unit/torch/export/test_hf_checkpoint_utils.py new file mode 100644 index 0000000000..f83cb35574 --- /dev/null +++ b/tests/unit/torch/export/test_hf_checkpoint_utils.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py""" + +from unittest.mock import patch + +import pytest + +pytest.importorskip("huggingface_hub") + +from modelopt.torch.export import copy_hf_ckpt_remote_code + + +def test_copy_hf_ckpt_remote_code_local_dir(tmp_path): + """copy_hf_ckpt_remote_code copies top-level .py files from a local directory.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "modeling_custom.py").write_text("# custom model") + (src_dir / "configuration_custom.py").write_text("# custom config") + (src_dir / "not_python.txt").write_text("not python") + (src_dir / "subdir").mkdir() + (src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied") + + dst_dir = tmp_path / "dst" + dst_dir.mkdir() + + copy_hf_ckpt_remote_code(src_dir, dst_dir) + + assert (dst_dir / "modeling_custom.py").read_text() == "# custom model" + assert (dst_dir / "configuration_custom.py").read_text() == "# custom config" + assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied" + assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied" + + +def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path): + """copy_hf_ckpt_remote_code is a no-op when the local directory has no .py files.""" + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "config.json").write_text("{}") + + dst_dir = tmp_path / "dst" + dst_dir.mkdir() + + copy_hf_ckpt_remote_code(src_dir, dst_dir) # should not raise + + assert list(dst_dir.iterdir()) == [], "no files should be copied" + + +def test_copy_hf_ckpt_remote_code_hub_id(tmp_path): + """copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID.""" + dst_dir = tmp_path / "dst" + + with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd: + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir) + + mock_sd.assert_called_once_with( + repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2", + local_dir=str(dst_dir), + allow_patterns=["*.py"], + ) diff --git a/tests/unit/torch/puzzletron/conftest.py b/tests/unit/torch/puzzletron/conftest.py new file mode 100644 index 0000000000..25f9e63847 --- /dev/null +++ b/tests/unit/torch/puzzletron/conftest.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + + +# `import fcntl` fails on Windows +def pytest_ignore_collect(collection_path, config): + return platform.system() == "Windows" diff --git a/tests/unit/torch/puzzletron/test_common.py b/tests/unit/torch/puzzletron/test_common.py new file mode 100644 index 0000000000..89c87ded51 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_common.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for modelopt.torch.puzzletron.tools.common utilities.""" + +import pytest +import torch + +import modelopt.torch.puzzletron as mtpz + + +@pytest.mark.parametrize( + ("input_dtype", "expected"), + [ + ("torch.bfloat16", torch.bfloat16), + ("torch.float16", torch.float16), + ("torch.float32", torch.float32), + ("bfloat16", torch.bfloat16), + ("float16", torch.float16), + ("float32", torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.float32, torch.float32), + ], + ids=[ + "str-bf16", + "str-fp16", + "str-fp32", + "bare-bf16", + "bare-fp16", + "bare-fp32", + "dtype-bf16", + "dtype-fp32", + ], +) +def test_resolve_torch_dtype(input_dtype, expected): + assert mtpz.tools.resolve_torch_dtype(input_dtype) is expected + + +def test_resolve_torch_dtype_unknown_name(): + with pytest.raises(ValueError, match="Unknown torch dtype"): + mtpz.tools.resolve_torch_dtype("not_a_real_dtype") + + +def test_resolve_torch_dtype_non_dtype_attr(): + with pytest.raises(ValueError, match="is not a dtype"): + mtpz.tools.resolve_torch_dtype("torch.nn") diff --git a/tests/unit/torch/puzzletron/test_convert_anymodel.py b/tests/unit/torch/puzzletron/test_convert_anymodel.py new file mode 100644 index 0000000000..febfd6259d --- /dev/null +++ b/tests/unit/torch/puzzletron/test_convert_anymodel.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import create_tiny_qwen3_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.puzzletron as mtpz + + +def test_convert_anymodel(tmp_path): + input_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True) + output_dir = tmp_path / "qwen3-0.6b-anymodel" + mtpz.anymodel.convert_model(input_dir, output_dir, converter="qwen3") + + descriptor = mtpz.anymodel.ModelDescriptorFactory.get("qwen3") + with mtpz.anymodel.deci_x_patcher(descriptor): + _ = AutoModelForCausalLM.from_pretrained(output_dir) diff --git a/tests/unit/torch/puzzletron/test_hybrid_pattern_truncation.py b/tests/unit/torch/puzzletron/test_hybrid_pattern_truncation.py new file mode 100644 index 0000000000..c19626f10b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_hybrid_pattern_truncation.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ModelDescriptor.truncate_pattern_for_subblock. + +Validates that the base descriptor method selects the correct pattern +character when building a 1-layer model for per-subblock param counting. +""" + +from types import SimpleNamespace + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor + +NEMOTRON_H_PATTERN = "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + + +class TestTruncatePatternForSubblock: + """Test ModelDescriptor.truncate_pattern_for_subblock.""" + + @pytest.mark.parametrize( + ("index", "expected"), + [ + (0, "M"), + (1, "-"), + (7, "*"), + ], + ids=["mamba", "ffn", "attention"], + ) + def test_index_selects_correct_layer_type(self, index, expected): + """Parent layer index selects the matching character from the pattern.""" + cfg = _make_config() + + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=index) + + assert cfg.hybrid_override_pattern == expected + + @pytest.mark.parametrize( + ("index", "expected"), + [ + (1, "-"), + (2, "*"), + ], + ids=["ffn_after_strip", "attention_after_strip"], + ) + def test_pipe_separators_stripped_before_indexing(self, index, expected): + """Pipe-delimited patterns like 'M|-|*' are normalised to 'M-*' before lookup.""" + cfg = _make_config("M|-|*") + + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=index) + + assert cfg.hybrid_override_pattern == expected + + def test_missing_attribute_is_noop(self): + """Config without hybrid_override_pattern is left unchanged.""" + cfg = SimpleNamespace() + + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=0) + + assert not hasattr(cfg, "hybrid_override_pattern") + + def test_empty_pattern_is_noop(self): + """Empty pattern string is left unchanged.""" + cfg = _make_config("") + + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=0) + + assert cfg.hybrid_override_pattern == "" + + def test_pipes_only_pattern_raises(self): + """Pattern with only pipe separators has no layer-type characters and should error.""" + cfg = _make_config("|||") + + with pytest.raises(ValueError, match="no layer-type characters"): + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=0) + + def test_none_index_defaults_to_first_char(self): + """Without an explicit index, defaults to pattern[0].""" + cfg = _make_config("*-M") + + ModelDescriptor.truncate_pattern_for_subblock(cfg) + + assert cfg.hybrid_override_pattern == "*" + + @pytest.mark.parametrize( + "index", + [999, -1], + ids=["above_range", "negative"], + ) + def test_out_of_range_index_defaults_to_first_char(self, index): + """Out-of-range index defaults to pattern[0].""" + cfg = _make_config("*-M") + + ModelDescriptor.truncate_pattern_for_subblock(cfg, parent_layer_index=index) + + assert cfg.hybrid_override_pattern == "*" + + +def _make_config(pattern=NEMOTRON_H_PATTERN): + return SimpleNamespace(hybrid_override_pattern=pattern) diff --git a/tests/unit/torch/puzzletron/test_resolve_descriptor_caching.py b/tests/unit/torch/puzzletron/test_resolve_descriptor_caching.py new file mode 100644 index 0000000000..5e5781d8e1 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_resolve_descriptor_caching.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for resolve_descriptor_from_pretrained dynamic-module caching. + +Verifies that resolve_descriptor_from_pretrained calls force_cache_dynamic_modules +so that decoder_layer_cls() works for models with custom code (e.g. Nemotron-H). +""" + +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("transformers") + +import modelopt.torch.puzzletron as mtpz + +MODEL_ID = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-Base" + +FACTORY_MODULE = "modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory" + + +class TestResolveDescriptorCachesDynamicModules: + """resolve_descriptor_from_pretrained must call force_cache_dynamic_modules.""" + + @patch(f"{FACTORY_MODULE}.force_cache_dynamic_modules") + @patch(f"{FACTORY_MODULE}.AutoConfig") + def test_force_cache_called(self, mock_auto_config_cls, mock_force_cache): + mock_config = MagicMock() + mock_config.model_type = "llama" + mock_auto_config_cls.from_pretrained.return_value = mock_config + + mtpz.anymodel.resolve_descriptor_from_pretrained("/fake/path", trust_remote_code=True) + + mock_force_cache.assert_called_once_with(mock_config, "/fake/path", trust_remote_code=True) + + @patch(f"{FACTORY_MODULE}.force_cache_dynamic_modules") + @patch(f"{FACTORY_MODULE}.AutoConfig") + def test_force_cache_called_without_trust_remote_code( + self, mock_auto_config_cls, mock_force_cache + ): + mock_config = MagicMock() + mock_config.model_type = "llama" + mock_auto_config_cls.from_pretrained.return_value = mock_config + + mtpz.anymodel.resolve_descriptor_from_pretrained("/fake/path") + + mock_force_cache.assert_called_once_with(mock_config, "/fake/path", trust_remote_code=False) + + +def test_resolve_descriptor_caches_dynamic_modules(): + """End-to-end: resolve_descriptor_from_pretrained must cache dynamic modules so decoder_layer_cls works.""" + pytest.importorskip("mamba_ssm") + + descriptor = mtpz.anymodel.resolve_descriptor_from_pretrained(MODEL_ID, trust_remote_code=True) + + layer_classes = descriptor.decoder_layer_cls() + assert layer_classes, ( + "decoder_layer_cls() returned empty after resolve_descriptor_from_pretrained" + ) + print( + f" Descriptor: {descriptor.__name__}, decoder classes: {[c.__name__ for c in layer_classes]}" + ) diff --git a/tox.ini b/tox.ini index e06661aa59..6694f7349d 100644 --- a/tox.ini +++ b/tox.ini @@ -50,7 +50,7 @@ deps = torch_deploy: .[onnx,torch,dev-test] commands = onnx: python -m pytest tests/unit/onnx - torch: python -m pytest tests/unit/torch --ignore tests/unit/torch/deploy + torch: python -m pytest tests/unit/torch --ignore tests/unit/torch/deploy --ignore tests/unit/torch/puzzletron torch_deploy: python -m pytest tests/unit/torch/deploy @@ -66,6 +66,10 @@ commands_pre = # Install cupy-cuda13x for INT4 ONNX quantization (default is cupy-cuda12x) pip uninstall -y cupy-cuda12x pip install cupy-cuda13x + + # Install mamba and causal-conv1d for Nemotron tests + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git commands = python -m pytest tests/gpu {env:COV_ARGS:}