|
42 | 42 |
|
43 | 43 | import datasets |
44 | 44 | from lm_eval import utils |
45 | | -from lm_eval.__main__ import cli_evaluate, parse_eval_args, setup_parser |
| 45 | +from packaging.version import Version |
46 | 46 |
|
47 | | -if not version("lm_eval").startswith("0.4.8"): |
48 | | - warnings.warn( |
49 | | - f"lm_eval_hf.py is tested with lm-eval 0.4.8; found {version('lm_eval')}. " |
50 | | - "Later versions may have incompatible API changes." |
51 | | - ) |
| 47 | +if Version(version("lm_eval")) < Version("0.4.10"): |
| 48 | + raise ImportError(f"lm_eval_hf.py requires lm-eval >= 0.4.10; found {version('lm_eval')}.") |
| 49 | + |
| 50 | +from lm_eval._cli import HarnessCLI |
52 | 51 | from lm_eval.api.model import T |
53 | 52 | from lm_eval.models.huggingface import HFLM |
| 53 | +from lm_eval.utils import setup_logging |
54 | 54 | from quantization_utils import quantize_model |
55 | 55 | from sparse_attention_utils import sparsify_model |
56 | 56 |
|
@@ -160,9 +160,24 @@ def create_from_arg_string( |
160 | 160 | HFLM.create_from_arg_string = classmethod(create_from_arg_string) |
161 | 161 |
|
162 | 162 |
|
163 | | -def setup_parser_with_modelopt_args(): |
164 | | - """Extend the lm-eval argument parser with ModelOpt quantization and sparsity options.""" |
165 | | - parser = setup_parser() |
| 163 | +# ModelOpt-specific args that we add to lm-eval's parser. After parsing, these are |
| 164 | +# moved out of the argparse namespace and into args.model_args so they reach |
| 165 | +# HFLM.create_from_arg_obj (and so lm-eval's own arg validation doesn't reject them). |
| 166 | +_MODELOPT_ARG_KEYS = ( |
| 167 | + "quant_cfg", |
| 168 | + "calib_batch_size", |
| 169 | + "calib_size", |
| 170 | + "auto_quantize_bits", |
| 171 | + "auto_quantize_method", |
| 172 | + "auto_quantize_score_size", |
| 173 | + "auto_quantize_checkpoint", |
| 174 | + "compress", |
| 175 | + "sparse_cfg", |
| 176 | +) |
| 177 | + |
| 178 | + |
| 179 | +def _add_modelopt_args(parser): |
| 180 | + """Extend an lm-eval argument parser with ModelOpt quantization and sparsity options.""" |
166 | 181 | parser.add_argument( |
167 | 182 | "--quant_cfg", |
168 | 183 | type=str, |
@@ -221,33 +236,45 @@ def setup_parser_with_modelopt_args(): |
221 | 236 | type=str, |
222 | 237 | help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", |
223 | 238 | ) |
224 | | - return parser |
225 | 239 |
|
226 | 240 |
|
227 | | -if __name__ == "__main__": |
228 | | - parser = setup_parser_with_modelopt_args() |
229 | | - args = parse_eval_args(parser) |
230 | | - model_args = utils.simple_parse_args_string(args.model_args) |
| 241 | +def _inject_modelopt_args_into_model_args(args): |
| 242 | + """Move ModelOpt args from the argparse namespace into args.model_args. |
| 243 | +
|
| 244 | + args.model_args is a dict (parsed by lm-eval's MergeDictAction). The ModelOpt |
| 245 | + keys must be removed from the namespace so EvaluatorConfig.from_cli doesn't |
| 246 | + reject them as unknown kwargs. |
| 247 | + """ |
| 248 | + model_args = dict(args.model_args) if args.model_args else {} |
231 | 249 |
|
232 | | - if args.trust_remote_code: |
| 250 | + if getattr(args, "trust_remote_code", False): |
| 251 | + # Propagate the user-provided --trust_remote_code flag (not hardcoded). |
233 | 252 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True |
234 | 253 | model_args["trust_remote_code"] = True |
235 | 254 | args.trust_remote_code = None |
236 | 255 |
|
237 | | - model_args.update( |
238 | | - { |
239 | | - "quant_cfg": args.quant_cfg, |
240 | | - "auto_quantize_bits": args.auto_quantize_bits, |
241 | | - "auto_quantize_method": args.auto_quantize_method, |
242 | | - "auto_quantize_score_size": args.auto_quantize_score_size, |
243 | | - "auto_quantize_checkpoint": args.auto_quantize_checkpoint, |
244 | | - "calib_batch_size": args.calib_batch_size, |
245 | | - "calib_size": args.calib_size, |
246 | | - "compress": args.compress, |
247 | | - "sparse_cfg": args.sparse_cfg, |
248 | | - } |
249 | | - ) |
| 256 | + for key in _MODELOPT_ARG_KEYS: |
| 257 | + if hasattr(args, key): |
| 258 | + model_args[key] = getattr(args, key) |
| 259 | + delattr(args, key) |
250 | 260 |
|
251 | 261 | args.model_args = model_args |
252 | 262 |
|
253 | | - cli_evaluate(args) |
| 263 | + |
| 264 | +if __name__ == "__main__": |
| 265 | + setup_logging() |
| 266 | + cli = HarnessCLI() |
| 267 | + # The `run` subcommand owns the model/task arguments; extend that parser. |
| 268 | + # `_subparsers` is private API; guard so a future lm-eval refactor surfaces a |
| 269 | + # clear error instead of an opaque AttributeError. |
| 270 | + try: |
| 271 | + run_parser = cli._subparsers.choices["run"] |
| 272 | + except (AttributeError, KeyError) as e: |
| 273 | + raise RuntimeError( |
| 274 | + "Cannot locate lm-eval's `run` subparser; the HarnessCLI internals may " |
| 275 | + f"have changed. Installed lm-eval version: {version('lm_eval')}." |
| 276 | + ) from e |
| 277 | + _add_modelopt_args(run_parser) |
| 278 | + args = cli.parse_args() |
| 279 | + _inject_modelopt_args_into_model_args(args) |
| 280 | + cli.execute(args) |
0 commit comments