Skip to content

Commit 94a4dfd

Browse files
committed
feat(dflash): add NVFP4 per-tensor scale2 support
Add support for NVFP4-quantized GGUF models (e.g. LibertAI Qwen3.6-27B-NVFP4) by loading per-tensor weight scales and applying them in the target graph. Scale values are read as host-side floats from the GGUF mmap at load time and applied via ggml_scale() — a compile-time scalar multiply with zero extra kernel launches. This avoids ggml_mul() with [1]-shaped GPU tensors, which adds 768 kernel launches per forward pass and causes ~30x overhead in batched DDTree verify mode (1001ms -> 43ms per step on RTX 5090). Supports both naming conventions: - LibertAI: blk.N.ffn_gate.scale - Heretic: blk.N.ffn_gate.weight.scale Non-NVFP4 models (Q4_K_M etc) are unaffected — scale fields default to 1.0f and apply_scale2() returns early with zero overhead. Also removes the DFLASH27B_USE_BLACKWELL_CONSUMER_FIX CMake workaround, which incorrectly assumed consumer Blackwell GPUs (RTX 5090) lack FP4 MMA instructions. The RTX 5090 fully supports sm_120a and native FP4 tensor cores. Note: full native FP4 MMA performance requires upstream PR ggml-org#22196 to be merged into the Luce-Org llama.cpp submodule fork. Without it, NVFP4 models still work correctly via the generic dequant-to-Q8_1 fallback path.
1 parent abdde79 commit 94a4dfd

4 files changed

Lines changed: 98 additions & 40 deletions

File tree

dflash/CMakeLists.txt

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,6 @@ else()
6969
endif()
7070
endif()
7171

72-
# Consumer Blackwell workaround: skip sm_12x→sm_12xa replacement and FP4
73-
# mmq kernels that can trigger illegal-instruction faults on consumer chips.
74-
# By default, auto-enable when the resolved CUDA arch list includes a 12x
75-
# entry. Set DFLASH27B_USE_BLACKWELL_CONSUMER_FIX=ON to force this behavior
76-
# explicitly (for cross-compiles or custom arch lists).
77-
option(DFLASH27B_USE_BLACKWELL_CONSUMER_FIX
78-
"Enable ggml consumer-Blackwell workaround (skip sm_12x→sm_12xa, exclude FP4 mmq kernels)" OFF)
79-
if(DFLASH27B_USE_BLACKWELL_CONSUMER_FIX)
80-
set(_dflash_is_consumer_blackwell ON)
81-
endif()
82-
83-
if(NOT DEFINED _dflash_is_consumer_blackwell)
84-
set(_dflash_is_consumer_blackwell OFF)
85-
# Iterate the resolved dflash27b arch list, not raw CMAKE_CUDA_ARCHITECTURES,
86-
# which is empty on the default path (the project supplies its own list above).
87-
foreach(_arch IN LISTS _dflash27b_archs)
88-
string(REGEX REPLACE "[^0-9]" "" _dflash_arch_num "${_arch}")
89-
if(_dflash_arch_num MATCHES "^12[0-9]$")
90-
set(_dflash_is_consumer_blackwell ON)
91-
break()
92-
endif()
93-
endforeach()
94-
endif()
95-
96-
if(_dflash_is_consumer_blackwell)
97-
set(GGML_CUDA_BLACKWELL_CONSUMER ON CACHE BOOL
98-
"Skip sm_12X→sm_12Xa for consumer Blackwell (no FP4)" FORCE)
99-
endif()
10072
# Use only the ggml subtree of llama.cpp (skip libllama).
10173
add_subdirectory(deps/llama.cpp/ggml EXCLUDE_FROM_ALL)
10274

dflash/src/gguf_target_loader.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ bool load_target_gguf_partial(const std::string & path,
471471
L.ssm_norm = fnd("ssm_norm.weight");
472472
L.ssm_out = fnd("ssm_out.weight");
473473

474+
// NVFP4 per-tensor weight scales are read after the mmap is loaded (below).
475+
474476
// Sanity: each layer must be EITHER full-attn OR deltanet, not both, not neither.
475477
const bool has_attn = L.wq && L.wk && L.wv && L.wo && L.q_norm && L.k_norm;
476478
const bool has_ssm = L.wqkv && L.wqkv_gate && L.ssm_conv1d && L.ssm_out;
@@ -572,6 +574,62 @@ bool load_target_gguf_partial(const std::string & path,
572574
total += sz;
573575
}
574576

577+
// ── 4b. Read NVFP4 per-tensor weight scales (optional; 1.0 for non-NVFP4).
578+
//
579+
// Scale tensors are F32 shape [1] — a single float per matmul weight.
580+
// We read the value from mmap into host-side floats so the graph builder
581+
// can use ggml_scale() (compile-time scalar, zero kernel launches) instead
582+
// of ggml_mul() with a [1]-shaped GPU tensor. The ggml_mul approach adds
583+
// 768 kernel launches per forward pass and causes catastrophic overhead
584+
// (~1000ms vs ~30ms) in batched DDTree verify mode.
585+
//
586+
// LibertAI convention: "blk.N.ffn_gate.scale"
587+
// Heretic convention: "blk.N.ffn_gate.weight.scale"
588+
{
589+
auto read_scale = [&](int il, const char * base) -> float {
590+
char sname[128];
591+
// Try "base.scale" first (LibertAI), then "base.weight.scale" (heretic)
592+
std::snprintf(sname, sizeof(sname), "blk.%d.%s.scale", il, base);
593+
int64_t stid = gguf_find_tensor(gctx, sname);
594+
if (stid < 0) {
595+
std::snprintf(sname, sizeof(sname), "blk.%d.%s.weight.scale", il, base);
596+
stid = gguf_find_tensor(gctx, sname);
597+
}
598+
if (stid < 0) return 1.0f;
599+
const size_t soff = data_start + gguf_get_tensor_offset(gctx, stid);
600+
if (soff + sizeof(float) > mm.len) return 1.0f;
601+
float val;
602+
std::memcpy(&val, (const uint8_t *)mm.addr + soff, sizeof(float));
603+
return val;
604+
};
605+
606+
int n_scales = 0;
607+
for (int il = 0; il < (int)n_layer; il++) {
608+
TargetLayer & L = out.layers[il];
609+
L.w_gate_s = read_scale(il, "ffn_gate");
610+
L.w_up_s = read_scale(il, "ffn_up");
611+
L.w_down_s = read_scale(il, "ffn_down");
612+
L.wq_s = read_scale(il, "attn_q");
613+
L.wk_s = read_scale(il, "attn_k");
614+
L.wv_s = read_scale(il, "attn_v");
615+
L.wo_s = read_scale(il, "attn_output");
616+
L.wqkv_s = read_scale(il, "attn_qkv");
617+
L.wqkv_gate_s = read_scale(il, "attn_gate");
618+
L.ssm_beta_s = read_scale(il, "ssm_beta");
619+
L.ssm_alpha_s = read_scale(il, "ssm_alpha");
620+
L.ssm_out_s = read_scale(il, "ssm_out");
621+
// Count non-trivial scales for the summary message.
622+
auto count_s = [&](float s) { if (s != 1.0f) n_scales++; };
623+
count_s(L.w_gate_s); count_s(L.w_up_s); count_s(L.w_down_s);
624+
count_s(L.wq_s); count_s(L.wk_s); count_s(L.wv_s);
625+
count_s(L.wo_s); count_s(L.wqkv_s); count_s(L.wqkv_gate_s);
626+
count_s(L.ssm_beta_s); count_s(L.ssm_alpha_s); count_s(L.ssm_out_s);
627+
}
628+
if (n_scales > 0) {
629+
std::printf("[loader] read %d NVFP4 per-tensor scale2 values (host-side, using ggml_scale)\n", n_scales);
630+
}
631+
}
632+
575633
gguf_free(gctx);
576634

577635
if (tok_embd_off == 0 || tok_embd_type == GGML_TYPE_COUNT) {

dflash/src/internal.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,26 @@ struct TargetLayer {
7171
ggml_tensor * ssm_dt_bias = nullptr; // [dt_rank] per-head alpha bias
7272
ggml_tensor * ssm_norm = nullptr; // [head_v_dim]
7373
ggml_tensor * ssm_out = nullptr; // output projection after delta-net
74+
75+
// NVFP4 per-tensor weight scales (optional; 1.0f = no scaling).
76+
// Each corresponds to a weight tensor above: result = mul_mat(w, x) * scale.
77+
// Stored as host-side floats (read from the GGUF at load time) and applied
78+
// via ggml_scale() — a compile-time scalar multiply with zero extra kernel
79+
// launches, unlike ggml_mul() with a [1]-shaped GPU tensor which adds 768
80+
// kernel launches per forward pass and causes catastrophic overhead in
81+
// batched DDTree verify mode.
82+
float w_gate_s = 1.0f;
83+
float w_up_s = 1.0f;
84+
float w_down_s = 1.0f;
85+
float wq_s = 1.0f;
86+
float wk_s = 1.0f;
87+
float wv_s = 1.0f;
88+
float wo_s = 1.0f;
89+
float wqkv_s = 1.0f;
90+
float wqkv_gate_s = 1.0f;
91+
float ssm_beta_s = 1.0f;
92+
float ssm_alpha_s = 1.0f;
93+
float ssm_out_s = 1.0f;
7494
};
7595

7696
// CPU-side embedder: keeps a mmap of the GGUF alive and knows how to

dflash/src/qwen35_target_graph.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,14 @@ bool restore_target_cache_chain(const PrefixSnapshot * thick,
678678

679679
// ─── Helpers ─────────────────────────────────────────────────────────
680680

681+
// NVFP4 scale2: if weight has a per-tensor scale, multiply the matmul result
682+
// by that scale. No-op when scale is nullptr (non-NVFP4 models).
683+
static ggml_tensor * apply_scale2(ggml_context * ctx, ggml_tensor * mm_result,
684+
float scale) {
685+
if (scale == 1.0f) return mm_result;
686+
return ggml_scale(ctx, mm_result, scale);
687+
}
688+
681689
static ggml_tensor * rms_norm_mul(ggml_context * ctx, ggml_tensor * x,
682690
ggml_tensor * weight, float eps) {
683691
ggml_tensor * n = ggml_rms_norm(ctx, x, eps);
@@ -686,10 +694,10 @@ static ggml_tensor * rms_norm_mul(ggml_context * ctx, ggml_tensor * x,
686694

687695
static ggml_tensor * build_swiglu_ffn(ggml_context * ctx, ggml_tensor * cur,
688696
const TargetLayer & L) {
689-
ggml_tensor * gate = ggml_mul_mat(ctx, L.w_gate, cur); // [inter, n_tokens]
690-
ggml_tensor * up = ggml_mul_mat(ctx, L.w_up, cur);
697+
ggml_tensor * gate = apply_scale2(ctx, ggml_mul_mat(ctx, L.w_gate, cur), L.w_gate_s);
698+
ggml_tensor * up = apply_scale2(ctx, ggml_mul_mat(ctx, L.w_up, cur), L.w_up_s);
691699
ggml_tensor * gu = ggml_swiglu_split(ctx, gate, up);
692-
return ggml_mul_mat(ctx, L.w_down, gu); // [hidden, n_tokens]
700+
return apply_scale2(ctx, ggml_mul_mat(ctx, L.w_down, gu), L.w_down_s);
693701
}
694702

695703
// Full-attention block (matches llama.cpp's build_layer_attn for qwen35)
@@ -721,7 +729,7 @@ static ggml_tensor * build_full_attn_block(
721729
const int q_dim = n_head * head_dim;
722730

723731
// ── Q projection (packed Q || gate), shape [2*q_dim, n_tokens]
724-
ggml_tensor * QG = ggml_mul_mat(ctx, L.wq, cur);
732+
ggml_tensor * QG = apply_scale2(ctx, ggml_mul_mat(ctx, L.wq, cur), L.wq_s);
725733
// Reshape to [head_dim*2, n_head, n_tokens] so we can view the Q and gate halves
726734
QG = ggml_reshape_3d(ctx, QG, head_dim * 2, n_head, n_tokens);
727735

@@ -743,8 +751,8 @@ static ggml_tensor * build_full_attn_block(
743751
gate = ggml_cont_2d(ctx, gate, q_dim, n_tokens); // [q_dim, n_tokens]
744752

745753
// ── K and V projections
746-
ggml_tensor * Kcur = ggml_mul_mat(ctx, L.wk, cur); // [kv_dim, n_tokens]
747-
ggml_tensor * Vcur = ggml_mul_mat(ctx, L.wv, cur); // [kv_dim, n_tokens]
754+
ggml_tensor * Kcur = apply_scale2(ctx, ggml_mul_mat(ctx, L.wk, cur), L.wk_s);
755+
ggml_tensor * Vcur = apply_scale2(ctx, ggml_mul_mat(ctx, L.wv, cur), L.wv_s);
748756

749757
Kcur = ggml_reshape_3d(ctx, Kcur, head_dim, n_head_kv, n_tokens);
750758
Kcur = rms_norm_mul(ctx, Kcur, L.k_norm, EPS);
@@ -850,7 +858,7 @@ static ggml_tensor * build_full_attn_block(
850858
attn = ggml_mul(ctx, attn, gate_sig);
851859

852860
// ── Output projection
853-
attn = ggml_mul_mat(ctx, L.wo, attn); // [hidden, n_tokens]
861+
attn = apply_scale2(ctx, ggml_mul_mat(ctx, L.wo, attn), L.wo_s);
854862
return attn;
855863
}
856864

@@ -885,22 +893,22 @@ static ggml_tensor * build_delta_net_block(
885893
const int n_seq_tokens = n_tokens;
886894

887895
// ── qkv_mixed = wqkv @ cur [conv_channels, n_tokens]
888-
ggml_tensor * qkv_mixed = ggml_mul_mat(ctx, L.wqkv, cur);
896+
ggml_tensor * qkv_mixed = apply_scale2(ctx, ggml_mul_mat(ctx, L.wqkv, cur), L.wqkv_s);
889897
qkv_mixed = ggml_reshape_3d(ctx, qkv_mixed, conv_channels, n_seq_tokens, n_seqs);
890898

891899
// ── z = wqkv_gate @ cur [inner, n_tokens]
892-
ggml_tensor * z = ggml_mul_mat(ctx, L.wqkv_gate, cur);
900+
ggml_tensor * z = apply_scale2(ctx, ggml_mul_mat(ctx, L.wqkv_gate, cur), L.wqkv_gate_s);
893901

894902
// ── beta = ssm_beta @ cur [dt_rank, n_tokens]
895-
ggml_tensor * beta = ggml_mul_mat(ctx, L.ssm_beta, cur);
903+
ggml_tensor * beta = apply_scale2(ctx, ggml_mul_mat(ctx, L.ssm_beta, cur), L.ssm_beta_s);
896904
beta = ggml_reshape_4d(ctx, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
897905
beta = ggml_sigmoid(ctx, beta);
898906

899907
// ── alpha = ssm_alpha @ cur [dt_rank, n_tokens]
900908
// alpha = alpha + ssm_dt_bias (per-head bias)
901909
// alpha = softplus(alpha)
902910
// g = alpha * ssm_a (-A_log.exp() * softplus)
903-
ggml_tensor * alpha = ggml_mul_mat(ctx, L.ssm_alpha, cur);
911+
ggml_tensor * alpha = apply_scale2(ctx, ggml_mul_mat(ctx, L.ssm_alpha, cur), L.ssm_alpha_s);
904912
alpha = ggml_reshape_3d(ctx, alpha, num_v_heads, n_seq_tokens, n_seqs);
905913
alpha = ggml_add(ctx, alpha, L.ssm_dt_bias);
906914
alpha = ggml_softplus(ctx, alpha);
@@ -1131,7 +1139,7 @@ static ggml_tensor * build_delta_net_block(
11311139
head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
11321140

11331141
// Output projection
1134-
ggml_tensor * out = ggml_mul_mat(ctx, L.ssm_out, flat);
1142+
ggml_tensor * out = apply_scale2(ctx, ggml_mul_mat(ctx, L.ssm_out, flat), L.ssm_out_s);
11351143
out = ggml_reshape_2d(ctx, out, w.n_embd, n_seq_tokens * n_seqs);
11361144
return out;
11371145
}

0 commit comments

Comments
 (0)