-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathLlamaState.java
More file actions
84 lines (72 loc) · 4.12 KB
/
LlamaState.java
File metadata and controls
84 lines (72 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package org.beehive.gpullama3.inference.state;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import java.util.stream.Stream;
/**
* Represents the state of the Llama model during inference.
* This class extends {@link State} to include model-specific functionalities
* and configurations tailored for the Llama model.
*
* <p><b>Note 1:</b> LlamaState contains additional fields for TornadoVM wrappers
* to enable GPU-accelerated processing of the model.</p>
*
* <p><b>Note 2:</b> This state implementation is also used for the Mistral model.</p>
*/
public final class LlamaState extends State {
public LlamaState(Configuration config, int batchsize) {
super(config, batchsize);
}
@Override
protected StateFields createStateFields(Configuration config) {
StateFields fields = new StateFields();
// Allocation with Llama/Mistral dimensions
fields.x = ArrayFloatTensor.allocate(config.dim());
fields.xb = ArrayFloatTensor.allocate(config.dim());
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
fields.q = ArrayFloatTensor.allocate(config.dim());
fields.k = ArrayFloatTensor.allocate(config.dim());
fields.v = ArrayFloatTensor.allocate(config.dim());
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
// Key-value cache with Llama/Mistral dimensions
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
// TornadoVM wrappers with Llama/Mistral dimensions
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapLogits = new FloatArray(config.vocabularySize());
fields.wrapQ = new FloatArray(config.dim());
fields.wrapK = new FloatArray(config.dim());
fields.wrapV = new FloatArray(config.dim());
fields.wrapXFP16 = new HalfFloatArray(config.dim());
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
// dim vs kvdim
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache.init(0.f);
fields.wrapKeyCache.init(0.f);
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
fields.positionHolder = new IntArray(1);
// Temporary arrays
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
return fields;
}
}