diff --git a/tests/assets/logits_generation/generate_hf_golden_logits.py b/tests/assets/logits_generation/generate_hf_golden_logits.py index d81a169f55..c57d58c380 100644 --- a/tests/assets/logits_generation/generate_hf_golden_logits.py +++ b/tests/assets/logits_generation/generate_hf_golden_logits.py @@ -85,15 +85,20 @@ def save_golden_logits( model_class = AutoModelForCausalLM - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) print(f"loading model from {hf_model_path}") if hf_load_dtype == "float32": torch_dtype = torch.float32 elif hf_load_dtype == "bfloat16": torch_dtype = torch.bfloat16 + elif hf_load_dtype == "auto": + # Preserve per-tensor dtypes from safetensors metadata, useful for mixed- + # precision checkpoints where forcing a single dtype would corrupt non- + # default tensors. + torch_dtype = None else: - raise ValueError + raise ValueError(f"unsupported --hf-load-dtype: {hf_load_dtype}") model = model_class.from_pretrained( hf_model_path, @@ -194,9 +199,9 @@ def main(raw_args=None) -> None: "--hf-load-dtype", type=str, required=False, - choices=["float32", "bfloat16"], + choices=["float32", "bfloat16", "auto"], default="float32", - help="model_class.from_pretrained: dtype", + help="model_class.from_pretrained: dtype. 'auto' preserves per-tensor dtypes from safetensors.", ) parser.add_argument( "--trust-remote-code",