Skip to content

Commit 9f102a1

Browse files
authored
models : move the token embedding norms to the first layer (ggml-org#20943)
* models : move the token embedding norms to the first layer * cont : fix LLM_TENSOR_CONV1D + fix il indexing
1 parent 3fc6f1a commit 9f102a1

8 files changed

Lines changed: 26 additions & 26 deletions

File tree

src/llama-arch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,7 +2564,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
25642564
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
25652565
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
25662566
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
2567-
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
2567+
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
25682568
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
25692569
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
25702570
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@@ -2725,7 +2725,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
27252725
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
27262726
// this tensor is loaded for T5, but never used
27272727
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
2728-
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
2728+
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
27292729
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
27302730
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
27312731
{LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-model.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3217,8 +3217,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32173217
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
32183218
}
32193219

3220-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
3221-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
3220+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
3221+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0);
32223222

32233223
for (int i = 0; i < n_layer; ++i) {
32243224
auto & layer = layers[i];
@@ -3265,7 +3265,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32653265
case LLM_ARCH_MODERN_BERT:
32663266
{
32673267
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3268-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
3268+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
32693269

32703270
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
32713271

@@ -3348,8 +3348,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33483348
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
33493349
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings
33503350

3351-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
3352-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias
3351+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm
3352+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias
33533353

33543354
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
33553355
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED);
@@ -3400,8 +3400,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34003400
case LLM_ARCH_BLOOM:
34013401
{
34023402
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3403-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
3404-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
3403+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
3404+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0);
34053405

34063406
// output
34073407
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -5780,8 +5780,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
57805780
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
57815781

57825782
// Block 0, LN0
5783-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
5784-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
5783+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
5784+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0);
57855785

57865786
// output
57875787
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -5895,8 +5895,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
58955895
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
58965896

58975897
// Block 0, LN0
5898-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
5899-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
5898+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
5899+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0);
59005900

59015901
// output
59025902
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -6067,8 +6067,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
60676067
{
60686068
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0);
60696069

6070-
conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
6071-
conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0);
6070+
conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
6071+
conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0);
60726072

60736073
// posnet
60746074
{
@@ -6133,8 +6133,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
61336133

61346134
GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
61356135

6136-
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
6137-
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0);
6136+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0);
6137+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0);
61386138

61396139
// convnext
61406140
{

src/models/bert.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
2828
cb(inpL, "inp_embd", -1);
2929

3030
// embed layer norm
31-
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
32-
cb(inpL, "inp_norm", -1);
31+
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
32+
cb(inpL, "inp_norm", 0);
3333

3434
auto * inp_attn = build_attn_inp_no_cache();
3535

src/models/bloom.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para
1616
inpL = build_norm(inpL,
1717
model.tok_norm,
1818
model.tok_norm_b,
19-
LLM_NORM, -1);
20-
cb(inpL, "inp_norm", -1);
19+
LLM_NORM, 0);
20+
cb(inpL, "inp_norm", 0);
2121

2222
ggml_tensor * inp_out_ids = build_inp_out_ids();
2323

src/models/modern-bert.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll
1515
cb(inpL, "inp_embd", -1);
1616

1717
// embed layer norm
18-
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
19-
cb(inpL, "inp_norm", -1);
18+
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0);
19+
cb(inpL, "inp_norm", 0);
2020

2121
ggml_tensor * inp_out_ids = build_inp_out_ids();
2222

src/models/rwkv6.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para
88
ggml_tensor * inpL;
99

1010
inpL = build_inp_embd(model.tok_embd);
11-
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
11+
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
1212

1313
auto * rs_inp = build_rs_inp();
1414

src/models/rwkv7.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
99
ggml_tensor * v_first = nullptr;
1010

1111
inpL = build_inp_embd(model.tok_embd);
12-
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12+
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
1313

1414
auto * rs_inp = build_rs_inp();
1515

src/models/wavtokenizer-dec.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model
9393
cur = build_norm(cur,
9494
model.tok_norm,
9595
model.tok_norm_b,
96-
LLM_NORM, -1);
96+
LLM_NORM, 0);
9797

9898
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
9999

0 commit comments

Comments
 (0)