-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathGraniteState.java
More file actions
83 lines (71 loc) · 4.1 KB
/
GraniteState.java
File metadata and controls
83 lines (71 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package org.beehive.gpullama3.inference.state;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import java.util.stream.Stream;
/**
* Represents the state of the Granite model during inference.
* This class extends {@link State} to include model-specific functionalities
* and configurations tailored for the Granite model.
*
* <p><b>Note:</b> GraniteState tensor shapes are identical to LlamaState
* since Granite uses the same transformer architecture as Llama,
* with differences only in the scaling factors applied.</p>
*/
public final class GraniteState extends State {
public GraniteState(Configuration config, int batchsize) {
super(config, batchsize);
}
@Override
protected StateFields createStateFields(Configuration config) {
StateFields fields = new StateFields();
// Allocation with Granite dimensions (identical to Llama)
fields.x = ArrayFloatTensor.allocate(config.dim());
fields.xb = ArrayFloatTensor.allocate(config.dim());
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
fields.q = ArrayFloatTensor.allocate(config.dim());
fields.k = ArrayFloatTensor.allocate(config.dim());
fields.v = ArrayFloatTensor.allocate(config.dim());
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
// Key-value cache with Granite dimensions
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
// TornadoVM wrappers with Granite dimensions
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());
switch (config.quantization()) {
case "FP16" -> fields.createActivationFP16(config.dim());
case "Q8_0" -> fields.createActivationQ8_0(config.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
}
fields.wrapLogits = new FloatArray(config.vocabularySize());
fields.wrapQ = new FloatArray(config.dim());
fields.wrapK = new FloatArray(config.dim());
fields.wrapV = new FloatArray(config.dim());
fields.wrapXFP16 = new HalfFloatArray(config.dim());
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
// dim vs kvdim
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
fields.wrapValueCache.init(0.f);
fields.wrapKeyCache.init(0.f);
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
fields.positionHolder = new IntArray(1);
// Temporary arrays
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
return fields;
}
}