|
4 | 4 | import org.beehive.gpullama3.model.Configuration; |
5 | 5 | import uk.ac.manchester.tornado.api.types.HalfFloat; |
6 | 6 | import uk.ac.manchester.tornado.api.types.arrays.*; |
| 7 | +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
| 8 | +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; |
| 9 | +import uk.ac.manchester.tornado.api.types.arrays.IntArray; |
7 | 10 |
|
8 | 11 | /** |
9 | 12 | * Represents the base state structure used during LLM inference. |
@@ -58,13 +61,17 @@ public abstract class State { |
58 | 61 | public final IntArray positionHolder; |
59 | 62 |
|
60 | 63 | public TornadoNativeArray embeddingX; |
| 64 | + |
| 65 | + public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. |
| 66 | + |
61 | 67 | // store inter |
62 | 68 | public int localSize; |
63 | 69 | public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. |
64 | 70 | public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size. |
65 | 71 | public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. |
66 | 72 | public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. |
67 | 73 |
|
| 74 | + public HalfFloatArray wrapXFP16; |
68 | 75 | /** last index in previous block */ |
69 | 76 |
|
70 | 77 | protected State(Configuration config, int batchsize) { |
@@ -100,6 +107,9 @@ protected State(Configuration config, int batchsize) { |
100 | 107 | this.wrapK = fields.wrapK; |
101 | 108 | this.wrapV = fields.wrapV; |
102 | 109 |
|
| 110 | + this.wrapXFP16 = fields.wrapXFP16; |
| 111 | + this.wrapXbFP16 = fields.wrapXbFP16; |
| 112 | + |
103 | 113 | // dim vs kvdim |
104 | 114 | this.wrapKeyCache = fields.wrapKeyCache; |
105 | 115 | this.wrapValueCache = fields.wrapValueCache; |
@@ -136,6 +146,7 @@ public void createActivationQ8_0(int size) { |
136 | 146 | int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES; |
137 | 147 | this.embeddingX = new ByteArray(q8BytesNeeded); |
138 | 148 | } |
| 149 | + public HalfFloatArray wrapXFP16, wrapXbFP16; |
139 | 150 | } |
140 | 151 |
|
141 | 152 | @Override |
|
0 commit comments