Finetune DeepSeek V4 Flash with NeMo Automodel #2052
khazic
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
NeMo Automodel now supports DeepSeek V4 Flash (deepseek-ai/DeepSeek-V4-Flash) — DeepSeek's latest fine-grained MoE language model with a hybrid attention zoo and hash-routing first layers. The PR (#2039) adds the model definition, state-dict adapter, V4-aware pipeline-parallel forward, a checkpoint loader path for FP4 e2m1fn / FP8 e8m0fnu / FP8 e5m2 dtypes, and finetune recipes.
Key features of DeepSeek V4 Flash
sqrtsoftplusscoring withnoaux_tctopk method and a clamped SwiGLU activation on routed experts (swiglu_limit=10.0) — both new branches in NeMo Automodel's sharedGate.forwardand MoE activation dispatch.compress_ratios:compress_ratio = 0→ pure Sliding-Window Attention (SWA) with a learned per-head attention sink.compress_ratio = 4→ Compressed Sparse Attention (CSA): a Compressor in overlap mode pools2*ratioraw tokens per compressed token, an Indexer selects the top-k most relevant compressed positions per query, and an explicit additive[B, 1, S, P_total]mask enforces causal correctness.compress_ratio = 128→ Hierarchical Compressed Attention (HCA): Compressor only (no Indexer), non-overlap pooling, deterministicp < (q+1) // ratiocausal mask.theta=10000forcompress_ratio==0layers andtheta=160000(with YaRN scaling) forcompress_ratio>0layers; the compress-rope is applied to both the main attention Q/KV and the Compressor sub-module on those layers. Encoded as INTERLEAVED pairs (view_as_complexstyle) to match the released checkpoint.num_key_value_heads=1) broadcast to all 64 attention heads, Q-LoRA (q_lora_rank=1024) and grouped O-LoRA (o_lora_rank=1024,o_groups=8) — not MLA. Per-head non-learnable rsqrt on Q afterwq_bmatches the inference reference.num_hash_layers(default 3) blocks use aDeepseekV4HashGatewith atid2eidlookup table for token→expert routing, instead of the score-based gate.input_idsis threaded throughDeepseekV4Modeland the V4-aware pipeline forward; under PP, hash layers live on stage 0 whereinput_idsis available.hc_mult=4copies of the hidden state, mixed via a learned col-norm-first Sinkhorn router (hc_split_sinkhorn).pre = sigmoid + eps,post = 2 * sigmoid(no+eps),comb = softmax(dim=-1) + epsfollowed by Sinkhorn — produces a doubly-stochastic mixing matrix per block.num_nextn_predict_layers(disabled by default in the validate harness, configurable for full training).max_position_embeddings = 1,048,576(1M tokens).Checkpoint format support
The released DSV4-Flash safetensors mix several quantization formats; the state-dict adapter handles all of them transparently:
e2m1fnpacked two values per int8 byte, with per-row 32-col FP8e8m0fnuscales — unpacked on load, re-emitted in matching packed placeholders onto_hfso DCP shape/dtype validation lines up with on-disk layout.e4m3fn128×128 block scales.num_hash_layersfrom the checkpoint'sconfig.jsonand drops the corresponding bias keys before DCP load.indexer.compressor.{ape,norm,wgate,wkv}+indexer.{wq_b,weights_proj}); the adapter renames these to land at ourcompressor.indexer.*flat layout.A new in-tree
HuggingFaceStorageReaderrecognizesF8_E8M0/F8_E5M2dtypes (the upstream reader silently dropped them), restoring DCP metadata on every rank for these checkpoints.Finetuning recipes
Two recipes in examples/llm_finetune/deepseek_v4/:
deepseek_v4_flash_validate.yaml— single-node 8×A100 infra validation on a 4-layer truncated harness exercising the full attention zoo (compress_ratios=[0, 0, 4, 128]→ SWA / SWA / CSA / HCA),num_hash_layers=2,pp=2 ep=4.deepseek_v4_flash_hellaswag.yaml— HellaSwag finetune recipe; the yaml header documents how to scalenum_hidden_layers/ep_sizefor the full 43-layer multi-node run.Layer-parity validation
The bringup was validated against the official DeepSeek inference reference (
dsv4flash/inference/model.py) by per-tensor dump bisection. On the 4-layer parity harness (compress_ratios=[0,0,4,128],num_hash_layers=2, PP=1 EP=8):Data
We use HellaSwag for the end-to-end full finetune. Below is the loss curve from a 43-layer full-finetune run with the full attention zoo (SWA + CSA + HCA) live:
Many thanks to @HuiyingLi @khazic for all contributions!
Beta Was this translation helpful? Give feedback.
All reactions