Skip to content

[deepseek_v4] save_pretrained silently downcasts FP32 tensors to BF16 (hc_*, attn_sink, ffn.gate.bias, compressor.ape, indexer.compressor.ape) #46167

@pasta-paul

Description

@pasta-paul

Summary

save_pretrained silently downcasts 417 FP32 tensors to BF16 when saving a DeepSeek-V4 model loaded with torch_dtype=torch.bfloat16. No warning, no error. The downcast loses numerical precision on plumbing tensors that DeepSeek's release spec keeps at FP32 for a reason.

Affected tensor groups

Per DeepSeek's release at deepseek-ai/DeepSeek-V4-Flash, the following tensor groups are FP32 in the source safetensors:

Pattern Count per model Role
layers.X.hc_attn_{base,fn,scale} + mtp.0.hc_attn_{base,fn,scale} 3 × 44 = 132 Hyper-connection attention plumbing
layers.X.hc_ffn_{base,fn,scale} + mtp.0.hc_ffn_{base,fn,scale} 3 × 44 = 132 Hyper-connection FFN plumbing
model.hc_head_{base,fn,scale} + mtp.0.hc_head_{base,fn,scale} 6 Top-level + MTP hc_head
layers.X.attn.attn_sink + mtp.0.attn.attn_sink ~44 Attention sink tokens
layers.X.ffn.gate.bias (was e_score_correction_bias) 41 MoE routing bias
layers.X.attn.compressor.ape (was position_bias) 41 Compressor positional encoding
layers.X.attn.indexer.compressor.ape (was position_bias) 21 Indexer positional encoding
Total 417

All 417 are saved as BF16 by save_pretrained when the model dtype is BF16.

Repro (minimal)

from transformers import AutoModelForCausalLM
import torch, safetensors
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-V4-Flash", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.save_pretrained("/tmp/dsv4_resaved")
# Compare /tmp/dsv4_resaved/model-*.safetensors dtypes vs source
# Result: 417 keys that were FP32 in source are now BF16 in resaved

Workaround (working in production)

Postprocess: after save_pretrained returns, walk the saved safetensors shards, read FP32 versions of the affected keys from the BF16 source checkpoint, and write them back in place (atomic per-shard via .tmp + os.replace). Working example: scripts/fixup_artifact.py in canada-quant/dsv4-flash-w4a16-fp8-mtp.

Suggested fix

save_pretrained should preserve per-tensor dtype rather than coerce to model's torch_dtype. The current behavior is right for most weights (you wanted BF16 for the W4A16 main model) but wrong for plumbing tensors that ship FP32 for numerical-stability reasons. A whitelist of "always preserve source dtype" patterns per architecture would work; for DSv4-Flash the regex is roughly r".*\.(hc_|attn_sink|ffn\.gate\.bias|compressor\.ape|indexer\.compressor\.ape)$".

Alternatively, a per-parameter dtype hint in the saved metadata + opt-in flag on the model config would solve it generically.

Why it matters

Without the postprocess restore, the saved artifact has BF16 plumbing → numerical drift on the gating math + LM head logits → measurable quality regression. With the restore, our W4A16+FP8+MTP artifact at https://huggingface.co/canada-quant/DeepSeek-V4-Flash-W4A16-FP8-MTP hits 86.88% MMLU and 93.71% GSM8K (within SE of the un-restored-baseline-impossible-to-produce target). The sibling NVFP4 artifact at canada-quant/DeepSeek-V4-Flash-NVFP4-FP8-MTP applies the same restore postprocess.

This is filed by the canada-quant team during W4A16+FP8+MTP quantization work. See FINDINGS_FOR_SIBLING.md §C13 for the diagnosis trace.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions