|
| 1 | +package org.beehive.gpullama3.inference; |
| 2 | + |
| 3 | +import org.beehive.gpullama3.auxiliary.Parallel; |
| 4 | +import org.beehive.gpullama3.inference.state.State; |
| 5 | +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; |
| 6 | +import org.beehive.gpullama3.model.Configuration; |
| 7 | +import org.beehive.gpullama3.model.Model; |
| 8 | +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; |
| 9 | +import org.beehive.gpullama3.tensor.standard.FloatTensor; |
| 10 | +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; |
| 11 | +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
| 12 | + |
| 13 | +/** |
| 14 | + * Low-level forward passes for the batched prefill/decode inference path (Phase 3/4). |
| 15 | + * |
| 16 | + * <p>Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.</p> |
| 17 | + * |
| 18 | + * <p>Provides three operations:</p> |
| 19 | + * <ul> |
| 20 | + * <li>{@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of |
| 21 | + * prompt tokens in one pass using batch matmul, avoiding redundant weight |
| 22 | + * traversals. Only the KV cache is populated; logits are intentionally omitted.</li> |
| 23 | + * <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk |
| 24 | + * to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.</li> |
| 25 | + * <li>{@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to |
| 26 | + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which |
| 27 | + * handles the embedding copy and runs the full decode + logits graphs.</li> |
| 28 | + * </ul> |
| 29 | + */ |
| 30 | +public final class InferenceCoreBatchPrefillDecode { |
| 31 | + |
| 32 | + private InferenceCoreBatchPrefillDecode() {} |
| 33 | + |
| 34 | + /** |
| 35 | + * CPU batched prefill forward pass for LLaMA (Phase 3). |
| 36 | + * |
| 37 | + * <p>Processes {@code batchSize} prompt tokens simultaneously through all |
| 38 | + * transformer layers. For each layer, Q/K/V projections, output projection, |
| 39 | + * and FFN projections are computed via batch matmul |
| 40 | + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), |
| 41 | + * which parallelises over both output dimension and batch simultaneously. |
| 42 | + * Attention reuses {@code state.att} sequentially per token (parallel per |
| 43 | + * head within each token), keeping memory overhead minimal.</p> |
| 44 | + * |
| 45 | + * <p>The logits layer is intentionally omitted — only the KV cache matters |
| 46 | + * for prefill positions.</p> |
| 47 | + * |
| 48 | + * @param model the LLaMA model (must carry {@link StandardWeights}) |
| 49 | + * @param state mutable inference state (KV cache, att buffer …) |
| 50 | + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} |
| 51 | + * @param startPos sequence position of {@code tokens[0]} |
| 52 | + * @param batchSize number of tokens in this chunk ({@code tokens.length}) |
| 53 | + */ |
| 54 | + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { |
| 55 | + final Configuration config = model.configuration(); |
| 56 | + final StandardWeights weights = (StandardWeights) model.weights(); |
| 57 | + int dim = config.dim(); |
| 58 | + int headSize = config.headSize(); |
| 59 | + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); |
| 60 | + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); |
| 61 | + float sqrtHeadSize = (float) Math.sqrt(headSize); |
| 62 | + |
| 63 | + // ── Batch activation tensors (allocated once per chunk) ─────────────── |
| 64 | + FloatTensor[] x = new FloatTensor[batchSize]; |
| 65 | + FloatTensor[] xb = new FloatTensor[batchSize]; |
| 66 | + FloatTensor[] xb2 = new FloatTensor[batchSize]; |
| 67 | + FloatTensor[] q = new FloatTensor[batchSize]; |
| 68 | + FloatTensor[] k = new FloatTensor[batchSize]; |
| 69 | + FloatTensor[] v = new FloatTensor[batchSize]; |
| 70 | + FloatTensor[] hb = new FloatTensor[batchSize]; |
| 71 | + FloatTensor[] hb2 = new FloatTensor[batchSize]; |
| 72 | + for (int b = 0; b < batchSize; b++) { |
| 73 | + x[b] = ArrayFloatTensor.allocate(dim); |
| 74 | + xb[b] = ArrayFloatTensor.allocate(dim); |
| 75 | + xb2[b] = ArrayFloatTensor.allocate(dim); |
| 76 | + q[b] = ArrayFloatTensor.allocate(dim); |
| 77 | + k[b] = ArrayFloatTensor.allocate(kvDim); |
| 78 | + v[b] = ArrayFloatTensor.allocate(kvDim); |
| 79 | + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); |
| 80 | + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); |
| 81 | + } |
| 82 | + |
| 83 | + // ── Token embeddings ────────────────────────────────────────────────── |
| 84 | + Parallel.parallelFor(0, batchSize, b -> |
| 85 | + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); |
| 86 | + |
| 87 | + // ── Transformer layers ──────────────────────────────────────────────── |
| 88 | + for (int l = 0; l < config.numberOfLayers(); l++) { |
| 89 | + final int layer = l; |
| 90 | + |
| 91 | + Parallel.parallelFor(0, batchSize, b -> |
| 92 | + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); |
| 93 | + |
| 94 | + weights.wq[l].matmul(batchSize, xb, q, dim, dim); |
| 95 | + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); |
| 96 | + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); |
| 97 | + |
| 98 | + Parallel.parallelFor(0, batchSize, b -> { |
| 99 | + int pos = startPos + b; |
| 100 | + for (int i = 0; i < dim; i += 2) { |
| 101 | + int head_dim = i % headSize; |
| 102 | + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); |
| 103 | + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); |
| 104 | + int rotn = i < kvDim ? 2 : 1; |
| 105 | + for (int vv = 0; vv < rotn; vv++) { |
| 106 | + FloatTensor vec = vv == 0 ? q[b] : k[b]; |
| 107 | + float v0 = vec.getFloat(i); |
| 108 | + float v1 = vec.getFloat(i + 1); |
| 109 | + vec.setFloat(i, v0 * fcr - v1 * fci); |
| 110 | + vec.setFloat(i + 1, v0 * fci + v1 * fcr); |
| 111 | + } |
| 112 | + } |
| 113 | + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); |
| 114 | + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); |
| 115 | + }); |
| 116 | + |
| 117 | + for (int b = 0; b < batchSize; b++) { |
| 118 | + final int pos_b = startPos + b; |
| 119 | + final int bFinal = b; |
| 120 | + Parallel.parallelFor(0, config.numberOfHeads(), h -> { |
| 121 | + int qOffset = h * headSize; |
| 122 | + int attOffset = h * config.contextLength(); |
| 123 | + |
| 124 | + for (int t = 0; t <= pos_b; t++) { |
| 125 | + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; |
| 126 | + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; |
| 127 | + state.att.setFloat(attOffset + t, score); |
| 128 | + } |
| 129 | + state.att.softmaxInPlace(attOffset, pos_b + 1); |
| 130 | + |
| 131 | + int xbOffset = h * headSize; |
| 132 | + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); |
| 133 | + for (int t = 0; t <= pos_b; t++) { |
| 134 | + int vOffset = t * kvDim + (h / kvMul) * headSize; |
| 135 | + float a = state.att.getFloat(attOffset + t); |
| 136 | + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); |
| 137 | + } |
| 138 | + }); |
| 139 | + } |
| 140 | + |
| 141 | + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); |
| 142 | + |
| 143 | + Parallel.parallelFor(0, batchSize, b -> { |
| 144 | + x[b].addInPlace(xb2[b]); |
| 145 | + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); |
| 146 | + }); |
| 147 | + |
| 148 | + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); |
| 149 | + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); |
| 150 | + |
| 151 | + Parallel.parallelFor(0, batchSize, b -> { |
| 152 | + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); |
| 153 | + hb[b].multiplyInPlace(hb2[b]); |
| 154 | + }); |
| 155 | + |
| 156 | + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); |
| 157 | + |
| 158 | + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); |
| 159 | + } |
| 160 | + // Final RMSNorm and vocab projection intentionally omitted — |
| 161 | + // logits are not needed for any token in a prefill batch. |
| 162 | + } |
| 163 | + |
| 164 | + /** |
| 165 | + * GPU batched prefill forward pass (Phase 4). |
| 166 | + * |
| 167 | + * <p>Delegates the full chunk to |
| 168 | + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, |
| 169 | + * which handles embedding lookup and GPU execution internally.</p> |
| 170 | + * |
| 171 | + * @param model the LLaMA model |
| 172 | + * @param tokens token ids for this chunk |
| 173 | + * @param startPos sequence position of {@code tokens[0]} |
| 174 | + * @param chunkSize number of tokens in this chunk |
| 175 | + * @param plan the batched prefill/decode GPU plan |
| 176 | + */ |
| 177 | + public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, |
| 178 | + TornadoVMMasterPlanWithBatchPrefillDecode plan) { |
| 179 | + plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); |
| 180 | + } |
| 181 | + |
| 182 | + /** |
| 183 | + * GPU decode forward pass (Phase 4). |
| 184 | + * |
| 185 | + * <p>Delegates a single-token decode step to |
| 186 | + * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, |
| 187 | + * which copies the token embedding and runs the decode + logits graphs.</p> |
| 188 | + * |
| 189 | + * @param model the LLaMA model |
| 190 | + * @param token current token id |
| 191 | + * @param position sequence position |
| 192 | + * @param plan the batched prefill/decode GPU plan |
| 193 | + * @return logits array for token sampling |
| 194 | + */ |
| 195 | + public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, |
| 196 | + TornadoVMMasterPlanWithBatchPrefillDecode plan) { |
| 197 | + return plan.tornadoVMForwardDecode(token, position, model); |
| 198 | + } |
| 199 | +} |
0 commit comments