-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathQwen2ModelLoader.java
More file actions
163 lines (141 loc) · 8.89 KB
/
Qwen2ModelLoader.java
File metadata and controls
163 lines (141 loc) · 8.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package org.beehive.gpullama3.model.loader;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tensor.GGUF;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
import org.beehive.gpullama3.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.auxiliary.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
import org.beehive.gpullama3.model.qwen2.DeepSeekR1Qwen;
import org.beehive.gpullama3.model.qwen2.Qwen2;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tokenizer.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import java.nio.channels.FileChannel;
import java.util.Map;
import static org.beehive.gpullama3.model.loader.ModelLoader.*;
public class Qwen2ModelLoader extends AbstractModelLoader<Qwen2, Qwen2Configuration> {
public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
super(fileChannel, gguf, contextLength, useTornadovm);
}
@Override
protected Vocabulary loadVocabulary(Map<String, Object> metadata) {
return Vocabulary.loadQwen3Vocabulary(metadata);
}
@Override
protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
}
// @formatter:off
@Override
protected Qwen2Configuration createConfiguration(Map<String, Object> metadata) {
int modelContextLength = (int) metadata.get("qwen2.context_length");
int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
int vocabSize = vocabulary.size();
return new Qwen2Configuration(
getModelQuantization(metadata),
(int) metadata.get("qwen2.embedding_length"), // dim
(int) metadata.get("qwen2.feed_forward_length"), // hiddendim
(int) metadata.get("qwen2.block_count"), // numberOfLayers
(int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
numberOfKeyValueHeads, // numberOfKeyValueHeads
numberOfKeyValueHeads, // numberOfHeadsKey
numberOfKeyValueHeads, // numberOfHeadsValue
vocabSize,
modelContextLength,
finalContextLength,
false,
(float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"),
(float) metadata.get("qwen2.rope.freq_base")
);
}
// @formatter:on
@Override
protected Pair<float[], float[]> precomputeRopeFrequencies(Qwen2Configuration config) {
return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192);
}
// @formatter:off
@Override
protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) {
Map<String, Object> metadata = gguf.getMetadata();
boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
// Qwen2.5-Coder uses <|endoftext|> as stop-token.
ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
return isDeepSeekR1DistillQwen
? new DeepSeekR1Qwen(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens))
: new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
// @formatter:on
// @formatter:off
@Override
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
final int nl = config.numberOfLayers();
return new Qwen2StandardWeights(
loadTensor(tokenEmbeddings),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
loadTensor(tensorEntries.get("output_norm.weight")),
new ArrayFloatTensor(ropeFreqs.first()),
new ArrayFloatTensor(ropeFreqs.second()),
loadTensor(outputWeight),
outputWeight.ggmlType()
);
}
// @formatter:on
// @formatter:off
@Override
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
// Validate supported types
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
}
final int nl = config.numberOfLayers();
// Load all tensors uniformly as TornadoTensor hierarchy
return new Qwen2TornadoWeights(
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
// Qwen2-specific: qkv bias (always F32)
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32
new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())),
new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())),
loadTornadoTensor(outputWeight),
ggmlType
);
}
// @formatter:off
}