Skip to content

Commit f772f6e

Browse files
authored
model : support NVFP4 tensors for Gemma4 (ggml-org#21971)
* support nvfp4 tensors for Gemma4 * add wo_s to build_attn * add wo_s to build_attn * fix glm4
1 parent b572d1e commit f772f6e

105 files changed

Lines changed: 149 additions & 148 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/llama-graph.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,6 +2011,7 @@ ggml_tensor * llm_graph_context::build_attn(
20112011
llm_graph_input_attn_no_cache * inp,
20122012
ggml_tensor * wo,
20132013
ggml_tensor * wo_b,
2014+
ggml_tensor * wo_s,
20142015
ggml_tensor * q_cur,
20152016
ggml_tensor * k_cur,
20162017
ggml_tensor * v_cur,
@@ -2044,7 +2045,7 @@ ggml_tensor * llm_graph_context::build_attn(
20442045
cb(cur, "kqv_out", il);
20452046

20462047
if (wo) {
2047-
cur = build_lora_mm(wo, cur);
2048+
cur = build_lora_mm(wo, cur, wo_s);
20482049
}
20492050

20502051
if (wo_b) {
@@ -2095,6 +2096,7 @@ ggml_tensor * llm_graph_context::build_attn(
20952096
llm_graph_input_attn_kv * inp,
20962097
ggml_tensor * wo,
20972098
ggml_tensor * wo_b,
2099+
ggml_tensor * wo_s,
20982100
ggml_tensor * q_cur,
20992101
ggml_tensor * k_cur,
21002102
ggml_tensor * v_cur,
@@ -2146,10 +2148,15 @@ ggml_tensor * llm_graph_context::build_attn(
21462148
}
21472149

21482150
if (wo) {
2149-
cur = build_lora_mm(wo, cur);
21502151
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
21512152
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2153+
cur = build_lora_mm(wo, cur);
21522154
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2155+
if (wo_s) {
2156+
cur = ggml_mul(ctx0, cur, wo_s);
2157+
}
2158+
} else {
2159+
cur = build_lora_mm(wo, cur, wo_s);
21532160
}
21542161
}
21552162

@@ -2193,6 +2200,7 @@ ggml_tensor * llm_graph_context::build_attn(
21932200
llm_graph_input_attn_k * inp,
21942201
ggml_tensor * wo,
21952202
ggml_tensor * wo_b,
2203+
ggml_tensor * wo_s,
21962204
ggml_tensor * q_cur,
21972205
ggml_tensor * k_cur,
21982206
ggml_tensor * v_cur,
@@ -2227,10 +2235,15 @@ ggml_tensor * llm_graph_context::build_attn(
22272235
cb(cur, "kqv_out", il);
22282236

22292237
if (wo) {
2230-
cur = build_lora_mm(wo, cur);
22312238
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
22322239
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
2240+
cur = build_lora_mm(wo, cur);
22332241
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2242+
if (wo_s) {
2243+
cur = ggml_mul(ctx0, cur, wo_s);
2244+
}
2245+
} else {
2246+
cur = build_lora_mm(wo, cur, wo_s);
22342247
}
22352248
}
22362249

@@ -2245,6 +2258,7 @@ ggml_tensor * llm_graph_context::build_attn(
22452258
llm_graph_input_attn_kv_iswa * inp,
22462259
ggml_tensor * wo,
22472260
ggml_tensor * wo_b,
2261+
ggml_tensor * wo_s,
22482262
ggml_tensor * q_cur,
22492263
ggml_tensor * k_cur,
22502264
ggml_tensor * v_cur,
@@ -2313,7 +2327,7 @@ ggml_tensor * llm_graph_context::build_attn(
23132327
}
23142328

23152329
if (wo) {
2316-
cur = build_lora_mm(wo, cur);
2330+
cur = build_lora_mm(wo, cur, wo_s);
23172331
}
23182332

23192333
if (wo_b) {
@@ -2344,6 +2358,7 @@ ggml_tensor * llm_graph_context::build_attn(
23442358
llm_graph_input_attn_cross * inp,
23452359
ggml_tensor * wo,
23462360
ggml_tensor * wo_b,
2361+
ggml_tensor * wo_s,
23472362
ggml_tensor * q_cur,
23482363
ggml_tensor * k_cur,
23492364
ggml_tensor * v_cur,
@@ -2368,7 +2383,7 @@ ggml_tensor * llm_graph_context::build_attn(
23682383
cb(cur, "kqv_out", il);
23692384

23702385
if (wo) {
2371-
cur = build_lora_mm(wo, cur);
2386+
cur = build_lora_mm(wo, cur, wo_s);
23722387
}
23732388

23742389
if (wo_b) {

src/llama-graph.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ struct llm_graph_context {
892892
llm_graph_input_attn_no_cache * inp,
893893
ggml_tensor * wo,
894894
ggml_tensor * wo_b,
895+
ggml_tensor * wo_s,
895896
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
896897
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
897898
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
@@ -907,6 +908,7 @@ struct llm_graph_context {
907908
llm_graph_input_attn_kv * inp,
908909
ggml_tensor * wo,
909910
ggml_tensor * wo_b,
911+
ggml_tensor * wo_s,
910912
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
911913
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
912914
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
@@ -922,6 +924,7 @@ struct llm_graph_context {
922924
llm_graph_input_attn_k * inp,
923925
ggml_tensor * wo,
924926
ggml_tensor * wo_b,
927+
ggml_tensor * wo_s,
925928
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
926929
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
927930
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
@@ -938,6 +941,7 @@ struct llm_graph_context {
938941
llm_graph_input_attn_kv_iswa * inp,
939942
ggml_tensor * wo,
940943
ggml_tensor * wo_b,
944+
ggml_tensor * wo_s,
941945
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
942946
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
943947
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
@@ -953,6 +957,7 @@ struct llm_graph_context {
953957
llm_graph_input_attn_cross * inp,
954958
ggml_tensor * wo,
955959
ggml_tensor * wo_b,
960+
ggml_tensor * wo_s,
956961
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
957962
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
958963
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]

src/models/afmoe.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
8080
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
8181

8282
cur = build_attn(inp_attn,
83-
NULL, NULL, // wo will be applied after gating
83+
NULL, NULL, NULL, // wo will be applied after gating
8484
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
8585
cb(cur, "attn_out", il);
8686

@@ -91,7 +91,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
9191
cb(cur, "attn_gated", il);
9292

9393
// now apply output projection
94-
cur = build_lora_mm(model.layers[il].wo, cur);
94+
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
9595
cb(cur, "attn_o_proj", il);
9696
}
9797

src/models/apertus.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "models.h"
22

3-
4-
53
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
64
const int64_t n_embd_head = hparams.n_embd_head_v();
75

@@ -62,7 +60,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_
6260
cb(Vcur, "Vcur_pos", il);
6361

6462
cur = build_attn(inp_attn,
65-
model.layers[il].wo, model.layers[il].bo,
63+
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s,
6664
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
6765
cb(cur, "attn_out", il);
6866
}

src/models/arcee.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "models.h"
22

3-
43
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
54
const int64_t n_embd_head = hparams.n_embd_head_v();
65

@@ -78,7 +77,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para
7877
cb(Vcur, "Vcur", il);
7978

8079
cur = build_attn(inp_attn,
81-
model.layers[il].wo, model.layers[il].bo,
80+
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s,
8281
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
8382
cb(cur, "attn_out", il);
8483
}

src/models/arctic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa
6060
cb(Vcur, "Vcur", il);
6161

6262
cur = build_attn(inp_attn,
63-
model.layers[il].wo, NULL,
63+
model.layers[il].wo, NULL, model.layers[il].wo_s,
6464
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6565
}
6666

src/models/baichuan.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "models.h"
22

3-
43
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
54
const int64_t n_embd_head = hparams.n_embd_head_v();
65

@@ -67,7 +66,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap
6766
cb(Vcur, "Vcur", il);
6867

6968
cur = build_attn(inp_attn,
70-
model.layers[il].wo, NULL,
69+
model.layers[il].wo, NULL, model.layers[il].wo_s,
7170
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7271
}
7372

src/models/bailingmoe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_
7070
cb(Vcur, "Vcur", il);
7171

7272
cur = build_attn(inp_attn,
73-
model.layers[il].wo, model.layers[il].bo,
73+
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s,
7474
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
7575
}
7676

src/models/bailingmoe2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll
5656
cb(Vcur, "Vcur", il);
5757

5858
cur = build_attn(inp_attn,
59-
model.layers[il].wo, model.layers[il].bo,
59+
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s,
6060
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
6161
}
6262

src/models/bert.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
100100
cb(Vcur, "Vcur", il);
101101

102102
cur = build_attn(inp_attn,
103-
model.layers[il].wo, model.layers[il].bo,
103+
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s,
104104
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
105105
cb(cur, "kqv_out", il);
106106
}

0 commit comments

Comments
 (0)