-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathQwen2State.java
More file actions
76 lines (62 loc) · 3.65 KB
/
Qwen2State.java
File metadata and controls
76 lines (62 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package org.beehive.gpullama3.inference.state;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import java.util.stream.Stream;
public class Qwen2State extends State {
public Qwen2State(Configuration config, int batchsize) {
super(config, batchsize);
this.localSize = 32;
}
@Override
protected StateFields createStateFields(Configuration configuration) {
StateFields fields = new StateFields();
Qwen2Configuration config = (Qwen2Configuration) configuration;
int nEmbdGqa = config.kvDim();
// with Qwen2-specific sizes
fields.x = ArrayFloatTensor.allocate(config.dim());
fields.xb = ArrayFloatTensor.allocate(config.dim());
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
fields.q = ArrayFloatTensor.allocate(config.dim());
fields.k = ArrayFloatTensor.allocate(config.kvDim());
fields.v = ArrayFloatTensor.allocate(config.kvDim());
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
// Key-value cache with Qwen2 dimensions
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
// TornadoVM wrappers with Qwen2 dimensions
switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
fields.wrapLogits = new FloatArray(config.vocabularySize());
fields.wrapQ = new FloatArray(config.dim());
fields.wrapK = new FloatArray(config.kvDim());
fields.wrapV = new FloatArray(config.kvDim());
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache.init(0.f);
fields.wrapKeyCache.init(0.f);
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
fields.positionHolder = new IntArray(1);
// Temporary arrays
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
return fields;
}
}