Skip to content

Commit 829e9a7

Browse files
authored
Support p-tuning v2 for ChatGLM family & fix rope theta for 32k/128k seqlen (#289)
1 parent 04910ce commit 829e9a7

9 files changed

Lines changed: 573 additions & 220 deletions

File tree

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGL
1515
Highlights:
1616
* Pure C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp).
1717
* Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing.
18+
* P-Tuning v2 and LoRA finetuned models support.
1819
* Streaming generation with typewriter effect.
1920
* Python binding, web demo, api servers and more possibilities.
2021

@@ -68,7 +69,9 @@ You are free to try any of the below quantization types by specifying `-t <type>
6869
* `f16`: half precision floating point weights without quantization.
6970
* `f32`: single precision floating point weights without quantization.
7071

71-
For LoRA model, add `-l <lora_model_name_or_path>` flag to merge your LoRA weights into the base model.
72+
For LoRA models, add `-l <lora_model_name_or_path>` flag to merge your LoRA weights into the base model. For example, run `python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml-lora.bin -l shibing624/chatglm3-6b-csc-chinese-lora` to merge public LoRA weights from Hugging Face.
73+
74+
For P-Tuning v2 models using the [official finetuning script](https://github.com/THUDM/ChatGLM3/tree/main/finetune_demo), additional weights are automatically detected by `convert.py`. If `past_key_values` is on the output weight list, the P-Tuning checkpoint is successfully converted.
7275

7376
**Build & Run**
7477

chatglm.cpp

Lines changed: 125 additions & 76 deletions
Large diffs are not rendered by default.

chatglm.h

Lines changed: 101 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,27 @@ struct ConfigRecordV1 {
7676
};
7777

7878
// For compatibility
79-
struct ConfigRecordV2 : public ConfigRecordV1 {
79+
struct ConfigRecordV1GQA : public ConfigRecordV1 {
8080
int num_kv_heads;
8181
};
8282

83+
// TODO: use json to serialize config
84+
struct ConfigRecordV2 {
85+
ggml_type dtype;
86+
int vocab_size;
87+
int hidden_size;
88+
int num_attention_heads;
89+
int num_key_value_heads;
90+
int num_hidden_layers;
91+
int intermediate_size;
92+
float norm_eps;
93+
int num_virtual_tokens;
94+
float rope_theta;
95+
int max_length;
96+
int eos_token_id;
97+
int pad_token_id;
98+
};
99+
83100
enum class ActivationType {
84101
GELU,
85102
SILU,
@@ -89,6 +106,7 @@ enum class RopeType {
89106
GPTJ = 0,
90107
NEOX = 2,
91108
CHATGLM = 4,
109+
CHATGLM2 = 8,
92110
DISABLED = 10000,
93111
};
94112

@@ -105,33 +123,44 @@ class ModelConfig {
105123
ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
106124
int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps,
107125
ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi,
108-
RopeType rope_type, int rope_dim_scale, AttentionMaskType attn_mask_type, int max_length,
109-
int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id,
110-
std::vector<int> extra_eos_token_ids)
126+
RopeType rope_type, float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type,
127+
int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id,
128+
int sep_token_id, std::vector<int> extra_eos_token_ids)
111129
: model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
112130
num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers),
113131
intermediate_size(intermediate_size), norm_eps(norm_eps), hidden_act(hidden_act), use_qkv_bias(use_qkv_bias),
114132
use_dense_bias(use_dense_bias), interleaved_qkv(interleaved_qkv), use_alibi(use_alibi), rope_type(rope_type),
115-
rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), max_length(max_length),
116-
bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id),
117-
sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
133+
rope_theta(rope_theta), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
134+
num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id),
135+
eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id),
136+
extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
118137

119138
ModelConfig(ModelType model_type, const ConfigRecordV1 &rec, float norm_eps, ActivationType hidden_act,
120139
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
121-
int rope_dim_scale, AttentionMaskType attn_mask_type)
140+
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
122141
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
123142
rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act,
124-
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale,
125-
attn_mask_type, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
126-
rec.sep_token_id, {}) {}
143+
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale,
144+
attn_mask_type, num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id,
145+
rec.pad_token_id, rec.sep_token_id, {}) {}
127146

128-
ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, float norm_eps, ActivationType hidden_act,
147+
ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, ActivationType hidden_act,
129148
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
130-
int rope_dim_scale, AttentionMaskType attn_mask_type)
149+
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
131150
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
132151
rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act, use_qkv_bias, use_dense_bias,
133-
interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type, rec.max_length,
134-
rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
152+
interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
153+
num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
154+
rec.sep_token_id, {}) {}
155+
156+
ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, ActivationType hidden_act, bool use_qkv_bias,
157+
bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
158+
AttentionMaskType attn_mask_type)
159+
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
160+
rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, rec.norm_eps, hidden_act,
161+
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rec.rope_theta,
162+
rope_dim_scale, attn_mask_type, rec.num_virtual_tokens, rec.max_length, -1, rec.eos_token_id,
163+
rec.pad_token_id, -1, {}) {}
135164

136165
std::string model_type_name() const { return to_string(model_type); }
137166

@@ -151,8 +180,10 @@ class ModelConfig {
151180
bool interleaved_qkv;
152181
bool use_alibi;
153182
RopeType rope_type;
183+
float rope_theta;
154184
int rope_dim_scale;
155185
AttentionMaskType attn_mask_type;
186+
int num_virtual_tokens;
156187
int max_length;
157188
int bos_token_id;
158189
int eos_token_id;
@@ -388,16 +419,17 @@ class BasicAttention {
388419
BasicAttention() = default;
389420
BasicAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
390421
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
391-
int rope_dim_scale, AttentionMaskType attn_mask_type)
422+
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
392423
: num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv),
393-
use_alibi(use_alibi), rope_type(rope_type), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
424+
use_alibi(use_alibi), rope_type(rope_type), rope_theta(rope_theta), rope_dim_scale(rope_dim_scale),
425+
attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens),
394426
query_key_value(ctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads,
395427
use_qkv_bias),
396428
dense(ctx, hidden_size, hidden_size, use_dense_bias),
397-
k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length,
398-
num_kv_heads)),
399-
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads,
400-
num_kv_heads)) {}
429+
k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads,
430+
max_length + num_virtual_tokens, num_kv_heads)),
431+
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens,
432+
hidden_size / num_attention_heads, num_kv_heads)) {}
401433

402434
ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
403435
int n_ctx) const;
@@ -408,12 +440,14 @@ class BasicAttention {
408440
bool interleaved_qkv;
409441
bool use_alibi;
410442
RopeType rope_type;
443+
float rope_theta;
411444
int rope_dim_scale;
412445
AttentionMaskType attn_mask_type;
446+
int num_virtual_tokens;
413447
Linear query_key_value;
414448
Linear dense;
415-
ggml_tensor *k_cache; // [kv_heads, max_len, head_size]
416-
ggml_tensor *v_cache; // [kv_heads, head_size, max_len]
449+
ggml_tensor *k_cache; // [#kvh, s, d]
450+
ggml_tensor *v_cache; // [#kvh, d, s]
417451
};
418452

419453
template <typename Norm, typename Attention, typename MLP>
@@ -422,11 +456,12 @@ class BasicBlock {
422456
BasicBlock() = default;
423457
BasicBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
424458
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
425-
bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
426-
AttentionMaskType attn_mask_type)
459+
bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
460+
AttentionMaskType attn_mask_type, int num_virtual_tokens)
427461
: input_layernorm(ctx, hidden_size, false, norm_eps),
428462
attention(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias,
429-
interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
463+
interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
464+
num_virtual_tokens),
430465
post_attention_layernorm(ctx, hidden_size, false, norm_eps),
431466
mlp(ctx, hidden_size, intermediate_size, hidden_act) {}
432467

@@ -517,16 +552,44 @@ class BasicModel {
517552
return hidden_states;
518553
}
519554

555+
void load_prefix_cache(const ModelConfig &config, ggml_tensor *past_key_values) {
556+
ggml_cgraph gf{};
557+
auto ctx = make_unique_ggml_context(config.num_hidden_layers * 7 * ggml_tensor_overhead(), nullptr, false);
558+
const int head_size = config.hidden_size / config.num_attention_heads;
559+
for (size_t i = 0; i < layers.size(); i++) {
560+
auto &attn = layers[i].attention;
561+
ggml_tensor *virtual_key = ggml_view_3d(ctx.get(), past_key_values, head_size, config.num_virtual_tokens,
562+
config.num_kv_heads, past_key_values->nb[1], past_key_values->nb[2],
563+
i * 2 * past_key_values->nb[3]); // [#h, v, d]
564+
ggml_tensor *k_cache_view =
565+
ggml_view_3d(ctx.get(), attn.k_cache, head_size, config.num_virtual_tokens, config.num_kv_heads,
566+
attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d]
567+
ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_key, k_cache_view));
568+
569+
ggml_tensor *virtual_value = ggml_view_3d(
570+
ctx.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_kv_heads,
571+
past_key_values->nb[1], past_key_values->nb[2], (i * 2 + 1) * past_key_values->nb[3]); // [#h, v, d]
572+
virtual_value = ggml_permute(ctx.get(), virtual_value, 1, 0, 2, 3); // [#h, d, v]
573+
ggml_tensor *v_cache_view =
574+
ggml_view_3d(ctx.get(), attn.v_cache, config.num_virtual_tokens, head_size, config.num_kv_heads,
575+
attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v]
576+
ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_value, v_cache_view));
577+
}
578+
CHATGLM_CHECK(ggml_used_mem(ctx.get()) == ggml_get_mem_size(ctx.get())) << "corrupted prefix cache context";
579+
std::vector<uninitialized_char> compute_buffer;
580+
ggml_graph_compute_helper(compute_buffer, &gf, 0);
581+
}
582+
520583
private:
521584
std::vector<Block> build_layers(ModelContext *ctx, const ModelConfig &config) {
522585
std::vector<Block> layers;
523586
layers.reserve(config.num_hidden_layers);
524587
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
525-
// TODO: reduce max length? 32k might be too large for cpu inference
526588
layers.emplace_back(ctx, config.hidden_size, config.num_attention_heads, config.num_kv_heads,
527589
config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act,
528590
config.use_qkv_bias, config.use_dense_bias, config.interleaved_qkv, config.use_alibi,
529-
config.rope_type, config.rope_dim_scale, config.attn_mask_type);
591+
config.rope_type, config.rope_theta, config.rope_dim_scale, config.attn_mask_type,
592+
config.num_virtual_tokens);
530593
}
531594
return layers;
532595
}
@@ -745,6 +808,8 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {
745808
return lm_logits;
746809
}
747810

811+
void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); }
812+
748813
protected:
749814
void to_cpu() {
750815
for (auto &item : state_dict_) {
@@ -818,13 +883,14 @@ class GLMBlock : public BasicBlock<LayerNorm, BasicAttention, BasicMLP> {
818883
GLMBlock() = default;
819884
GLMBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
820885
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
821-
bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
822-
AttentionMaskType attn_mask_type)
823-
: BasicBlock(
824-
LayerNorm(ctx, hidden_size, false, norm_eps),
825-
BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias,
826-
use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
827-
LayerNorm(ctx, hidden_size, false, norm_eps), BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
886+
bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
887+
AttentionMaskType attn_mask_type, int num_virtual_tokens)
888+
: BasicBlock(LayerNorm(ctx, hidden_size, false, norm_eps),
889+
BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length,
890+
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta,
891+
rope_dim_scale, attn_mask_type, num_virtual_tokens),
892+
LayerNorm(ctx, hidden_size, false, norm_eps),
893+
BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
828894
alpha_value(std::sqrt(2.f * 28)) {}
829895

830896
ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,

chatglm_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import chatglm_cpp._C as _C
77
from chatglm_cpp._C import ChatMessage
88

9-
__version__ = "0.3.1"
9+
__version__ = "0.3.2"
1010

1111

1212
@dataclass

0 commit comments

Comments
 (0)