@@ -76,10 +76,27 @@ struct ConfigRecordV1 {
7676};
7777
7878// For compatibility
79- struct ConfigRecordV2 : public ConfigRecordV1 {
79+ struct ConfigRecordV1GQA : public ConfigRecordV1 {
8080 int num_kv_heads;
8181};
8282
83+ // TODO: use json to serialize config
84+ struct ConfigRecordV2 {
85+ ggml_type dtype;
86+ int vocab_size;
87+ int hidden_size;
88+ int num_attention_heads;
89+ int num_key_value_heads;
90+ int num_hidden_layers;
91+ int intermediate_size;
92+ float norm_eps;
93+ int num_virtual_tokens;
94+ float rope_theta;
95+ int max_length;
96+ int eos_token_id;
97+ int pad_token_id;
98+ };
99+
83100enum class ActivationType {
84101 GELU,
85102 SILU,
@@ -89,6 +106,7 @@ enum class RopeType {
89106 GPTJ = 0 ,
90107 NEOX = 2 ,
91108 CHATGLM = 4 ,
109+ CHATGLM2 = 8 ,
92110 DISABLED = 10000 ,
93111};
94112
@@ -105,33 +123,44 @@ class ModelConfig {
105123 ModelConfig (ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
106124 int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps,
107125 ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi,
108- RopeType rope_type, int rope_dim_scale, AttentionMaskType attn_mask_type, int max_length ,
109- int bos_token_id , int eos_token_id , int pad_token_id , int sep_token_id ,
110- std::vector<int > extra_eos_token_ids)
126+ RopeType rope_type, float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type,
127+ int num_virtual_tokens , int max_length , int bos_token_id , int eos_token_id, int pad_token_id ,
128+ int sep_token_id, std::vector<int > extra_eos_token_ids)
111129 : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
112130 num_attention_heads (num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers),
113131 intermediate_size(intermediate_size), norm_eps(norm_eps), hidden_act(hidden_act), use_qkv_bias(use_qkv_bias),
114132 use_dense_bias(use_dense_bias), interleaved_qkv(interleaved_qkv), use_alibi(use_alibi), rope_type(rope_type),
115- rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), max_length(max_length),
116- bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id),
117- sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
133+ rope_theta(rope_theta), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
134+ num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id),
135+ eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id),
136+ extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
118137
119138 ModelConfig (ModelType model_type, const ConfigRecordV1 &rec, float norm_eps, ActivationType hidden_act,
120139 bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
121- int rope_dim_scale, AttentionMaskType attn_mask_type)
140+ float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens )
122141 : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
123142 rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act,
124- use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale,
125- attn_mask_type, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id ,
126- rec.sep_token_id, {}) {}
143+ use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale,
144+ attn_mask_type, num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id,
145+ rec.pad_token_id, rec. sep_token_id, {}) {}
127146
128- ModelConfig (ModelType model_type, const ConfigRecordV2 &rec, float norm_eps, ActivationType hidden_act,
147+ ModelConfig (ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, ActivationType hidden_act,
129148 bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
130- int rope_dim_scale, AttentionMaskType attn_mask_type)
149+ float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens )
131150 : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
132151 rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act, use_qkv_bias, use_dense_bias,
133- interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type, rec.max_length,
134- rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
152+ interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
153+ num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
154+ rec.sep_token_id, {}) {}
155+
156+ ModelConfig (ModelType model_type, const ConfigRecordV2 &rec, ActivationType hidden_act, bool use_qkv_bias,
157+ bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
158+ AttentionMaskType attn_mask_type)
159+ : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
160+ rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, rec.norm_eps, hidden_act,
161+ use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rec.rope_theta,
162+ rope_dim_scale, attn_mask_type, rec.num_virtual_tokens, rec.max_length, -1 , rec.eos_token_id,
163+ rec.pad_token_id, -1 , {}) {}
135164
136165 std::string model_type_name () const { return to_string (model_type); }
137166
@@ -151,8 +180,10 @@ class ModelConfig {
151180 bool interleaved_qkv;
152181 bool use_alibi;
153182 RopeType rope_type;
183+ float rope_theta;
154184 int rope_dim_scale;
155185 AttentionMaskType attn_mask_type;
186+ int num_virtual_tokens;
156187 int max_length;
157188 int bos_token_id;
158189 int eos_token_id;
@@ -388,16 +419,17 @@ class BasicAttention {
388419 BasicAttention () = default ;
389420 BasicAttention (ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
390421 bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
391- int rope_dim_scale, AttentionMaskType attn_mask_type)
422+ float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens )
392423 : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv),
393- use_alibi (use_alibi), rope_type(rope_type), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
424+ use_alibi (use_alibi), rope_type(rope_type), rope_theta(rope_theta), rope_dim_scale(rope_dim_scale),
425+ attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens),
394426 query_key_value(ctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads,
395427 use_qkv_bias),
396428 dense(ctx, hidden_size, hidden_size, use_dense_bias),
397- k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length,
398- num_kv_heads)),
399- v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads ,
400- num_kv_heads)) {}
429+ k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads,
430+ max_length + num_virtual_tokens, num_kv_heads)),
431+ v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens ,
432+ hidden_size / num_attention_heads, num_kv_heads)) {}
401433
402434 ggml_tensor *forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
403435 int n_ctx) const ;
@@ -408,12 +440,14 @@ class BasicAttention {
408440 bool interleaved_qkv;
409441 bool use_alibi;
410442 RopeType rope_type;
443+ float rope_theta;
411444 int rope_dim_scale;
412445 AttentionMaskType attn_mask_type;
446+ int num_virtual_tokens;
413447 Linear query_key_value;
414448 Linear dense;
415- ggml_tensor *k_cache; // [kv_heads, max_len, head_size ]
416- ggml_tensor *v_cache; // [kv_heads, head_size, max_len ]
449+ ggml_tensor *k_cache; // [#kvh, s, d ]
450+ ggml_tensor *v_cache; // [#kvh, d, s ]
417451};
418452
419453template <typename Norm, typename Attention, typename MLP>
@@ -422,11 +456,12 @@ class BasicBlock {
422456 BasicBlock () = default ;
423457 BasicBlock (ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
424458 int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
425- bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
426- AttentionMaskType attn_mask_type)
459+ bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
460+ AttentionMaskType attn_mask_type, int num_virtual_tokens )
427461 : input_layernorm(ctx, hidden_size, false , norm_eps),
428462 attention (ctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias,
429- interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
463+ interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
464+ num_virtual_tokens),
430465 post_attention_layernorm(ctx, hidden_size, false , norm_eps),
431466 mlp(ctx, hidden_size, intermediate_size, hidden_act) {}
432467
@@ -517,16 +552,44 @@ class BasicModel {
517552 return hidden_states;
518553 }
519554
555+ void load_prefix_cache (const ModelConfig &config, ggml_tensor *past_key_values) {
556+ ggml_cgraph gf{};
557+ auto ctx = make_unique_ggml_context (config.num_hidden_layers * 7 * ggml_tensor_overhead (), nullptr , false );
558+ const int head_size = config.hidden_size / config.num_attention_heads ;
559+ for (size_t i = 0 ; i < layers.size (); i++) {
560+ auto &attn = layers[i].attention ;
561+ ggml_tensor *virtual_key = ggml_view_3d (ctx.get (), past_key_values, head_size, config.num_virtual_tokens ,
562+ config.num_kv_heads , past_key_values->nb [1 ], past_key_values->nb [2 ],
563+ i * 2 * past_key_values->nb [3 ]); // [#h, v, d]
564+ ggml_tensor *k_cache_view =
565+ ggml_view_3d (ctx.get (), attn.k_cache , head_size, config.num_virtual_tokens , config.num_kv_heads ,
566+ attn.k_cache ->nb [1 ], attn.k_cache ->nb [2 ], 0 ); // [#h, v, d]
567+ ggml_build_forward_expand (&gf, ggml_cpy (ctx.get (), virtual_key, k_cache_view));
568+
569+ ggml_tensor *virtual_value = ggml_view_3d (
570+ ctx.get (), past_key_values, head_size, config.num_virtual_tokens , config.num_kv_heads ,
571+ past_key_values->nb [1 ], past_key_values->nb [2 ], (i * 2 + 1 ) * past_key_values->nb [3 ]); // [#h, v, d]
572+ virtual_value = ggml_permute (ctx.get (), virtual_value, 1 , 0 , 2 , 3 ); // [#h, d, v]
573+ ggml_tensor *v_cache_view =
574+ ggml_view_3d (ctx.get (), attn.v_cache , config.num_virtual_tokens , head_size, config.num_kv_heads ,
575+ attn.v_cache ->nb [1 ], attn.v_cache ->nb [2 ], 0 ); // [#h, d, v]
576+ ggml_build_forward_expand (&gf, ggml_cpy (ctx.get (), virtual_value, v_cache_view));
577+ }
578+ CHATGLM_CHECK (ggml_used_mem (ctx.get ()) == ggml_get_mem_size (ctx.get ())) << " corrupted prefix cache context" ;
579+ std::vector<uninitialized_char> compute_buffer;
580+ ggml_graph_compute_helper (compute_buffer, &gf, 0 );
581+ }
582+
520583 private:
521584 std::vector<Block> build_layers (ModelContext *ctx, const ModelConfig &config) {
522585 std::vector<Block> layers;
523586 layers.reserve (config.num_hidden_layers );
524587 for (int layer_id = 0 ; layer_id < config.num_hidden_layers ; layer_id++) {
525- // TODO: reduce max length? 32k might be too large for cpu inference
526588 layers.emplace_back (ctx, config.hidden_size , config.num_attention_heads , config.num_kv_heads ,
527589 config.intermediate_size , config.max_length , config.norm_eps , config.hidden_act ,
528590 config.use_qkv_bias , config.use_dense_bias , config.interleaved_qkv , config.use_alibi ,
529- config.rope_type , config.rope_dim_scale , config.attn_mask_type );
591+ config.rope_type , config.rope_theta , config.rope_dim_scale , config.attn_mask_type ,
592+ config.num_virtual_tokens );
530593 }
531594 return layers;
532595 }
@@ -745,6 +808,8 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {
745808 return lm_logits;
746809 }
747810
811+ void load_prefix_cache (ggml_tensor *past_key_values) { transformer.load_prefix_cache (config, past_key_values); }
812+
748813 protected:
749814 void to_cpu () {
750815 for (auto &item : state_dict_) {
@@ -818,13 +883,14 @@ class GLMBlock : public BasicBlock<LayerNorm, BasicAttention, BasicMLP> {
818883 GLMBlock () = default ;
819884 GLMBlock (ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
820885 int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
821- bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
822- AttentionMaskType attn_mask_type)
823- : BasicBlock(
824- LayerNorm (ctx, hidden_size, false , norm_eps),
825- BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias,
826- use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
827- LayerNorm(ctx, hidden_size, false , norm_eps), BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
886+ bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
887+ AttentionMaskType attn_mask_type, int num_virtual_tokens)
888+ : BasicBlock(LayerNorm(ctx, hidden_size, false , norm_eps),
889+ BasicAttention (ctx, hidden_size, num_attention_heads, num_attention_heads, max_length,
890+ use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta,
891+ rope_dim_scale, attn_mask_type, num_virtual_tokens),
892+ LayerNorm(ctx, hidden_size, false , norm_eps),
893+ BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
828894 alpha_value(std::sqrt(2 .f * 28 )) {}
829895
830896 ggml_tensor *forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
0 commit comments