Skip to content

Commit 44c51e5

Browse files
authored
model : allow causal_attn and pooling_type on all architectures (#20973)
* models : allow causal_attn and pooling_type on all architectures * fix: move location
1 parent 1922f87 commit 44c51e5

1 file changed

Lines changed: 2 additions & 18 deletions

File tree

src/llama-model.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
370370
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
371371
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
372372
ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false);
373+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
374+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
373375
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
374376
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
375377
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
@@ -748,8 +750,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
748750
case LLM_ARCH_BERT:
749751
{
750752
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
751-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
752-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
753753

754754
switch (hparams.n_layer) {
755755
case 3:
@@ -781,8 +781,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
781781
}
782782

783783
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
784-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
785-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
786784

787785
switch (hparams.n_layer) {
788786
case 12:
@@ -797,8 +795,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
797795
case LLM_ARCH_JINA_BERT_V2:
798796
{
799797
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
800-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
801-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
802798
hparams.f_max_alibi_bias = 8.0f;
803799

804800
switch (hparams.n_layer) {
@@ -810,8 +806,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
810806
case LLM_ARCH_JINA_BERT_V3:
811807
{
812808
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
813-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
814-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
815809

816810
switch (hparams.n_layer) {
817811
case 24:
@@ -823,8 +817,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
823817
case LLM_ARCH_NOMIC_BERT_MOE:
824818
{
825819
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
826-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
827-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
828820
ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0);
829821

830822
if (hparams.n_layer == 12 && hparams.n_embd == 768) {
@@ -838,8 +830,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
838830
case LLM_ARCH_NEO_BERT:
839831
{
840832
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
841-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
842-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
843833

844834
if (hparams.n_layer == 28) {
845835
type = LLM_TYPE_250M;
@@ -848,8 +838,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
848838
case LLM_ARCH_EUROBERT:
849839
{
850840
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
851-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
852-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
853841

854842
if (hparams.n_layer == 12) {
855843
type = LLM_TYPE_SMALL; // 0.2B
@@ -913,7 +901,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
913901
// fall through
914902
case LLM_ARCH_QWEN2:
915903
{
916-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
917904
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
918905
switch (hparams.n_layer) {
919906
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
@@ -995,7 +982,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
995982
} break;
996983
case LLM_ARCH_QWEN3:
997984
{
998-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
999985
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1000986
switch (hparams.n_layer) {
1001987
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
@@ -1287,7 +1273,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12871273
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
12881274
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
12891275
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1290-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
12911276

12921277
//applied only if model converted with --sentence-transformers-dense-modules
12931278
ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
@@ -2084,7 +2069,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
20842069
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
20852070
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps);
20862071
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
2087-
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false);
20882072
} break;
20892073
case LLM_ARCH_BAILINGMOE:
20902074
{

0 commit comments

Comments
 (0)