Skip to content

Commit aa62950

Browse files
committed
generate_hf_golden_logits: propagate trust_remote_code to tokenizer; add --hf-load-dtype=auto
1 parent b3a1832 commit aa62950

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

tests/assets/logits_generation/generate_hf_golden_logits.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,20 @@ def save_golden_logits(
8585

8686
model_class = AutoModelForCausalLM
8787

88-
tokenizer = AutoTokenizer.from_pretrained(model_id)
88+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
8989
print(f"loading model from {hf_model_path}")
9090

9191
if hf_load_dtype == "float32":
9292
torch_dtype = torch.float32
9393
elif hf_load_dtype == "bfloat16":
9494
torch_dtype = torch.bfloat16
95+
elif hf_load_dtype == "auto":
96+
# Preserve per-tensor dtypes from safetensors metadata, useful for mixed-
97+
# precision checkpoints where forcing a single dtype would corrupt non-
98+
# default tensors.
99+
torch_dtype = None
95100
else:
96-
raise ValueError
101+
raise ValueError(f"unsupported --hf-load-dtype: {hf_load_dtype}")
97102

98103
model = model_class.from_pretrained(
99104
hf_model_path,
@@ -194,9 +199,9 @@ def main(raw_args=None) -> None:
194199
"--hf-load-dtype",
195200
type=str,
196201
required=False,
197-
choices=["float32", "bfloat16"],
202+
choices=["float32", "bfloat16", "auto"],
198203
default="float32",
199-
help="model_class.from_pretrained: dtype",
204+
help="model_class.from_pretrained: dtype. 'auto' preserves per-tensor dtypes from safetensors metadata (useful for mixed-precision checkpoints).",
200205
)
201206
parser.add_argument(
202207
"--trust-remote-code",

0 commit comments

Comments
 (0)