diff --git a/llama-tornado b/llama-tornado index 57a50f1c..4900a100 100755 --- a/llama-tornado +++ b/llama-tornado @@ -87,6 +87,15 @@ class LlamaRunner: if args.verbose_init: cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") + if args.batched_prefill: + cmd.append("-Dllama.batchedPrefill=true") + + if args.prefill_batch_size is not None: + cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + + if args.no_cuda_graphs: + cmd.append("-Dllama.cudaGraphs=false") + # Debug options debug_config = [] @@ -472,6 +481,28 @@ def create_parser() -> argparse.ArgumentParser: help="Execute the command after showing it (use with --show-command)", ) + # Prefill/Decode optimization + prefill_group = parser.add_argument_group("Prefill/Decode Optimization") + prefill_group.add_argument( + "--batched-prefill", + dest="batched_prefill", + action="store_true", + help="Enable batched prefill/decode separation (llama.batchedPrefill=true)", + ) + prefill_group.add_argument( + "--prefill-batch-size", + dest="prefill_batch_size", + type=int, + default=None, + help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", + ) + prefill_group.add_argument( + "--no-cuda-graphs", + dest="no_cuda_graphs", + action="store_true", + help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging", + ) + # Advanced options advanced_group = parser.add_argument_group("Advanced Options") advanced_group.add_argument( diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java new file mode 100644 index 00000000..460bb9af --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -0,0 +1,306 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.lang.foreign.MemorySegment; + +/** + * Low-level forward passes for the prefill/decode separated inference path. + * + *
Parallel to {@link InferenceCore} — does NOT modify it.
+ * + *The key addition is {@link #forwardJavaPrefill}, which runs a full + * transformer forward pass but skips the final RMSNorm and vocabulary + * projection (wcls matmul). This is correct for all prefill positions + * because their logits are discarded anyway; only the KV-cache update + * matters. Skipping the projection saves one large matmul + * (vocabularySize × dim) per prefill token.
+ */ +public final class InferenceCoreWithPrefillDecode { + + private InferenceCoreWithPrefillDecode() {} + + /** + * Prefill-only forward pass for LLaMA (CPU, FP32 weights). + * + *Identical to {@link InferenceCore#forwardJava} except the final + * RMSNorm and vocabulary projection are omitted. The KV cache is + * populated correctly at {@code position}.
+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, activations …) + * @param token input token id + * @param position sequence position being processed + */ + public static void forwardJavaPrefill(Model model, State state, int token, int position) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // Token embedding + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // Transformer layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // Attention RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // QKV projections + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // KV cache update + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + // Multi-head attention + int curLayer = l; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // Attention output projection + residual + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + state.x.addInPlace(state.xb2); + + // FFN RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN (SwiGLU) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // FFN residual + state.x.addInPlace(state.xb); + } + + // Final RMSNorm and vocab projection intentionally omitted: + // logits are not needed for prefill positions — only the KV cache matters. + } + + /** + * CPU batched prefill forward pass for LLaMA (Phase 3). + * + *Processes {@code batchSize} prompt tokens simultaneously through all + * transformer layers. For each layer, Q/K/V projections, output projection, + * and FFN projections are computed via batch matmul + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), + * which parallelises over both output dimension and batch simultaneously. + * Attention reuses {@code state.att} sequentially per token (parallel per + * head within each token), keeping memory overhead minimal.
+ * + *The logits layer is intentionally omitted — only the KV cache matters + * for prefill positions.
+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, att buffer …) + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} + * @param startPos sequence position of {@code tokens[0]} + * @param batchSize number of tokens in this chunk ({@code tokens.length}) + */ + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // ── Batch activation tensors (allocated once per chunk) ─────────────── + FloatTensor[] x = new FloatTensor[batchSize]; + FloatTensor[] xb = new FloatTensor[batchSize]; + FloatTensor[] xb2 = new FloatTensor[batchSize]; + FloatTensor[] q = new FloatTensor[batchSize]; + FloatTensor[] k = new FloatTensor[batchSize]; + FloatTensor[] v = new FloatTensor[batchSize]; + FloatTensor[] hb = new FloatTensor[batchSize]; + FloatTensor[] hb2 = new FloatTensor[batchSize]; + for (int b = 0; b < batchSize; b++) { + x[b] = ArrayFloatTensor.allocate(dim); + xb[b] = ArrayFloatTensor.allocate(dim); + xb2[b] = ArrayFloatTensor.allocate(dim); + q[b] = ArrayFloatTensor.allocate(dim); + k[b] = ArrayFloatTensor.allocate(kvDim); + v[b] = ArrayFloatTensor.allocate(kvDim); + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + } + + // ── Token embeddings ────────────────────────────────────────────────── + Parallel.parallelFor(0, batchSize, b -> + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); + + // ── Transformer layers ──────────────────────────────────────────────── + for (int l = 0; l < config.numberOfLayers(); l++) { + final int layer = l; + + // Attention RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); + + // QKV projections — batch matmul parallelises over (dim × batchSize) + weights.wq[l].matmul(batchSize, xb, q, dim, dim); + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); + + // RoPE + KV cache store (parallel per b — different positions, no conflict) + Parallel.parallelFor(0, batchSize, b -> { + int pos = startPos + b; + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int vv = 0; vv < rotn; vv++) { + FloatTensor vec = vv == 0 ? q[b] : k[b]; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); + }); + + // Attention — sequential per b (state.att is shared), parallel per head + for (int b = 0; b < batchSize; b++) { + final int pos_b = startPos + b; + final int bFinal = b; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= pos_b; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + state.att.softmaxInPlace(attOffset, pos_b + 1); + + int xbOffset = h * headSize; + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= pos_b; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); + } + }); + } + + // Output projection — batch matmul + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); + + // Residual + FFN RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + x[b].addInPlace(xb2[b]); + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); + }); + + // FFN projections — batch matmul + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); + + // SwiGLU (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + hb[b].multiplyInPlace(hb2[b]); + }); + + // w2 projection — batch matmul (output reuses xb) + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); + + // FFN residual (parallel per b) + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); + } + + // Final RMSNorm and vocab projection intentionally omitted — + // logits are not needed for any token in a prefill batch. + } + + /** + * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). + * + *Copies the token embedding into {@code state.embeddingX} (same as + * {@link InferenceCore#forwardTornadoVM}) then delegates to + * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * which executes preprocessing + layer graphs but skips the logits graph.
+ * + * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) + * @param state mutable inference state + * @param token input token id + * @param position sequence position being processed + * @param prefillPlan the prefill/decode plan wrapper + * @throws UnsupportedOperationException if the model uses Q8_0 weights + */ + public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, + TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + final Configuration configuration = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, + state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); + } + case Q8_0 -> throw new UnsupportedOperationException( + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + prefillPlan.tornadoVMForwardPrefill(position); + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java new file mode 100644 index 00000000..74c474e9 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -0,0 +1,326 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the prefill/decode separated inference path. + * + *Parallel to {@link InferenceEngine} — does NOT modify it.
+ * + *The split loop runs two phases:
+ *Activated by {@code -Dllama.batchedPrefill=true} (set via + * {@code --batched-prefill} in the Python launcher).
+ */ +public final class InferenceEngineWithPrefillDecode { + + private InferenceEngineWithPrefillDecode() {} + + /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3/4). */ + static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); + + /** + * LLaMA token generation with prefill/decode separation (CPU). + * + *When {@code llama.prefillBatchSize > 1} (Phase 3), prompt tokens are + * processed in chunks of that size using batch matmul, which traverses each + * weight matrix once per chunk instead of once per token.
+ * + *When {@code llama.prefillBatchSize == 1} (Phase 1), falls back to + * sequential single-token prefill (skip logits only).
+ * + *Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.
+ */ + public static ListDrop-in replacement for + * {@link InferenceEngine#generateTokensGPULlama} when the batched-prefill + * flag is enabled. FP16 only; Q8_0 throws {@link UnsupportedOperationException}.
+ * + *Split loop:
+ *Two concrete implementations exist:
+ *The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation + * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a + * {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a + * {@link TornadoVMMasterPlanStandard}.
+ */ +public interface TornadoVMMasterPlan { + + boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( + System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + /** When {@code false}, {@code withCUDAGraph()} is never called — useful for debugging. */ + boolean CUDA_GRAPHS = Boolean.parseBoolean( + System.getProperty("llama.cudaGraphs", "true")); /** - * Determines whether the NVIDIA-specific scheduler should be used based on the current - * hardware backend and the model type. - *- * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the - * NVIDIA-specific scheduler should not be used. + * Single-token forward pass returning output logits. * - * @param model - * the model whose type may affect the scheduler decision - * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise - */ - - /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context - * window. - * - *
The execution happens in three phases: - *
Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) + * and the Phase 2 sequential decode path. Not applicable to + * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.
* - * @param position - * The current position in the sequence being processed - * @return FloatTensor containing the output logits for token prediction + * @param position sequence position of the current token + * @return logits array for token sampling */ + FloatArray tornadoVMForwardExecuteLayered(int position); - // int pos, ModelPlanner - public FloatArray tornadoVMForwardExecuteLayered(int position) { - // @formatter:off - // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); - - // Set the position in the state object (used by attention layers) - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - // 2. Execute each transformer layer graph sequentially - // Each graph computes attention and feed-forward transformations for one layer - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); - } - state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f - state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f - // 3. Execute the final graph that projects the last hidden state to output logits - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .execute(); - - // @formatter:on - // Return the logits (used for token prediction) - return state.wrapLogits; - } + /** Releases all device memory held by this plan. */ + void freeTornadoExecutionPlan(); /** - * Returns the graph index for the pre-processing step (e.g., token embedding). - */ - private int getPreprocessingGraphIndex() { - return 0; - } - - /** - * Returns the graph index for the given transformer layer. + * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - * @param layerIndex - * Index of the transformer layer (0-based) - */ - private int getLayerGraphIndex(int layerIndex) { - return 1 + layerIndex; - } - - /** - * Returns the graph index for the final projection to logits. + *When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode} + * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.
+ * + * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) + * @param model the model instance + * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ - private int getFinalLogitsGraphIndex() { - return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; - } - - /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. - public void forceCopyInReadOnlyDataLayered() { - // Execute all TornadoVM graphs - state.wrapX.clear(); - state.positionHolder.init(0); - - // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); - - // Execute layer processing graphs - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { + int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); + TornadoVMMasterPlan plan; + if (batchSize > 1) { + plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan( + (LlamaState) state, model, batchSize); + } else { + plan = TornadoVMMasterPlanStandard.initialize(state, model); } - - // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); - } - - /** - * Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks. - */ - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); + model.setTornadoVMPlan(plan); + return plan; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java new file mode 100644 index 00000000..c9d816ee --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -0,0 +1,149 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Standard (single-token) GPU execution plan. + * + *Processes one token at a time through preprocessing + N transformer layers + + * logits projection. Used for both the baseline GPU path and the Phase 2 + * sequential prefill/decode path.
+ */ +public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { + + public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final State state; + private final Configuration config; + public TornadoExecutionPlan executionPlan; + GenericLayerPlanner tornadoVMLayerPlanner; + + public TornadoVMMasterPlanStandard(State state, Model model) { + this.tornadoVMLayerPlanner = createPlanner(state, model); + this.executionPlan = createExecutionPlan(); + this.state = state; + this.config = model.configuration(); + } + + /** + * Initializes and warms up the standard TornadoVM plan. + * + * @param state the model state containing KV cache + * @param model the model instance + * @return the initialized plan ready for inference + */ + static TornadoVMMasterPlanStandard initialize(State state, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + tornadoVMPlan.executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + tornadoVMPlan.forceCopyInReadOnlyDataLayered(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + + return tornadoVMPlan; + } + + private TornadoExecutionPlan createExecutionPlan() { + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + private GenericLayerPlanner createPlanner(State state, Model model) { + GGMLType weightType = model.weights().getWeightType(); + return QuantizationPlannerFactory.create(weightType, state, model); + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + // @formatter:off + var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) preGraph.withCUDAGraph(); + preGraph.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(getLayerGraphIndex(layer)) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + //.withCUDAGraph() + .execute(); + } + state.tempLogits.clear(); + state.wrapLogits.clear(); + var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); + logitsGraph.execute(); + // @formatter:on + return state.wrapLogits; + } + + private int getPreprocessingGraphIndex() { + return 0; + } + + private int getLayerGraphIndex(int layerIndex) { + return 1 + layerIndex; + } + + private int getFinalLogitsGraphIndex() { + return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; + } + + public void forceCopyInReadOnlyDataLayered() { + state.wrapX.clear(); + state.positionHolder.init(0); + + //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java new file mode 100644 index 00000000..3df08dfa --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -0,0 +1,326 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/** + * Unified GPU execution plan for Phase 4: batched prefill + single-token decode. + * + *A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via + * {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would + * allocate independent device buffers and lose the prefill KV state.
+ * + *Graph layout (2N+3 graphs total):
+ *+ * [0] batch activation B×dim FP16 → FP32 + * [1..N] batch layer graphs B tokens, all transformer ops + * [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through + * [N+2..2N+1] decode layer graphs single-token, standard kernels + * [2N+2] logits graph + *+ * + *
KV cache pointer chain across phases:
+ *+ * batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→ + * decodeActivation --consumeFromDevice(wrapKeyCache)-→ (pass-through) + * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention) + *+ */ +public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { + + private static final boolean ENABLE_TIMING = + Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final LlamaState state; + private final LlamaConfiguration config; + private final int batchSize; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; + + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int batchActivationIdx() { return 0; } + private int batchLayerIdx(int i) { return 1 + i; } + private int decodeActivationIdx() { return N + 1; } + private int decodeLayerIdx(int i) { return N + 2 + i; } + private int logitsIdx() { return 2 * N + 2; } + + // ── Construction ───────────────────────────────────────────────────────── + private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.batchSize = batchSize; + this.N = config.numberOfLayers(); + + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List
Receives the KV-cache device pointer from batch layer N via + * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so + * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. + * Both halves of the chain are required; without the re-persist the pointer is + * not forwarded in interpreter (non-CUDA-graph) mode.
+ */ + private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { +// System.out.println("lastBatchLayerID = " + lastBatchLayerID); +// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString()); + return new TaskGraph("decodeActivationUpdate") + .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV) + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", + TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) +// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe +// .task("dbgKV", +// TransformerComputeKernels::dbgCopyFirst8, +// state.wrapKeyCache, debugKV) +// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV) + // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache + // re-persisted so updatePersistedObjectState() propagates the device + // pointer to decode layer 0's consumeFromDevice without CUDA graphs. + .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + // ── Static factory ──────────────────────────────────────────────────────── + + /** + * Creates, JIT-compiles, and warms up the unified plan. + * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. + */ + public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan( + LlamaState state, Model model, int batchSize) { + + long t0 = System.nanoTime(); + TornadoVMMasterPlanWithBatchPrefillDecode plan = + new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph(); + plan.executionPlan.withPreCompilation(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + plan.forceCopyInReadOnlyData(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Init complete: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + return plan; + } + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + private void forceCopyInReadOnlyData() { + state.wrapXBatch.clear(); + state.wrapX.clear(); + state.positionHolder.init(0); + state.batchStartPosHolder.init(0); + + for (int i = 0; i <= logitsIdx(); i++) { + //System.out.println(i + " " + executionPlan.withGraph(i).toString()); + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk (length == batchSize, or tail) + * @param startPos sequence position of tokenIds[0] + * @param model model (for embedding table) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + // Copy B embeddings into embeddingXBatch + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + state.batchStartPosHolder.set(0, startPos); + + // Graph 0: batch activation + var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); + + // Graphs 1..N: batch transformer layers + for (int l = 0; l < N; l++) { + var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); + } + //System.err.println("[DEBUG] last batch layer done, about to return from prefill"); + // Logits skipped — not needed for prefill positions. + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (for embedding table) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graph N+1: decode activation + var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); + decodeAct.execute(); + // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and + // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache) +// if (position <= 290) { +// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n", +// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3)); +// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n", +// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3)); +// } + + // Graphs N+2..2N+1: decode transformer layers + for (int l = 0; l < N; l++) { + var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)"); + decodeLayer.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + // Graph 2N+2: logits + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } + + // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── +// moved to package +// +// private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { +// +// +// } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java new file mode 100644 index 00000000..e5262b17 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -0,0 +1,79 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Wraps {@link TornadoVMMasterPlanStandard} and adds a prefill-only GPU forward pass. + * + *Parallel to {@link TornadoVMMasterPlanStandard} — does NOT modify it.
+ * + *The existing execution plan has this graph layout:
+ *+ * graph 0 : preprocessing (embedding setup) + * graphs 1..N : transformer layers + * graph N+1 : logits projection (final RMSNorm + wcls matmul) + *+ * + *
{@link #tornadoVMForwardPrefill} executes graphs 0..N and deliberately + * skips graph N+1. The KV cache is populated correctly by the layer graphs; + * the logits are not needed for prefill positions so the projection is wasted + * work that we avoid.
+ * + *For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped + * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour + * to the baseline GPU path.
+ */ +public class TornadoVMMasterPlanWithPrefillDecode { + + private final TornadoVMMasterPlanStandard plan; + private final State state; + private final Configuration config; + + public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, State state, Model model) { + this.plan = plan; + this.state = state; + this.config = model.configuration(); + } + + /** + * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. + * + *Mirrors {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered} except + * the final logits graph (graph {@code numberOfLayers + 1}) is not executed.
+ * + * @param position sequence position being processed + */ + public void tornadoVMForwardPrefill(int position) { + // Graph 0: preprocessing + plan.executionPlan.withGraph(0) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graphs 1..N: transformer layers + for (int layer = 1; layer <= config.numberOfLayers(); layer++) { + plan.executionPlan.withGraph(layer) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + } + + // Graph N+1 (logits) intentionally skipped — not needed for prefill positions. + } + + /** + * GPU decode forward: full execution including logits. + * Delegates to {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered}. + * + * @param position sequence position being processed + * @return logits array for token sampling + */ + public FloatArray tornadoVMForwardDecode(int position) { + return plan.tornadoVMForwardExecuteLayered(position); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java new file mode 100644 index 00000000..9bba3860 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java @@ -0,0 +1,461 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * GPU kernels for batched prefill (Phase 4). + * + *Each kernel processes {@code batchSize} tokens simultaneously. + * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}. + * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.
+ * + *These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + * batch task graphs; they are NOT invoked directly.
+ */ +public final class TransformerBatchPrefillKernels { + + private TransformerBatchPrefillKernels() {} + + // ── Activation ──────────────────────────────────────────────────────────── + + /** + * Converts B×dim FP16 token embeddings to FP32. + * Worker: B*dim global threads, localSize=128. + */ + public static void batchEmbeddingToFP32(KernelContext context, + HalfFloatArray embeddingXBatch, + FloatArray wrapXBatch) { + int gid = context.globalIdx; + wrapXBatch.set(gid, embeddingXBatch.get(gid).getFloat32()); + } + + // ── RMS Norm (attention) ───────────────────────────────────────────────── + + /** + * Sequential RMS reduction — one thread per batch item. + * + *Each thread computes the RMS scale factor for its token: + * {@code scale[b] = 1 / sqrt( mean(x[b]²) + eps )}
+ * + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray attnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + attnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + /** + * Applies RMS normalization and FP16-quantizes the result. + * + *{@code xbFP16Batch[b*dim+i] = FP16( rmsWeights[i] * scale[b] * x[b*dim+i] )}
+ * + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP16(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + float scale = attnScaleBatch.get(b); + float result = rmsWeights.get(i) * scale * wrapXBatch.get(gid); + xbFP16Batch.set(gid, new HalfFloat(result)); + } + + // ── QKV Projection ──────────────────────────────────────────────────────── + + /** + * Fused batched QKV projection (FP16 weights, FP16 input). + * + *One workgroup per (batchIdx, outputRow) pair. + * globalGroupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.
+ * + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmul(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowOff = rowIdx * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowOff = kRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowOff = vRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + // ── RoPE + KV Cache ─────────────────────────────────────────────────────── + + /** + * Fused batched RoPE rotation + KV cache write. + * + *globalIdx encodes (batchIdx, pairIdx) as {@code batchIdx*(dim/2) + pairIdx}. + * Position for token b = {@code startPos + b}.
+ * + * Worker: B*(dim/2) global threads, localSize=512 (or less if B*dim/2 < 512). + */ + public static void batchedRopeWithKVCache(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, int headSize, + int layerIndex, int contextLength, int dim) { + int globalIdx = context.globalIdx; + int halfDim = dim / 2; + int batchIdx = globalIdx / halfDim; + int pairIdx = globalIdx % halfDim; + int i = pairIdx * 2; + + int pos = batchStartPosHolder.get(0) + batchIdx; + int qOffset = batchIdx * dim; + int kOffset = batchIdx * kvDim; + + if (i + 1 < dim) { + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = wrapQBatch.get(qOffset + i); + float v1q = wrapQBatch.get(qOffset + i + 1); + wrapQBatch.set(qOffset + i, v0q * fcr - v1q * fci); + wrapQBatch.set(qOffset + i + 1, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache + if (i + 1 < kvDim) { + float v0k = wrapKBatch.get(kOffset + i); + float v1k = wrapKBatch.get(kOffset + i + 1); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kOffset + i, rotK0); + wrapKBatch.set(kOffset + i + 1, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim; + wrapKeyCache.set(cacheOff + i, rotK0); + wrapKeyCache.set(cacheOff + i + 1, rotK1); + wrapValueCache.set(cacheOff + i, wrapVBatch.get(kOffset + i)); + wrapValueCache.set(cacheOff + i + 1, wrapVBatch.get(kOffset + i + 1)); + } + } + } + + // ── Attention ───────────────────────────────────────────────────────────── + + /** + * Batched causal flash attention. + * + *One workgroup per (batchIdx, headIdx) pair: + * {@code groupIdx = batchIdx * nHeads + headIdx}. + * Token b attends to positions 0..{@code startPos + b} (causal).
+ * + * Worker: B*nHeads workgroups × optimalLocalSize threads. + */ + public static void batchedFlashAttention(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + FloatArray wrapXbBatch, + int nHeads, int headSize, + int kvDim, int kvMul, + int layerIndex, int contextLength, int dim) { + int tid = context.localIdx; + int groupId = context.groupIdx; + int localSz = context.localGroupSizeX; + + int batchIdx = groupId / nHeads; + int h = groupId % nHeads; + int pos = batchStartPosHolder.get(0) + batchIdx; + int loff = layerIndex * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_C = 16; + + float[] qShared = context.allocateFloatLocalArray(headSize); + float[] kTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] vTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] sTile = context.allocateFloatLocalArray(BLOCK_C); + float[] maxHolder = context.allocateFloatLocalArray(1); + + // Load Q into shared memory + int qOffset = batchIdx * dim + h * headSize; + for (int i = tid; i < headSize; i += localSz) { + qShared[i] = wrapQBatch.get(qOffset + i); + } + context.localBarrier(); + + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) output[i] = 0.0f; + + for (int tileC = 0; tileC <= pos; tileC += BLOCK_C) { + int tileEnd = Math.min(tileC + BLOCK_C - 1, pos); + + // Load K/V tile + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + int tileMOff = tInTile * headSize; + for (int d = 0; d < headSize; d++) { + int kvOff = loff + t * kvDim + kvHeadIdx * headSize + d; + kTile[tileMOff + d] = wrapKeyCache.get(kvOff); + vTile[tileMOff + d] = wrapValueCache.get(kvOff); + } + } + context.localBarrier(); + + // Compute attention scores + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += qShared[d] * kTile[tInTile * headSize + d]; + } + sTile[tInTile] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // Tile max + float tileMax = Float.NEGATIVE_INFINITY; + for (int t = 0; t <= tileEnd - tileC; t++) { + if (sTile[t] > tileMax) tileMax = sTile[t]; + } + if (tid == 0) maxHolder[0] = tileMax; + context.localBarrier(); + float curTileMax = maxHolder[0]; + + float newMax = Math.max(maxScore, curTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) output[d] *= scale; + } + maxScore = newMax; + + for (int t = 0; t <= tileEnd - tileC; t++) { + float expScore = TornadoMath.exp(sTile[t] - maxScore); + sumExp += expScore; + for (int d = 0; d < headSize; d++) { + output[d] += expScore * vTile[t * headSize + d]; + } + } + context.localBarrier(); + } + + float norm = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + int xbOffset = batchIdx * dim + h * headSize; + for (int d = tid; d < headSize; d += localSz) { + wrapXbBatch.set(xbOffset + d, output[d] * norm); + } + } + + // ── Output / FFN Projections ───────────────────────────────────────────── + + /** + * Batched matrix-vector multiply with residual add. + * + *Used for both the attention output projection (Wo) and the FFN down + * projection (W2). One workgroup per (batchIdx, outputRow): + * {@code groupIdx = batchIdx * d + rowIdx}.
+ * + *One workgroup per (batchIdx, hiddenRow): + * {@code groupIdx = batchIdx * hiddenDim + rowIdx}.
+ * + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUp(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + HalfFloatArray w1, + HalfFloatArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + int rowOff = rowIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // W1 matmul with inline RMS apply + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + // W3 matmul with inline RMS apply + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + // SiLU(W1·x) × (W3·x) + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 56f2c0c3..50619cc2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + // consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the + // source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter + // mode (updatePersistedObjectState looks up the predecessor's name, not the current name). + // Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to + // return the correct predecessor graph name so the XPUBuffer is propagated in both modes. + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), @@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, + state.wrapValueCache); return unifiedLayer; } + /** + * Returns the name of the predecessor task graph from which {@code wrapX} should be consumed, + * or {@code null} to fall back to the no-arg form (source key = own graph name). + * + *The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time) + * but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's + * graph name, not the current graph's name, so the XPUBuffer is never propagated and + * {@code executeAlloc} NPEs on a null buffer.
+ * + *Override in subclasses that receive {@code wrapX} from a named predecessor graph:
+ *Overrides data-transfer declarations so that all cross-graph boundaries use + * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by + * the base class) passes the current graph's own name as the source key. + * In CUDA-graph mode this is harmless (device pointers are frozen at capture time), + * but in interpreter mode {@code updatePersistedObjectState} looks up the + * predecessor's name, so the lookup always misses and the XPUBuffer is + * never propagated — causing either a null-pointer crash or a silent re-upload + * from host (zeros), corrupting the hidden state and KV cache.
+ * + *Two boundaries are fixed here:
+ *Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + * Must match the {@code TaskGraph} names used in + * {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.
+ */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device + // (passed through by the decode activation graph, which relays them from + // the last batch prefill layer). No FIRST_EXECUTION for KV cache here. + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph(). + layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache); + } else { + // Layers 1+: use explicit predecessor name for ALL consumed objects. + // Calling super here would use the no-arg form (source key = own graph name), + // which silently fails in interpreter mode and causes re-upload from host. + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java new file mode 100644 index 00000000..760be156 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -0,0 +1,53 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer for the unified prefill-decode plan (Phase 4). + * + *Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device + * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the + * logits → decode-activation boundary across decode tokens.
+ * + *In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()} + * propagates device pointers from the predecessor graph's persisted set. After the + * last decode token the predecessor of the next decode-activation graph is the + * logits graph. Without the pass-through here, the KV-cache pointer is absent from + * the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.
+ * + *Bytecode order matters: {@code consumeFromDevice} must precede task declarations, + * and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in + * {@link LogitsFP16Layer} guarantee this ordering.
+ */ +public class LogitsFP16LayerDecode extends LogitsFP16Layer { + + public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + /** + * Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks. + * + *Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache + * to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState} + * uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor + * (last decode layer) to never be propagated into the logits graph's device state.
+ */ + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + /** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */ + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java new file mode 100644 index 00000000..a893623d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -0,0 +1,247 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill. + * + *One {@link ImmutableTaskGraph} per transformer layer, each processing + * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.
+ * + *KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device + * after every layer so the subsequent single-token decode layers can consume it.
+ */ +public class LlamaFP16LayersBatchPrefill { + + // Matches the local workgroup size used by the single-token kernels. + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final List