11// Builds a ggml compute graph for one forward pass of the DFlash draft
2- // (5-layer non-causal Qwen3-flavored block-diffusion model).
2+ // (5-layer Qwen3-flavored block-diffusion model).
33//
44// Stateless: no KV cache. Each call takes:
5- // - noise_embed [hidden, q_len, 1] bf16 (target.tok_embd on [last_tok, MASK*15])
6- // - target_hidden_cat [5 *hidden, ctx_len, 1] bf16 (5 target layers concat along features)
5+ // - noise_embed [hidden, q_len, 1] f32 (target.tok_embd on [last_tok, MASK*15])
6+ // - target_hidden_cat [N *hidden, ctx_len, 1] f32 (N target layers concat along features)
77// - positions_q [q_len] i32 values [ctx_len..ctx_len+q_len-1]
88// - positions_k [ctx_len+q_len] i32 values [0..ctx_len+q_len-1]
9+ // - causal_mask_swa [kv_pad, q_len] f32 (optional; causal mask for SWA layers)
910// and returns:
10- // - hidden_states [hidden, q_len, 1] bf16 (final RMSNorm; NO lm_head here)
11+ // - hidden_states [hidden, q_len, 1] f32 (final RMSNorm; NO lm_head here)
1112//
1213// The caller projects `hidden_states` through the TARGET's lm_head separately
1314// (the draft has no lm_head of its own, it shares the target's).
1415//
15- // Semantics match megaqwen3_27b_dflash/reference/dflash_reference.py exactly :
16+ // Semantics:
1617// - fc @ target_hidden_cat -> rms_norm with hidden_norm -> target_feat
17- // - Per layer (non-causal) :
18+ // - Per layer:
1819// h_norm = rms_norm(h) * input_layernorm
1920// Q = wq @ h_norm -> per-head q_norm
2021// K_ctx/V_ctx = wk/wv @ target_feat
2122// K_noi/V_noi = wk/wv @ h_norm
2223// K = concat[K_ctx, K_noi] -> per-head k_norm
2324// V = concat[V_ctx, V_noi]
24- // RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style, theta=10M )
25- // attn = flash_attn_ext(Q, K, V, mask=null , scale=1/sqrt(head_dim)) non-causal
25+ // RoPE(Q, positions_q); RoPE(K, positions_k) (NEOX style)
26+ // attn = flash_attn_ext(Q, K, V, mask, scale) SWA=causal, full= non-causal
2627// h += wo @ attn
2728// h_norm = rms_norm(h) * post_attention_layernorm
2829// h += w_down @ (silu(w_gate @ h_norm) * (w_up @ h_norm))
@@ -46,7 +47,7 @@ DraftGraphOutputs build_draft_graph(
4647 const int n_kv = w.n_head_kv ;
4748 const int head_dim = w.head_dim ;
4849 const float eps = DFLASH27B_RMS_EPS;
49- const float rope_base = DFLASH27B_ROPE_THETA ;
50+ const float rope_base = w. rope_theta ;
5051
5152 // ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm)
5253 // fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden)
@@ -134,9 +135,10 @@ DraftGraphOutputs build_draft_graph(
134135 V = ggml_permute (ctx, V, 0 , 2 , 1 , 3 ); // [head_dim, eff_total_k, n_kv, 1]
135136 V = ggml_cont (ctx, V);
136137
137- // ── 2f. Non- causal flash attention; GQA broadcast handled internally .
138+ // ── 2f. Attention: causal for SWA layers, non-causal for full layers .
138139 const float scale = 1 .0f / std::sqrt ((float )head_dim);
139- ggml_tensor * attn = ggml_flash_attn_ext (ctx, Q, K, V, /* mask=*/ nullptr ,
140+ ggml_tensor * mask = (L.is_swa && in.causal_mask_swa ) ? in.causal_mask_swa : nullptr ;
141+ ggml_tensor * attn = ggml_flash_attn_ext (ctx, Q, K, V, mask,
140142 scale, /* max_bias=*/ 0 .0f ,
141143 /* logit_softcap=*/ 0 .0f );
142144 // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]
0 commit comments