|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Run lm-eval directly on AnyModel (puzzletron) checkpoints without a deployment server. |
| 17 | +
|
| 18 | +AnyModel checkpoints have heterogeneous decoder layers; this script patches |
| 19 | +lm-eval's HFLM to wrap model loading with deci_x_patcher so they load correctly. |
| 20 | +
|
| 21 | +Usage: |
| 22 | + # From the repo root (requires: pip install -e ".[hf,puzzletron]") |
| 23 | + # Descriptor is auto-detected from the checkpoint's config.json model_type. |
| 24 | + python examples/puzzletron/evaluation/lm_eval_anymodel.py \ |
| 25 | + --model hf \ |
| 26 | + --model_args pretrained=/path/to/anymodel_checkpoint,dtype=bfloat16,parallelize=True \ |
| 27 | + --tasks mmlu \ |
| 28 | + --num_fewshot 5 \ |
| 29 | + --batch_size 4 |
| 30 | +
|
| 31 | + # With sample limit for smoke tests |
| 32 | + python examples/puzzletron/evaluation/lm_eval_anymodel.py \ |
| 33 | + --model hf \ |
| 34 | + --model_args pretrained=/path/to/anymodel_checkpoint,dtype=bfloat16,parallelize=True \ |
| 35 | + --tasks mmlu \ |
| 36 | + --limit 10 |
| 37 | +""" |
| 38 | + |
| 39 | +from lm_eval.__main__ import cli_evaluate |
| 40 | +from lm_eval.api.model import T |
| 41 | +from lm_eval.models.huggingface import HFLM |
| 42 | + |
| 43 | +# Trigger factory registration for all model descriptors |
| 44 | +import modelopt.torch.puzzletron.anymodel.models # noqa: F401 |
| 45 | +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory |
| 46 | +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher |
| 47 | + |
| 48 | +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. |
| 49 | +# Local to this script; add entries when supporting new model types for auto-detection. |
| 50 | +_MODEL_TYPE_TO_DESCRIPTOR = { |
| 51 | + "llama": "llama", |
| 52 | + "mistral": "mistral_small", |
| 53 | + "qwen2": "qwen2", |
| 54 | + "qwen3": "qwen3", |
| 55 | + "nemotron_h": "nemotron_h", |
| 56 | + "nemotron_h_v2": "nemotron_h_v2", |
| 57 | + "gpt_oss_20b": "gpt_oss_20b", |
| 58 | +} |
| 59 | + |
| 60 | + |
| 61 | +def _resolve_descriptor_from_pretrained(pretrained: str | None): |
| 62 | + """Resolve the model descriptor by loading the checkpoint config and mapping model_type.""" |
| 63 | + if not pretrained: |
| 64 | + raise ValueError( |
| 65 | + "pretrained must be set in --model_args " |
| 66 | + "(e.g. --model_args pretrained=/path/to/checkpoint,dtype=bfloat16)." |
| 67 | + ) |
| 68 | + from transformers import AutoConfig |
| 69 | + |
| 70 | + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=True) |
| 71 | + model_type = getattr(config, "model_type", None) |
| 72 | + |
| 73 | + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: |
| 74 | + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] |
| 75 | + print( |
| 76 | + f"[lm_eval_anymodel] Auto-detected model_type='{model_type}' → descriptor='{detected}'" |
| 77 | + ) |
| 78 | + return ModelDescriptorFactory.get(detected) |
| 79 | + |
| 80 | + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) |
| 81 | + raise ValueError( |
| 82 | + f"Cannot auto-detect descriptor for model_type='{model_type}'. " |
| 83 | + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: |
| 88 | + """Override HFLM.create_from_arg_obj to wrap model loading with deci_x_patcher.""" |
| 89 | + pretrained = arg_dict.get("pretrained") |
| 90 | + descriptor = _resolve_descriptor_from_pretrained(pretrained) |
| 91 | + |
| 92 | + additional_config = {} if additional_config is None else additional_config |
| 93 | + additional_config = {k: v for k, v in additional_config.items() if v is not None} |
| 94 | + |
| 95 | + # The patcher must be active during HFLM.__init__ because that's where |
| 96 | + # AutoModelForCausalLM.from_pretrained() is called internally. |
| 97 | + with deci_x_patcher(model_descriptor=descriptor): |
| 98 | + model_obj = cls(**arg_dict, **additional_config) |
| 99 | + |
| 100 | + return model_obj |
| 101 | + |
| 102 | + |
| 103 | +# Monkey-patch HFLM so lm-eval uses our patched model loading |
| 104 | +HFLM.create_from_arg_obj = classmethod(create_from_arg_obj) |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + cli_evaluate() |
0 commit comments