From c00aa825d07d1fed718339ece69dc66a32670440 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:42:02 +0200 Subject: [PATCH 01/12] [refactor] Move QuantizedLayerPlanner to layerplanner package root-level --- .../gpullama3/tornadovm/TornadoVMMasterPlan.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index a42dc310..4b752735 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -55,6 +55,8 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); } + tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + // 2. Perform warmup with extra iterations to ensure JIT compilation is complete tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code @@ -130,6 +132,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) executionPlan.withGraph(getPreprocessingGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); // Set the position in the state object (used by attention layers) @@ -142,6 +145,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { for (int layer = 0; layer < config.numberOfLayers(); layer++) { executionPlan.withGraph(getLayerGraphIndex(layer)) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); } state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f @@ -149,6 +153,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // 3. Execute the final graph that projects the last hidden state to output logits executionPlan.withGraph(getFinalLogitsGraphIndex()) .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() .execute(); // @formatter:on @@ -187,15 +192,15 @@ public void forceCopyInReadOnlyDataLayered() { state.positionHolder.init(0); // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); // Execute layer processing graphs for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); } // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); } /** From 8be1c05e1b56d0be3aec89456f64f5c973396ad1 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 18:19:02 +0300 Subject: [PATCH 02/12] [prf/dec] Add CLI options for batched prefill and prefill batch size configuration --- llama-tornado | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/llama-tornado b/llama-tornado index 57a50f1c..81349a5e 100755 --- a/llama-tornado +++ b/llama-tornado @@ -87,6 +87,12 @@ class LlamaRunner: if args.verbose_init: cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") + if args.batched_prefill: + cmd.append("-Dllama.batchedPrefill=true") + + if args.prefill_batch_size is not None: + cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + # Debug options debug_config = [] @@ -472,6 +478,22 @@ def create_parser() -> argparse.ArgumentParser: help="Execute the command after showing it (use with --show-command)", ) + # Prefill/Decode optimization + prefill_group = parser.add_argument_group("Prefill/Decode Optimization") + prefill_group.add_argument( + "--batched-prefill", + dest="batched_prefill", + action="store_true", + help="Enable batched prefill/decode separation (llama.batchedPrefill=true)", + ) + prefill_group.add_argument( + "--prefill-batch-size", + dest="prefill_batch_size", + type=int, + default=None, + help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", + ) + # Advanced options advanced_group = parser.add_argument_group("Advanced Options") advanced_group.add_argument( From ca4744f880ea058308b8bbc648b5b5098f7dde96 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 18:20:28 +0300 Subject: [PATCH 03/12] [prf/dec] Add CPU path for prefill/decode. Separates inference path with InferenceCoreWithPrefillDecode and InferenceEngineWithPrefillDecode --- .../InferenceCoreWithPrefillDecode.java | 124 ++++++++++++++++++ .../InferenceEngineWithPrefillDecode.java | 120 +++++++++++++++++ .../beehive/gpullama3/model/llama/Llama.java | 6 + 3 files changed, 250 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java new file mode 100644 index 00000000..d662afe1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -0,0 +1,124 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.Parallel; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.FloatTensor; + +/** + * Low-level forward passes for the prefill/decode separated inference path. + * + *

Parallel to {@link InferenceCore} — does NOT modify it.

+ * + *

The key addition is {@link #forwardJavaPrefill}, which runs a full + * transformer forward pass but skips the final RMSNorm and vocabulary + * projection (wcls matmul). This is correct for all prefill positions + * because their logits are discarded anyway; only the KV-cache update + * matters. Skipping the projection saves one large matmul + * (vocabularySize × dim) per prefill token.

+ */ +public final class InferenceCoreWithPrefillDecode { + + private InferenceCoreWithPrefillDecode() {} + + /** + * Prefill-only forward pass for LLaMA (CPU, FP32 weights). + * + *

Identical to {@link InferenceCore#forwardJava} except the final + * RMSNorm and vocabulary projection are omitted. The KV cache is + * populated correctly at {@code position}.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, activations …) + * @param token input token id + * @param position sequence position being processed + */ + public static void forwardJavaPrefill(Model model, State state, int token, int position) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // Token embedding + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // Transformer layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // Attention RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // QKV projections + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // KV cache update + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + // Multi-head attention + int curLayer = l; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // Attention output projection + residual + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + state.x.addInPlace(state.xb2); + + // FFN RMSNorm + InferenceCore.rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN (SwiGLU) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // FFN residual + state.x.addInPlace(state.xb); + } + + // Final RMSNorm and vocab projection intentionally omitted: + // logits are not needed for prefill positions — only the KV cache matters. + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java new file mode 100644 index 00000000..b97f3c72 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -0,0 +1,120 @@ +package org.beehive.gpullama3.inference; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Token generation entry point for the prefill/decode separated inference path. + * + *

Parallel to {@link InferenceEngine} — does NOT modify it.

+ * + *

The split loop runs two phases:

+ *
    + *
  1. Prefill (positions 0..N-1): calls + * {@link InferenceCoreWithPrefillDecode#forwardJavaPrefill} for every + * prompt token. Vocabulary projection is skipped because these logits + * are discarded. KV cache is populated identically to the baseline.
  2. + *
  3. Decode (position N onward): calls + * {@link InferenceCore#forwardJava} per generated token. + * Behaviour is identical to the baseline decode path.
  4. + *
+ * + *

Activated by {@code -Dllama.batchedPrefill=true} (set via + * {@code --batched-prefill} in the Python launcher).

+ */ +public final class InferenceEngineWithPrefillDecode { + + private InferenceEngineWithPrefillDecode() {} + + /** + * LLaMA token generation with prefill/decode separation (CPU, Phase 1). + * + *

Drop-in replacement for + * {@link InferenceEngine#generateTokensLlama} when the batched-prefill + * flag is enabled. Only the CPU path is implemented here; GPU support + * is added in a later phase.

+ */ + public static List generateTokensLlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + if (maxTokens < 0 || config.contextLength() < maxTokens) { + maxTokens = config.contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + + // ── Phase 1: Prefill ────────────────────────────────────────────────── + // Run all prompt tokens through the forward pass without computing + // logits. The KV cache is populated at each position, which is all + // that matters. After this loop: + // currentToken == promptTokens.getLast() + // pos == startPosition + promptTokens.size() + for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + + state.latestToken = currentToken; + + // ── Phase 2: Decode ─────────────────────────────────────────────────── + // Standard single-token forward with logits. Behaviour is identical + // to the baseline InferenceEngine decode path. + long inferenceStartNanos = 0; + while (pos < maxTokens) { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + var logits = InferenceCore.forwardJava(model, state, currentToken, pos); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8c69cb40..8036809e 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; @@ -19,6 +20,8 @@ public class Llama extends AbstractModel { + static final boolean BATCHED_PREFILL = Boolean.getBoolean("llama.batchedPrefill"); + LlamaConfiguration configuration; public Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { @@ -63,6 +66,9 @@ public void forward(State state, int token, int position) { @Override public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { + if (BATCHED_PREFILL) { + return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); } From d5f5629407e95f3752ab1fb65bdb5838db418421 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 19:09:26 +0300 Subject: [PATCH 04/12] [prf/dec] Add GPU path for prefill/decode with TornadoVM integration. Implements `InferenceEngineWithPrefillDecode` and `TornadoVMMasterPlanWithPrefillDecode` for batched token generation. Refactor `Llama` to support the batched prefill flag. --- .../InferenceCoreWithPrefillDecode.java | 40 ++++++++ .../InferenceEngineWithPrefillDecode.java | 94 +++++++++++++++++++ .../beehive/gpullama3/model/llama/Llama.java | 3 + .../TornadoVMMasterPlanWithPrefillDecode.java | 79 ++++++++++++++++ 4 files changed, 216 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index d662afe1..91bb6f79 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -3,9 +3,13 @@ import org.beehive.gpullama3.auxiliary.Parallel; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; + +import java.lang.foreign.MemorySegment; /** * Low-level forward passes for the prefill/decode separated inference path. @@ -121,4 +125,40 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p // Final RMSNorm and vocab projection intentionally omitted: // logits are not needed for prefill positions — only the KV cache matters. } + + /** + * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). + * + *

Copies the token embedding into {@code state.embeddingX} (same as + * {@link InferenceCore#forwardTornadoVM}) then delegates to + * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * which executes preprocessing + layer graphs but skips the logits graph.

+ * + * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) + * @param state mutable inference state + * @param token input token id + * @param position sequence position being processed + * @param prefillPlan the prefill/decode plan wrapper + * @throws UnsupportedOperationException if the model uses Q8_0 weights + */ + public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, + TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + final Configuration configuration = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, + state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); + } + case Q8_0 -> throw new UnsupportedOperationException( + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + prefillPlan.tornadoVMForwardPrefill(position); + } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index b97f3c72..0ea06c84 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -3,9 +3,13 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; import java.util.List; @@ -117,4 +121,94 @@ public static List generateTokensLlama( return generatedTokens; } + + /** + * LLaMA GPU token generation with prefill/decode separation (Phase 2). + * + *

Drop-in replacement for + * {@link InferenceEngine#generateTokensGPULlama} when the batched-prefill + * flag is enabled. FP16 only; Q8_0 throws {@link UnsupportedOperationException}.

+ * + *

Split loop:

+ *
    + *
  • Prefill (0..N-1): {@link InferenceCoreWithPrefillDecode#forwardTornadoVMPrefill} + * — layer graphs execute, logits graph is skipped.
  • + *
  • Decode (N onward): {@link InferenceCore#forwardTornadoVM} + * — identical to the baseline GPU decode path.
  • + *
+ */ + public static List generateTokensGPULlama( + Model model, State state, int startPosition, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + + // Q8_0 GPU prefill not implemented yet + if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) { + // TODO Phase 4: implement Q8_0 GPU batched prefill kernels + throw new UnsupportedOperationException( + "GPU prefill/decode path not yet implemented for Q8_0 weights"); + } + + long startNanos = System.nanoTime(); + + final Configuration config = model.configuration(); + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; // BOS + int pos = startPosition; + + // Thin wrapper: no new TornadoVM plan created, just holds the reference + TornadoVMMasterPlanWithPrefillDecode prefillPlan = + new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model); + + // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── + for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardTornadoVMPrefill(model, state, currentToken, pos, prefillPlan); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } + + state.latestToken = currentToken; + + // ── Phase 2: Decode (GPU, with logits) ─────────────────────────────── + while (pos < actualMaxTokens) { + var logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + int totalTokens = promptTokens.size() + generatedTokens.size(); + LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); + + return generatedTokens; + } + + } diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index 8036809e..12a95070 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -75,6 +75,9 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + if (BATCHED_PREFILL) { + return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java new file mode 100644 index 00000000..61b81bef --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -0,0 +1,79 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass. + * + *

Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.

+ * + *

The existing execution plan has this graph layout:

+ *
+ *   graph 0         : preprocessing (embedding setup)
+ *   graphs 1..N     : transformer layers
+ *   graph N+1       : logits projection (final RMSNorm + wcls matmul)
+ * 
+ * + *

{@link #tornadoVMForwardPrefill} executes graphs 0..N and deliberately + * skips graph N+1. The KV cache is populated correctly by the layer graphs; + * the logits are not needed for prefill positions so the projection is wasted + * work that we avoid.

+ * + *

For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped + * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour + * to the baseline GPU path.

+ */ +public class TornadoVMMasterPlanWithPrefillDecode { + + private final TornadoVMMasterPlan plan; + private final State state; + private final Configuration config; + + public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) { + this.plan = plan; + this.state = state; + this.config = model.configuration(); + } + + /** + * GPU prefill forward: runs preprocessing + all transformer layers, skips logits. + * + *

Mirrors {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered} except + * the final logits graph (graph {@code numberOfLayers + 1}) is not executed.

+ * + * @param position sequence position being processed + */ + public void tornadoVMForwardPrefill(int position) { + // Graph 0: preprocessing + plan.executionPlan.withGraph(0) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graphs 1..N: transformer layers + for (int layer = 1; layer <= config.numberOfLayers(); layer++) { + plan.executionPlan.withGraph(layer) + .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler()) + .execute(); + } + + // Graph N+1 (logits) intentionally skipped — not needed for prefill positions. + } + + /** + * GPU decode forward: full execution including logits. + * Delegates to {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered}. + * + * @param position sequence position being processed + * @return logits array for token sampling + */ + public FloatArray tornadoVMForwardDecode(int position) { + return plan.tornadoVMForwardExecuteLayered(position); + } +} From fbbc41ff3041235cf61102d0131e7afec07fdbe5 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 31 Mar 2026 22:34:07 +0300 Subject: [PATCH 05/12] [prf/dec] Batch prompt tokens during prefill phase in CPU path --- .../InferenceCoreWithPrefillDecode.java | 142 ++++++++++++++++++ .../InferenceEngineWithPrefillDecode.java | 96 ++++++++---- 2 files changed, 207 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index 91bb6f79..460bb9af 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -6,6 +6,7 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; @@ -126,6 +127,147 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p // logits are not needed for prefill positions — only the KV cache matters. } + /** + * CPU batched prefill forward pass for LLaMA (Phase 3). + * + *

Processes {@code batchSize} prompt tokens simultaneously through all + * transformer layers. For each layer, Q/K/V projections, output projection, + * and FFN projections are computed via batch matmul + * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}), + * which parallelises over both output dimension and batch simultaneously. + * Attention reuses {@code state.att} sequentially per token (parallel per + * head within each token), keeping memory overhead minimal.

+ * + *

The logits layer is intentionally omitted — only the KV cache matters + * for prefill positions.

+ * + * @param model the LLaMA model (must carry {@link StandardWeights}) + * @param state mutable inference state (KV cache, att buffer …) + * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b} + * @param startPos sequence position of {@code tokens[0]} + * @param batchSize number of tokens in this chunk ({@code tokens.length}) + */ + public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) { + final Configuration config = model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // ── Batch activation tensors (allocated once per chunk) ─────────────── + FloatTensor[] x = new FloatTensor[batchSize]; + FloatTensor[] xb = new FloatTensor[batchSize]; + FloatTensor[] xb2 = new FloatTensor[batchSize]; + FloatTensor[] q = new FloatTensor[batchSize]; + FloatTensor[] k = new FloatTensor[batchSize]; + FloatTensor[] v = new FloatTensor[batchSize]; + FloatTensor[] hb = new FloatTensor[batchSize]; + FloatTensor[] hb2 = new FloatTensor[batchSize]; + for (int b = 0; b < batchSize; b++) { + x[b] = ArrayFloatTensor.allocate(dim); + xb[b] = ArrayFloatTensor.allocate(dim); + xb2[b] = ArrayFloatTensor.allocate(dim); + q[b] = ArrayFloatTensor.allocate(dim); + k[b] = ArrayFloatTensor.allocate(kvDim); + v[b] = ArrayFloatTensor.allocate(kvDim); + hb[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim()); + } + + // ── Token embeddings ────────────────────────────────────────────────── + Parallel.parallelFor(0, batchSize, b -> + weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim)); + + // ── Transformer layers ──────────────────────────────────────────────── + for (int l = 0; l < config.numberOfLayers(); l++) { + final int layer = l; + + // Attention RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps())); + + // QKV projections — batch matmul parallelises over (dim × batchSize) + weights.wq[l].matmul(batchSize, xb, q, dim, dim); + weights.wk[l].matmul(batchSize, xb, k, kvDim, dim); + weights.wv[l].matmul(batchSize, xb, v, kvDim, dim); + + // RoPE + KV cache store (parallel per b — different positions, no conflict) + Parallel.parallelFor(0, batchSize, b -> { + int pos = startPos + b; + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int vv = 0; vv < rotn; vv++) { + FloatTensor vec = vv == 0 ? q[b] : k[b]; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim); + v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim); + }); + + // Attention — sequential per b (state.att is shared), parallel per head + for (int b = 0; b < batchSize; b++) { + final int pos_b = startPos + b; + final int bFinal = b; + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= pos_b; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + state.att.softmaxInPlace(attOffset, pos_b + 1); + + int xbOffset = h * headSize; + xb[bFinal].fillInPlace(xbOffset, headSize, 0f); + for (int t = 0; t <= pos_b; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a); + } + }); + } + + // Output projection — batch matmul + weights.wo[l].matmul(batchSize, xb, xb2, dim, dim); + + // Residual + FFN RMSNorm (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + x[b].addInPlace(xb2[b]); + InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps()); + }); + + // FFN projections — batch matmul + weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim); + weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim); + + // SwiGLU (parallel per b) + Parallel.parallelFor(0, batchSize, b -> { + hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + hb[b].multiplyInPlace(hb2[b]); + }); + + // w2 projection — batch matmul (output reuses xb) + weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim()); + + // FFN residual (parallel per b) + Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b])); + } + + // Final RMSNorm and vocab projection intentionally omitted — + // logits are not needed for any token in a prefill batch. + } + /** * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM). * diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 0ea06c84..b581b8e8 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -12,6 +12,7 @@ import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.IntConsumer; @@ -39,13 +40,20 @@ public final class InferenceEngineWithPrefillDecode { private InferenceEngineWithPrefillDecode() {} + /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */ + static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); + /** - * LLaMA token generation with prefill/decode separation (CPU, Phase 1). + * LLaMA token generation with prefill/decode separation (CPU). * - *

Drop-in replacement for - * {@link InferenceEngine#generateTokensLlama} when the batched-prefill - * flag is enabled. Only the CPU path is implemented here; GPU support - * is added in a later phase.

+ *

When {@code llama.prefillBatchSize > 1} (Phase 3), prompt tokens are + * processed in chunks of that size using batch matmul, which traverses each + * weight matrix once per chunk instead of once per token.

+ * + *

When {@code llama.prefillBatchSize == 1} (Phase 1), falls back to + * sequential single-token prefill (skip logits only).

+ * + *

Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.

*/ public static List generateTokensLlama( Model model, State state, int startPosition, @@ -56,42 +64,68 @@ public static List generateTokensLlama( long startNanos = System.nanoTime(); final Configuration config = model.configuration(); - if (maxTokens < 0 || config.contextLength() < maxTokens) { - maxTokens = config.contextLength(); - } + int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) + ? config.contextLength() : maxTokens; List generatedTokens = new ArrayList<>(); int currentToken = state.latestToken; // BOS int pos = startPosition; - - // ── Phase 1: Prefill ────────────────────────────────────────────────── - // Run all prompt tokens through the forward pass without computing - // logits. The KV cache is populated at each position, which is all - // that matters. After this loop: - // currentToken == promptTokens.getLast() - // pos == startPosition + promptTokens.size() - for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) { - InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); - currentToken = promptTokens.get(promptIndex); - if (echo) { - System.err.print(Tokenizer.replaceControlCharacters( - model.tokenizer().decode(List.of(currentToken)))); + int N = promptTokens.size(); + + // ── Prefill ─────────────────────────────────────────────────────────── + if (N > 0 && pos < actualMaxTokens) { + if (PREFILL_BATCH_SIZE > 1) { + // Phase 3: batch prefill — process tokens in chunks of PREFILL_BATCH_SIZE. + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS) + // position startPosition+1 : promptTokens[0] + // ... + // position startPosition+N-1: promptTokens[N-2] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { + int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart); + } else { + InferenceCoreWithPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + } else { + // Phase 1: sequential prefill — single token, no logits + for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) { + InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos); + currentToken = promptTokens.get(promptIndex); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(currentToken)))); + } + pos++; + } } - pos++; } state.latestToken = currentToken; - // ── Phase 2: Decode ─────────────────────────────────────────────────── - // Standard single-token forward with logits. Behaviour is identical - // to the baseline InferenceEngine decode path. - long inferenceStartNanos = 0; - while (pos < maxTokens) { - if (inferenceStartNanos == 0) { - inferenceStartNanos = System.nanoTime(); - } - + // ── Decode ──────────────────────────────────────────────────────────── + while (pos < actualMaxTokens) { var logits = InferenceCore.forwardJava(model, state, currentToken, pos); int nextToken = sampler.sampleToken(logits); From f0bca5f56e58db89cec21e9319127fe30413860d Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 2 Apr 2026 18:05:49 +0300 Subject: [PATCH 06/12] [prf/dec][wip] Add GPU-based prefill-decode with batched prefill (working state, with cuda graphs only) --- .../InferenceEngineWithPrefillDecode.java | 82 +++- .../gpullama3/inference/state/LlamaState.java | 42 ++ .../gpullama3/inference/state/State.java | 2 +- .../tornadovm/TornadoVMMasterPlan.java | 240 ++------- .../TornadoVMMasterPlanBatchPrefill.java | 342 +++++++++++++ .../TornadoVMMasterPlanStandard.java | 149 ++++++ .../TornadoVMMasterPlanWithPrefillDecode.java | 8 +- .../TransformerBatchPrefillKernels.java | 461 ++++++++++++++++++ .../fp16/LlamaFP16BatchPrefillLayers.java | 238 +++++++++ 9 files changed, 1364 insertions(+), 200 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index b581b8e8..6517df12 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -2,6 +2,7 @@ import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.tensor.GGMLType; @@ -9,6 +10,8 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; import java.util.ArrayList; @@ -40,7 +43,7 @@ public final class InferenceEngineWithPrefillDecode { private InferenceEngineWithPrefillDecode() {} - /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */ + /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3/4). */ static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1); /** @@ -195,9 +198,82 @@ public static List generateTokensGPULlama( int currentToken = state.latestToken; // BOS int pos = startPosition; + if (PREFILL_BATCH_SIZE > 1) { + // ── Phase 4: Batch GPU Prefill ──────────────────────────────────── + // Plan was pre-initialized in Model.runInstructOnce/runInteractive + // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan. + TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan; + + int N = promptTokens.size(); + + // Build the token sequence at positions [startPosition .. startPosition+N-1]: + // position startPosition+0 : currentToken (BOS/previous token) + // position startPosition+1 : promptTokens[0] + // ... + // position startPosition+N-1: promptTokens[N-2] + int[] prefillSeq = new int[N]; + prefillSeq[0] = currentToken; + for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1); + + for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) { + int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos); + int chunkSize = chunkEnd - chunkStart; + int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd); + + if (chunkSize == 1) { + // Single-token chunk: use decode path (includes logits skip is not needed + // here, but we need the KV cache populated — use batch prefill with size 1) + plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, 1); + } else { + plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, chunkSize); + } + + if (echo) { + for (int b = 0; b < chunkSize; b++) { + int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1)); + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(echoed)))); + } + } + } + + currentToken = promptTokens.get(N - 1); + pos = startPosition + N; + state.latestToken = currentToken; + + // ── Phase 4: Decode (GPU, with logits, via unified plan) ────────── + while (pos < actualMaxTokens) { + var logits = plan.tornadoVMForwardDecode(currentToken, pos, model); + int nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters( + model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + } else { + // ── Phase 2: Sequential GPU Prefill + Decode ───────────────────────── + // Thin wrapper: no new TornadoVM plan created, just holds the reference + // Plan is a TornadoVMMasterPlanStandard when PREFILL_BATCH_SIZE == 1. TornadoVMMasterPlanWithPrefillDecode prefillPlan = - new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model); + new TornadoVMMasterPlanWithPrefillDecode( + (TornadoVMMasterPlanStandard) tornadoVMPlan, state, model); // ── Phase 1: Prefill (GPU, no logits) ──────────────────────────────── for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) { @@ -237,6 +313,8 @@ public static List generateTokensGPULlama( pos++; } + } // end else (Phase 2) + long endNanos = System.nanoTime(); int totalTokens = promptTokens.size() + generatedTokens.size(); LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 21344223..d298d388 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -22,8 +22,50 @@ */ public final class LlamaState extends State { + // ── Batch-prefill GPU buffers ───────────────────────────────────────────── + // Allocated when llama.prefillBatchSize > 1; null otherwise. + // Layout: flat [B × stride], element [b][i] at index b*stride + i. + public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input) + public final FloatArray wrapXBatch; // B × dim (live activations) + public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16) + public final FloatArray wrapQBatch; // B × dim + public final FloatArray wrapKBatch; // B × kvDim + public final FloatArray wrapVBatch; // B × kvDim + public final FloatArray wrapXbBatch; // B × dim (attention output) + public final FloatArray wrapHbBatch; // B × hiddenDim + public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn) + public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN) + public final IntArray batchStartPosHolder; // 1 (start position of chunk) + public LlamaState(Configuration config, int batchsize) { super(config, batchsize); + int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1); + if (gpuBatchSize > 1) { + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim()); + this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim); + this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim()); + this.attnScaleBatch = new FloatArray(gpuBatchSize); + this.ffnScaleBatch = new FloatArray(gpuBatchSize); + this.batchStartPosHolder = new IntArray(1); + } else { + this.embeddingXBatch = null; + this.wrapXBatch = null; + this.wrapXbFP16Batch = null; + this.wrapQBatch = null; + this.wrapKBatch = null; + this.wrapVBatch = null; + this.wrapXbBatch = null; + this.wrapHbBatch = null; + this.attnScaleBatch = null; + this.ffnScaleBatch = null; + this.batchStartPosHolder = null; + } } @Override diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index f8e9906a..06d448e7 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -75,7 +75,7 @@ public abstract class State { /** last index in previous block */ protected State(Configuration config, int batchsize) { - this.batchsize = -1; + this.batchsize = batchsize; this.latestToken = -1; this.localSize = 256; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 4b752735..43030fd0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,212 +1,66 @@ package org.beehive.gpullama3.tornadovm; +import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -public class TornadoVMMasterPlan { - public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); - - private final State state; - private final Configuration config; - public TornadoExecutionPlan executionPlan; - GenericLayerPlanner tornadoVMLayerPlanner; - - public TornadoVMMasterPlan(State state, Model model) { - this.tornadoVMLayerPlanner = createPlanner(state, model); - this.executionPlan = createExecutionPlan(); - this.state = state; - this.config = model.configuration(); - } +/** + * Common contract for all TornadoVM GPU execution plans. + * + *

Two concrete implementations exist:

+ *
    + *
  • {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the + * baseline GPU path and Phase 2 sequential prefill/decode.
  • + *
  • {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched + * prefill + single-token decode within one {@code TornadoExecutionPlan}.
  • + *
+ * + *

The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation + * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a + * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a + * {@link TornadoVMMasterPlanStandard}.

+ */ +public interface TornadoVMMasterPlan { + + boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( + System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); /** - * Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3. - * Copying read-only model weights to the GPU + * Single-token forward pass returning output logits. * - * @param state - * The model state containing KV cache - * @param model - * The Llama model instance - * @return The initialized TornadoVMMasterPlan ready for inference - */ - public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { - // Initialize timing variables outside conditional blocks to avoid scope issues - long startTime = System.nanoTime(); - long planCreationTime = 0; - long warmupTime = 0; - - // Start a timing message if enabled - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - // 1. Pre-allocate the TornadoVM plan - TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model); - - // Record time after plan creation - if (ENABLE_TORNADOVM_INIT_TIME) { - planCreationTime = System.nanoTime(); - System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); - } - - tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); - - // 2. Perform warmup with extra iterations to ensure JIT compilation is complete - tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code - - // Record time after warmup - if (ENABLE_TORNADOVM_INIT_TIME) { - warmupTime = System.nanoTime(); - System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); - } - - // 3. Perform copy-in of read-only weights and objects - tornadoVMPlan.forceCopyInReadOnlyDataLayered(); // Force copy-in read-only weights - - // Record final timing information - if (ENABLE_TORNADOVM_INIT_TIME) { - long copyTime = System.nanoTime(); - System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); - System.err.printf("Finished TornadoVM initialization...\n \n"); - } - - model.setTornadoVMPlan(tornadoVMPlan); - - return tornadoVMPlan; - } - - private TornadoExecutionPlan createExecutionPlan() { - var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); - var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); - return new TornadoExecutionPlan(taskGraphArray); - } - - private GenericLayerPlanner createPlanner(State state, Model model) { - // ========== STEP 1: Detect Quantization Type ========== - GGMLType weightType = model.weights().getWeightType(); - - // ========== STEP 2: Route via Factory ========== - // Factory handles all model × quantization combinations - GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model); - - return basePlanner; - } - - /** - * Determines whether the NVIDIA-specific scheduler should be used based on the current - * hardware backend and the model type. - *

- * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the - * NVIDIA-specific scheduler should not be used. + *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) + * and the Phase 2 sequential decode path. Not applicable to + * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.

* - * @param model - * the model whose type may affect the scheduler decision - * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise + * @param position sequence position of the current token + * @return logits array for token sampling */ + FloatArray tornadoVMForwardExecuteLayered(int position); - /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context - * window. - * - *

The execution happens in three phases: - *

    - *
  1. Initial token embedding lookup (already done before calling this method)
  2. - *
  3. Sequential processing through each transformer layer using TornadoVM
  4. - *
  5. Final projection to logits using TornadoVM
  6. - *
- * - * @param position - * The current position in the sequence being processed - * @return FloatTensor containing the output logits for token prediction - */ - - // int pos, ModelPlanner - public FloatArray tornadoVMForwardExecuteLayered(int position) { - // @formatter:off - // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - - // Set the position in the state object (used by attention layers) - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - // 2. Execute each transformer layer graph sequentially - // Each graph computes attention and feed-forward transformations for one layer - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - } - state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f - state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f - // 3. Execute the final graph that projects the last hidden state to output logits - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); - - // @formatter:on - // Return the logits (used for token prediction) - return state.wrapLogits; - } + /** Releases all device memory held by this plan. */ + void freeTornadoExecutionPlan(); /** - * Returns the graph index for the pre-processing step (e.g., token embedding). - */ - private int getPreprocessingGraphIndex() { - return 0; - } - - /** - * Returns the graph index for the given transformer layer. + * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - * @param layerIndex - * Index of the transformer layer (0-based) - */ - private int getLayerGraphIndex(int layerIndex) { - return 1 + layerIndex; - } - - /** - * Returns the graph index for the final projection to logits. + *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill} + * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.

+ * + * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) + * @param model the model instance + * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan} */ - private int getFinalLogitsGraphIndex() { - return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; - } - - /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. - public void forceCopyInReadOnlyDataLayered() { - // Execute all TornadoVM graphs - state.wrapX.clear(); - state.positionHolder.init(0); - - // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - - // Execute layer processing graphs - for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { + int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); + TornadoVMMasterPlan plan; + if (batchSize > 1) { + plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan( + (LlamaState) state, model, batchSize); + } else { + plan = TornadoVMMasterPlanStandard.initialize(state, model); } - - // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - } - - /** - * Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks. - */ - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); + model.setTornadoVMPlan(plan); + return plan; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java new file mode 100644 index 00000000..b2388bf3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -0,0 +1,342 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/** + * Unified GPU execution plan for Phase 4: batched prefill + single-token decode. + * + *

A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via + * {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would + * allocate independent device buffers and lose the prefill KV state.

+ * + *

Graph layout (2N+3 graphs total):

+ *
+ *   [0]         batch activation     B×dim FP16 → FP32
+ *   [1..N]      batch layer graphs   B tokens, all transformer ops
+ *   [N+1]       decode activation    single-token FP16 → FP32 + KV-cache pass-through
+ *   [N+2..2N+1] decode layer graphs  single-token, standard kernels
+ *   [2N+2]      logits graph
+ * 
+ * + *

KV cache pointer chain across phases:

+ *
+ *   batchLayer[N-1]  --persistOnDevice(wrapKeyCache)-→
+ *   decodeActivation --consumeFromDevice(wrapKeyCache)-→  (pass-through)
+ *   decodeLayer[0]   --consumeFromDevice(wrapKeyCache)-→  (used by attention)
+ * 
+ */ +public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { + + private static final boolean ENABLE_TIMING = + Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final LlamaState state; + private final LlamaConfiguration config; + private final int batchSize; + private final int N; // numberOfLayers + private final TornadoExecutionPlan executionPlan; + private final GridScheduler gridScheduler; + + // ── Graph-index helpers ─────────────────────────────────────────────────── + private int batchActivationIdx() { return 0; } + private int batchLayerIdx(int i) { return 1 + i; } + private int decodeActivationIdx() { return N + 1; } + private int decodeLayerIdx(int i) { return N + 2 + i; } + private int logitsIdx() { return 2 * N + 2; } + + // ── Construction ───────────────────────────────────────────────────────── + private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.batchSize = batchSize; + this.N = config.numberOfLayers(); + + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + + List all = new ArrayList<>(2 * N + 3); + GridScheduler scheduler = new GridScheduler(); + + // [0] Batch activation ──────────────────────────────────────────────── + KernelContext batchActCtx = new KernelContext(); + all.add(buildBatchActivationGraph(batchActCtx).snapshot()); + scheduler.addWorkerGrid("batchActivation.batchUpdateX", + WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + + // [1..N] Batch layer graphs ─────────────────────────────────────────── + LlamaFP16BatchPrefillLayers batchLayers = + new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(scheduler); + + // [N+1] Decode activation (with KV-cache pass-through) ──────────────── + KernelContext decodeActCtx = new KernelContext(); + all.add(buildDecodeActivationGraph(decodeActCtx).snapshot()); + scheduler.addWorkerGrid("activationUpdate.updateX", + WorkerGridFactory.genericWorker(config.dim(), 128)); + + // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── + // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). + LlamaFP16FFNLayersForUnifiedDecode decodeLayers = + new LlamaFP16FFNLayersForUnifiedDecode( + "llamaFFNDecode", state, weights, config, schedulerType); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(scheduler); + + // [2N+2] Logits ─────────────────────────────────────────────────────── + LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + all.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(scheduler); + + this.gridScheduler = scheduler; + this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); + } + + // ── Activation graphs ───────────────────────────────────────────────────── + + /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ + private TaskGraph buildBatchActivationGraph(KernelContext ctx) { + return new TaskGraph("batchActivation") + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) + .task("batchUpdateX", + (KernelContext c, HalfFloatArray src, FloatArray dst) -> + dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()), + ctx, state.embeddingXBatch, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + + /** + * Graph N+1: single-token FP16 → FP32. + * + *

Receives the KV-cache device pointer from batch layer N via + * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so + * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. + * Both halves of the chain are required; without the re-persist the pointer is + * not forwarded in interpreter (non-CUDA-graph) mode.

+ */ + private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { + return new TaskGraph("activationUpdate") + .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", + TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX) + // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache + // re-persisted so updatePersistedObjectState() propagates the device + // pointer to decode layer 0's consumeFromDevice without CUDA graphs. + .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + // ── Static factory ──────────────────────────────────────────────────────── + + /** + * Creates, JIT-compiles, and warms up the unified plan. + * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. + */ + public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( + LlamaState state, Model model, int batchSize) { + + long t0 = System.nanoTime(); + TornadoVMMasterPlanBatchPrefill plan = + new TornadoVMMasterPlanBatchPrefill(state, model, batchSize); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + plan.executionPlan.withAllGraphs().withCUDAGraph(); + plan.executionPlan.withPreCompilation(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + plan.forceCopyInReadOnlyData(); + + if (ENABLE_TIMING) + System.err.printf("[BatchPlan] Init complete: %.2f ms%n", + (System.nanoTime() - t0) / 1e6); + + return plan; + } + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + private void forceCopyInReadOnlyData() { + state.wrapXBatch.clear(); + state.wrapX.clear(); + state.positionHolder.init(0); + state.batchStartPosHolder.init(0); + + for (int i = 0; i <= logitsIdx(); i++) { + executionPlan.withGraph(i) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk (length == batchSize, or tail) + * @param startPos sequence position of tokenIds[0] + * @param model model (for embedding table) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + // Copy B embeddings into embeddingXBatch + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + state.batchStartPosHolder.set(0, startPos); + + // Graph 0: batch activation + executionPlan.withGraph(batchActivationIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + // Graphs 1..N: batch transformer layers + for (int l = 0; l < N; l++) { + executionPlan.withGraph(batchLayerIdx(l)) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + // Logits skipped — not needed for prefill positions. + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (for embedding table) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + int dim = config.dim(); + + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + // Graph N+1: decode activation + executionPlan.withGraph(decodeActivationIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + // Graphs N+2..2N+1: decode transformer layers + for (int l = 0; l < N; l++) { + executionPlan.withGraph(decodeLayerIdx(l)) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + // Graph 2N+2: logits + executionPlan.withGraph(logitsIdx()) + .withGridScheduler(gridScheduler) + .withCUDAGraph() + .execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } + + // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── + + /** + * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses + * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * + *

This ensures decode layer 0 receives the KV-cache device pointer that was + * persisted by the last batch prefill layer and passed through the decode + * activation graph.

+ */ + private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { + + LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come + // from device (passed through by the decode activation graph). + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // KV cache: consume from device (device pointer supplied by + // decode activation's pass-through from last batch layer). + layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + } else { + // Identical to parent for layers 1+ (already uses consumeFromDevice). + return super.configureLayerDataTransfers(layer, layerIndex); + } + return layer; + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java new file mode 100644 index 00000000..91586f2c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -0,0 +1,149 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * Standard (single-token) GPU execution plan. + * + *

Processes one token at a time through preprocessing + N transformer layers + + * logits projection. Used for both the baseline GPU path and the Phase 2 + * sequential prefill/decode path.

+ */ +public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { + + public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + + private final State state; + private final Configuration config; + public TornadoExecutionPlan executionPlan; + GenericLayerPlanner tornadoVMLayerPlanner; + + public TornadoVMMasterPlanStandard(State state, Model model) { + this.tornadoVMLayerPlanner = createPlanner(state, model); + this.executionPlan = createExecutionPlan(); + this.state = state; + this.config = model.configuration(); + } + + /** + * Initializes and warms up the standard TornadoVM plan. + * + * @param state the model state containing KV cache + * @param model the model instance + * @return the initialized plan ready for inference + */ + static TornadoVMMasterPlanStandard initialize(State state, Model model) { + long startTime = System.nanoTime(); + long planCreationTime = 0; + long warmupTime = 0; + + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model); + + if (ENABLE_TORNADOVM_INIT_TIME) { + planCreationTime = System.nanoTime(); + System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); + } + + tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + tornadoVMPlan.executionPlan.withPreCompilation(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + warmupTime = System.nanoTime(); + System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); + } + + tornadoVMPlan.forceCopyInReadOnlyDataLayered(); + + if (ENABLE_TORNADOVM_INIT_TIME) { + long copyTime = System.nanoTime(); + System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); + System.err.printf("Finished TornadoVM initialization...\n \n"); + } + + return tornadoVMPlan; + } + + private TornadoExecutionPlan createExecutionPlan() { + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + private GenericLayerPlanner createPlanner(State state, Model model) { + GGMLType weightType = model.weights().getWeightType(); + return QuantizationPlannerFactory.create(weightType, state, model); + } + + @Override + public FloatArray tornadoVMForwardExecuteLayered(int position) { + // @formatter:off + executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() + .execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + executionPlan.withGraph(getLayerGraphIndex(layer)) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + //.withCUDAGraph() + .execute(); + } + state.tempLogits.clear(); + state.wrapLogits.clear(); + executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) + .withCUDAGraph() + .execute(); + // @formatter:on + return state.wrapLogits; + } + + private int getPreprocessingGraphIndex() { + return 0; + } + + private int getLayerGraphIndex(int layerIndex) { + return 1 + layerIndex; + } + + private int getFinalLogitsGraphIndex() { + return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; + } + + public void forceCopyInReadOnlyDataLayered() { + state.wrapX.clear(); + state.positionHolder.init(0); + + //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 61b81bef..e5262b17 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -6,9 +6,9 @@ import uk.ac.manchester.tornado.api.types.arrays.FloatArray; /** - * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass. + * Wraps {@link TornadoVMMasterPlanStandard} and adds a prefill-only GPU forward pass. * - *

Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.

+ *

Parallel to {@link TornadoVMMasterPlanStandard} — does NOT modify it.

* *

The existing execution plan has this graph layout:

*
@@ -28,11 +28,11 @@
  */
 public class TornadoVMMasterPlanWithPrefillDecode {
 
-    private final TornadoVMMasterPlan plan;
+    private final TornadoVMMasterPlanStandard plan;
     private final State state;
     private final Configuration config;
 
-    public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) {
+    public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, State state, Model model) {
         this.plan = plan;
         this.state = state;
         this.config = model.configuration();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
new file mode 100644
index 00000000..9bba3860
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
@@ -0,0 +1,461 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * GPU kernels for batched prefill (Phase 4).
+ *
+ * 

Each kernel processes {@code batchSize} tokens simultaneously. + * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}. + * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.

+ * + *

These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + * batch task graphs; they are NOT invoked directly.

+ */ +public final class TransformerBatchPrefillKernels { + + private TransformerBatchPrefillKernels() {} + + // ── Activation ──────────────────────────────────────────────────────────── + + /** + * Converts B×dim FP16 token embeddings to FP32. + * Worker: B*dim global threads, localSize=128. + */ + public static void batchEmbeddingToFP32(KernelContext context, + HalfFloatArray embeddingXBatch, + FloatArray wrapXBatch) { + int gid = context.globalIdx; + wrapXBatch.set(gid, embeddingXBatch.get(gid).getFloat32()); + } + + // ── RMS Norm (attention) ───────────────────────────────────────────────── + + /** + * Sequential RMS reduction — one thread per batch item. + * + *

Each thread computes the RMS scale factor for its token: + * {@code scale[b] = 1 / sqrt( mean(x[b]²) + eps )}

+ * + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray attnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + attnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + /** + * Applies RMS normalization and FP16-quantizes the result. + * + *

{@code xbFP16Batch[b*dim+i] = FP16( rmsWeights[i] * scale[b] * x[b*dim+i] )}

+ * + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP16(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + float scale = attnScaleBatch.get(b); + float result = rmsWeights.get(i) * scale * wrapXBatch.get(gid); + xbFP16Batch.set(gid, new HalfFloat(result)); + } + + // ── QKV Projection ──────────────────────────────────────────────────────── + + /** + * Fused batched QKV projection (FP16 weights, FP16 input). + * + *

One workgroup per (batchIdx, outputRow) pair. + * globalGroupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.

+ * + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmul(KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowOff = rowIdx * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowOff = kRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowOff = vRow * dim; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + // ── RoPE + KV Cache ─────────────────────────────────────────────────────── + + /** + * Fused batched RoPE rotation + KV cache write. + * + *

globalIdx encodes (batchIdx, pairIdx) as {@code batchIdx*(dim/2) + pairIdx}. + * Position for token b = {@code startPos + b}.

+ * + * Worker: B*(dim/2) global threads, localSize=512 (or less if B*dim/2 < 512). + */ + public static void batchedRopeWithKVCache(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, int headSize, + int layerIndex, int contextLength, int dim) { + int globalIdx = context.globalIdx; + int halfDim = dim / 2; + int batchIdx = globalIdx / halfDim; + int pairIdx = globalIdx % halfDim; + int i = pairIdx * 2; + + int pos = batchStartPosHolder.get(0) + batchIdx; + int qOffset = batchIdx * dim; + int kOffset = batchIdx * kvDim; + + if (i + 1 < dim) { + int head_dim = i % headSize; + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = wrapQBatch.get(qOffset + i); + float v1q = wrapQBatch.get(qOffset + i + 1); + wrapQBatch.set(qOffset + i, v0q * fcr - v1q * fci); + wrapQBatch.set(qOffset + i + 1, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache + if (i + 1 < kvDim) { + float v0k = wrapKBatch.get(kOffset + i); + float v1k = wrapKBatch.get(kOffset + i + 1); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kOffset + i, rotK0); + wrapKBatch.set(kOffset + i + 1, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim; + wrapKeyCache.set(cacheOff + i, rotK0); + wrapKeyCache.set(cacheOff + i + 1, rotK1); + wrapValueCache.set(cacheOff + i, wrapVBatch.get(kOffset + i)); + wrapValueCache.set(cacheOff + i + 1, wrapVBatch.get(kOffset + i + 1)); + } + } + } + + // ── Attention ───────────────────────────────────────────────────────────── + + /** + * Batched causal flash attention. + * + *

One workgroup per (batchIdx, headIdx) pair: + * {@code groupIdx = batchIdx * nHeads + headIdx}. + * Token b attends to positions 0..{@code startPos + b} (causal).

+ * + * Worker: B*nHeads workgroups × optimalLocalSize threads. + */ + public static void batchedFlashAttention(KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + FloatArray wrapXbBatch, + int nHeads, int headSize, + int kvDim, int kvMul, + int layerIndex, int contextLength, int dim) { + int tid = context.localIdx; + int groupId = context.groupIdx; + int localSz = context.localGroupSizeX; + + int batchIdx = groupId / nHeads; + int h = groupId % nHeads; + int pos = batchStartPosHolder.get(0) + batchIdx; + int loff = layerIndex * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_C = 16; + + float[] qShared = context.allocateFloatLocalArray(headSize); + float[] kTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] vTile = context.allocateFloatLocalArray(BLOCK_C * headSize); + float[] sTile = context.allocateFloatLocalArray(BLOCK_C); + float[] maxHolder = context.allocateFloatLocalArray(1); + + // Load Q into shared memory + int qOffset = batchIdx * dim + h * headSize; + for (int i = tid; i < headSize; i += localSz) { + qShared[i] = wrapQBatch.get(qOffset + i); + } + context.localBarrier(); + + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) output[i] = 0.0f; + + for (int tileC = 0; tileC <= pos; tileC += BLOCK_C) { + int tileEnd = Math.min(tileC + BLOCK_C - 1, pos); + + // Load K/V tile + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + int tileMOff = tInTile * headSize; + for (int d = 0; d < headSize; d++) { + int kvOff = loff + t * kvDim + kvHeadIdx * headSize + d; + kTile[tileMOff + d] = wrapKeyCache.get(kvOff); + vTile[tileMOff + d] = wrapValueCache.get(kvOff); + } + } + context.localBarrier(); + + // Compute attention scores + for (int t = tileC + tid; t <= tileEnd; t += localSz) { + int tInTile = t - tileC; + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += qShared[d] * kTile[tInTile * headSize + d]; + } + sTile[tInTile] = score / TornadoMath.sqrt(headSize); + } + context.localBarrier(); + + // Tile max + float tileMax = Float.NEGATIVE_INFINITY; + for (int t = 0; t <= tileEnd - tileC; t++) { + if (sTile[t] > tileMax) tileMax = sTile[t]; + } + if (tid == 0) maxHolder[0] = tileMax; + context.localBarrier(); + float curTileMax = maxHolder[0]; + + float newMax = Math.max(maxScore, curTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) output[d] *= scale; + } + maxScore = newMax; + + for (int t = 0; t <= tileEnd - tileC; t++) { + float expScore = TornadoMath.exp(sTile[t] - maxScore); + sumExp += expScore; + for (int d = 0; d < headSize; d++) { + output[d] += expScore * vTile[t * headSize + d]; + } + } + context.localBarrier(); + } + + float norm = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + int xbOffset = batchIdx * dim + h * headSize; + for (int d = tid; d < headSize; d += localSz) { + wrapXbBatch.set(xbOffset + d, output[d] * norm); + } + } + + // ── Output / FFN Projections ───────────────────────────────────────────── + + /** + * Batched matrix-vector multiply with residual add. + * + *

Used for both the attention output projection (Wo) and the FFN down + * projection (W2). One workgroup per (batchIdx, outputRow): + * {@code groupIdx = batchIdx * d + rowIdx}.

+ * + *
    + *
  • Wo: inputBatch=xbBatch (B×dim), outputBatch=xBatch (B×dim), n=dim, d=dim
  • + *
  • W2: inputBatch=hbBatch (B×hiddenDim), outputBatch=xBatch (B×dim), n=hiddenDim, d=dim
  • + *
+ * + * Worker: B*d workgroups × localWorkGroupSize threads. + */ + public static void batchedMatVecWithResidual(KernelContext context, + FloatArray inputBatch, + FloatArray outputBatch, + HalfFloatArray w, + int n, int d, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / d; + int rowIdx = groupId % d; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + int inputOff = batchIdx * n; + int rowOff = rowIdx * n; + + float partial = 0.0f; + for (int j = localId; j < n; j += localWorkGroupSize) { + partial += w.get(rowOff + j).getFloat32() * inputBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) { + int outIdx = batchIdx * d + rowIdx; + outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]); + } + } + + // ── FFN RMS Norm ───────────────────────────────────────────────────────── + + /** + * Sequential FFN RMS reduction — one thread per batch item. + * Worker: batchSize global threads, localSize=1. + */ + public static void batchedFFNRmsReduce(KernelContext context, + FloatArray wrapXBatch, + FloatArray ffnScaleBatch, + int dim, float eps) { + int b = context.globalIdx; + int base = b * dim; + float ss = 0.0f; + for (int i = 0; i < dim; i++) { + float val = wrapXBatch.get(base + i); + ss += val * val; + } + ss /= dim; + ss += eps; + ffnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss)); + } + + // ── FFN SwiGLU ─────────────────────────────────────────────────────────── + + /** + * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU. + * + *

One workgroup per (batchIdx, hiddenRow): + * {@code groupIdx = batchIdx * hiddenDim + rowIdx}.

+ * + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUp(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + HalfFloatArray w1, + HalfFloatArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + int rowOff = rowIdx * dim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + // W1 matmul with inline RMS apply + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + // W3 matmul with inline RMS apply + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3.get(rowOff + j).getFloat32() * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + // SiLU(W1·x) × (W3·x) + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java new file mode 100644 index 00000000..e28ecb19 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java @@ -0,0 +1,238 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill. + * + *

One {@link ImmutableTaskGraph} per transformer layer, each processing + * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.

+ * + *

KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device + * after every layer so the subsequent single-token decode layers can consume it.

+ */ +public class LlamaFP16BatchPrefillLayers { + + // Matches the local workgroup size used by the single-token kernels. + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchLayerTaskGraph(int layerIndex) { + String graphName = "batchLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + // Allocate persistent GPU-side intermediates once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + // wrapXBatch produced by the batch activation graph + layer.consumeFromDevice(state.wrapXBatch); + } else { + layer.consumeFromDevice( + context, + state.wrapXBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP16, + context, state.wrapXbFP16Batch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + TransformerBatchPrefillKernels::batchedFusedQKVMatmul, + context, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + dim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_rope_kv", + TransformerBatchPrefillKernels::batchedRopeWithKVCache, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, config.headSize(), layerIndex, config.contextLength(), dim); + + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), config.headSize(), + kvDim, config.kvMul(), layerIndex, config.contextLength(), dim); + + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asHalfFloatArray(), + dim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asHalfFloatArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + // Persist wrapXBatch for the next layer, and KV cache so the decode + // layers can consume it via the activation graph pass-through. + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + /** Registers all batch layer workers in the shared {@link GridScheduler}. */ + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + int nHeads = config.numberOfHeads(); + int headSz = config.headSize(); + + // RMS: one thread per batch token + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + + // RMS apply: B*dim threads, local=256 (dim is always a multiple of 256 for LLaMA) + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + // QKV: B*(dim+2*kvDim) workgroups × LOCAL_WORK_GROUP_SIZE + int qkvRows = dim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + // RoPE+KV cache: B*(dim/2) threads, local=512 + int ropeGlobal = batchSize * (dim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + // Attention (flash): B*nHeads workgroups × optimalLocalSize + int optLocal = findOptimalLocalSize(headSz); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * nHeads * optLocal, optLocal); + + // Mat-vec (Wo, W2): B*d workgroups × LOCAL_WORK_GROUP_SIZE + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} From 4152edb48829dd07df4e7203a91785adc05c1acc Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:22:19 +0300 Subject: [PATCH 07/12] [prf/dec][refactor] Restructure prefill-decode ExecutionPlan components to dedicated classes and packages Move `LlamaFP16BatchPrefillLayers` to `tornadovm.layers.type.fp16.prefll` and `LlamaFP16FFNLayersForUnifiedDecode` to `tornadovm.layers.type.fp16.decode` --- .../TornadoVMMasterPlanBatchPrefill.java | 49 +++---------------- .../LlamaFP16FFNLayersForUnifiedDecode.java | 47 ++++++++++++++++++ .../LlamaFP16BatchPrefillLayers.java | 2 +- 3 files changed, 56 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/{ => prefill}/LlamaFP16BatchPrefillLayers.java (99%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index b2388bf3..258bb9fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -8,8 +8,8 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -300,43 +300,10 @@ public void freeTornadoExecutionPlan() { } // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ─────── - - /** - * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses - * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. - * - *

This ensures decode layer 0 receives the KV-cache device pointer that was - * persisted by the last batch prefill layer and passed through the decode - * activation graph.

- */ - private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { - - LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, - LlamaTornadoWeights weights, LlamaConfiguration config, - SchedulerType schedulerType) { - super(taskGraph, state, weights, config, schedulerType); - } - - @Override - protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { - if (layerIndex == 0) { - // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come - // from device (passed through by the decode activation graph). - layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.positionHolder, state.temp, state.tempFFN); - layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapXb, state.wrapXb2, - state.wrapQ, state.wrapK, state.wrapV, - state.wrapAtt, state.wrapHb, state.wrapXbFP16); - // KV cache: consume from device (device pointer supplied by - // decode activation's pass-through from last batch layer). - layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); - } else { - // Identical to parent for layers 1+ (already uses consumeFromDevice). - return super.configureLayerDataTransfers(layer, layerIndex); - } - return layer; - } - } +// moved to package +// +// private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { +// +// +// } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java new file mode 100644 index 00000000..b1cd063f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java @@ -0,0 +1,47 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses + * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * + *

This ensures decode layer 0 receives the KV-cache device pointer that was + * persisted by the last batch prefill layer and passed through the decode + * activation graph.

+ */ +public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { + public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come + // from device (passed through by the decode activation graph). + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + // KV cache: consume from device (device pointer supplied by + // decode activation's pass-through from last batch layer). + layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + } else { + // Identical to parent for layers 1+ (already uses consumeFromDevice). + return super.configureLayerDataTransfers(layer, layerIndex); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java index e28ecb19..8414be72 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layers.type.fp16; +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; From 8d2d15dd2ced670f753e3b827b96bd1575f47246 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:27:34 +0300 Subject: [PATCH 08/12] [prf/dec][dbg] Guard CUDA Graphs enable/disable behind `--no-cuda-graphs` option to ease debugging --- llama-tornado | 9 ++++ .../tornadovm/TornadoVMMasterPlan.java | 4 ++ .../TornadoVMMasterPlanBatchPrefill.java | 44 ++++++++----------- .../TornadoVMMasterPlanStandard.java | 18 ++++---- 4 files changed, 41 insertions(+), 34 deletions(-) diff --git a/llama-tornado b/llama-tornado index 81349a5e..4900a100 100755 --- a/llama-tornado +++ b/llama-tornado @@ -93,6 +93,9 @@ class LlamaRunner: if args.prefill_batch_size is not None: cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}") + if args.no_cuda_graphs: + cmd.append("-Dllama.cudaGraphs=false") + # Debug options debug_config = [] @@ -493,6 +496,12 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)", ) + prefill_group.add_argument( + "--no-cuda-graphs", + dest="no_cuda_graphs", + action="store_true", + help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging", + ) # Advanced options advanced_group = parser.add_argument_group("Advanced Options") diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 43030fd0..c81ba92c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -26,6 +26,10 @@ public interface TornadoVMMasterPlan { boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean( System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); + /** When {@code false}, {@code withCUDAGraph()} is never called — useful for debugging. */ + boolean CUDA_GRAPHS = Boolean.parseBoolean( + System.getProperty("llama.cudaGraphs", "true")); + /** * Single-token forward pass returning output logits. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index 258bb9fe..cc6591c2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -170,7 +170,7 @@ public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", (System.nanoTime() - t0) / 1e6); - plan.executionPlan.withAllGraphs().withCUDAGraph(); + if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph(); plan.executionPlan.withPreCompilation(); if (ENABLE_TIMING) @@ -194,10 +194,9 @@ private void forceCopyInReadOnlyData() { state.batchStartPosHolder.init(0); for (int i = 0; i <= logitsIdx(); i++) { - executionPlan.withGraph(i) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); } } @@ -226,17 +225,15 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod state.batchStartPosHolder.set(0, startPos); // Graph 0: batch activation - executionPlan.withGraph(batchActivationIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); // Graphs 1..N: batch transformer layers for (int l = 0; l < N; l++) { - executionPlan.withGraph(batchLayerIdx(l)) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); } // Logits skipped — not needed for prefill positions. } @@ -263,27 +260,24 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { state.tempFFN.clear(); // Graph N+1: decode activation - executionPlan.withGraph(decodeActivationIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { - executionPlan.withGraph(decodeLayerIdx(l)) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + decodeLayer.execute(); } state.tempLogits.clear(); state.wrapLogits.clear(); // Graph 2N+2: logits - executionPlan.withGraph(logitsIdx()) - .withGridScheduler(gridScheduler) - .withCUDAGraph() - .execute(); + var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); return state.wrapLogits; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java index 91586f2c..c9d816ee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java @@ -56,7 +56,7 @@ static TornadoVMMasterPlanStandard initialize(State state, Model model) { System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); } - tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); + if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph(); tornadoVMPlan.executionPlan.withPreCompilation(); if (ENABLE_TORNADOVM_INIT_TIME) { @@ -89,10 +89,10 @@ private GenericLayerPlanner createPlanner(State state, Model model) { @Override public FloatArray tornadoVMForwardExecuteLayered(int position) { // @formatter:off - executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); + var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) preGraph.withCUDAGraph(); + preGraph.execute(); state.positionHolder.set(0, position); state.temp.clear(); @@ -106,10 +106,10 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { } state.tempLogits.clear(); state.wrapLogits.clear(); - executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - .withCUDAGraph() - .execute(); + var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex()) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); + logitsGraph.execute(); // @formatter:on return state.wrapLogits; } From 2e51fc29a354e0a5c92f66e673c8f7e24e2be4a3 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:32:24 +0300 Subject: [PATCH 09/12] [prf/dec][refactor] Rename `LlamaFP16FFNLayersForUnifiedDecode` to `LlamaFP16FFNLayersDecode` --- .../tornadovm/TornadoVMMasterPlanBatchPrefill.java | 6 +++--- ...orUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/{LlamaFP16FFNLayersForUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} (85%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java index cc6591c2..6ce9c23e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; @@ -100,8 +100,8 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). - LlamaFP16FFNLayersForUnifiedDecode decodeLayers = - new LlamaFP16FFNLayersForUnifiedDecode( + LlamaFP16FFNLayersDecode decodeLayers = + new LlamaFP16FFNLayersDecode( "llamaFFNDecode", state, weights, config, schedulerType); all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java similarity index 85% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index b1cd063f..f781f08a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -16,10 +16,10 @@ * persisted by the last batch prefill layer and passed through the decode * activation graph.

*/ -public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers { - public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state, - LlamaTornadoWeights weights, LlamaConfiguration config, - SchedulerType schedulerType) { +public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { + public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); } From 885473b2eddb20d26329d182c0d42d8a20ef3233 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 20:42:32 +0300 Subject: [PATCH 10/12] [prf/dec][refactor] Rename `LlamaFP16BatchPrefillLayers` to `LlamaFP16LayersBatchPrefill` --- ...doVMMasterPlanWithBatchPrefillDecode.java} | 22 ++++++++++++------- ....java => LlamaFP16LayersBatchPrefill.java} | 10 ++++----- 2 files changed, 19 insertions(+), 13 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/{TornadoVMMasterPlanBatchPrefill.java => TornadoVMMasterPlanWithBatchPrefillDecode.java} (91%) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/{LlamaFP16BatchPrefillLayers.java => LlamaFP16LayersBatchPrefill.java} (97%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java similarity index 91% rename from src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java rename to src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index 6ce9c23e..c8772e61 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -9,7 +9,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -80,15 +80,15 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch List all = new ArrayList<>(2 * N + 3); GridScheduler scheduler = new GridScheduler(); - // [0] Batch activation ──────────────────────────────────────────────── + // [0] Batch prefill activation ──────────────────────────────────────────────── KernelContext batchActCtx = new KernelContext(); - all.add(buildBatchActivationGraph(batchActCtx).snapshot()); + all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); scheduler.addWorkerGrid("batchActivation.batchUpdateX", WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); - // [1..N] Batch layer graphs ─────────────────────────────────────────── - LlamaFP16BatchPrefillLayers batchLayers = - new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize); + // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── + LlamaFP16LayersBatchPrefill batchLayers = + new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); all.addAll(batchLayers.getLayerImmutableTaskGraphs()); batchLayers.updateGridScheduler(scheduler); @@ -116,10 +116,10 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); } - // ── Activation graphs ───────────────────────────────────────────────────── + // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ - private TaskGraph buildBatchActivationGraph(KernelContext ctx) { + private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { return new TaskGraph("batchActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) @@ -142,6 +142,9 @@ private TaskGraph buildBatchActivationGraph(KernelContext ctx) { private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { return new TaskGraph("activationUpdate") .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through +// .transferToDevice(DataTransferMode.EVERY_EXECUTION, +// state.wrapKeyCache, +// state.wrapValueCache) .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", @@ -235,6 +238,7 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); batchLayer.execute(); } + //System.err.println("[DEBUG] last batch layer done, about to return from prefill"); // Logits skipped — not needed for prefill positions. } @@ -262,12 +266,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { // Graph N+1: decode activation var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); decodeAct.execute(); // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)"); decodeLayer.execute(); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 8414be72..30e3267e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -24,7 +24,7 @@ *

KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device * after every layer so the subsequent single-token decode layers can consume it.

*/ -public class LlamaFP16BatchPrefillLayers { +public class LlamaFP16LayersBatchPrefill { // Matches the local workgroup size used by the single-token kernels. static final int LOCAL_WORK_GROUP_SIZE = 32; @@ -37,20 +37,20 @@ public class LlamaFP16BatchPrefillLayers { private final List layerITGs; private String lastLayerTaskGraphID; - public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights, - LlamaConfiguration config, int batchSize) { + public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { this.state = state; this.weights = weights; this.config = config; this.batchSize = batchSize; this.layerITGs = IntStream.range(0, config.numberOfLayers()) - .mapToObj(this::createBatchLayerTaskGraph) + .mapToObj(this::createBatchPrefillLayerTaskGraph) .map(TaskGraph::snapshot) .toList(); } // @formatter:off - private TaskGraph createBatchLayerTaskGraph(int layerIndex) { + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { String graphName = "batchLayer_" + layerIndex; if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; From 61287934317c674b81860ae20282a822854a07f0 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 3 Apr 2026 21:23:09 +0300 Subject: [PATCH 11/12] [prf/dec][refactor] Rename `TornadoVMMasterPlanBatchPrefill` to `TornadoVMMasterPlanWithBatchPrefillDecode` --- .../inference/InferenceEngineWithPrefillDecode.java | 4 ++-- .../gpullama3/tornadovm/TornadoVMMasterPlan.java | 10 +++++----- .../TornadoVMMasterPlanWithBatchPrefillDecode.java | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index 6517df12..74c474e9 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -10,7 +10,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; @@ -202,7 +202,7 @@ public static List generateTokensGPULlama( // ── Phase 4: Batch GPU Prefill ──────────────────────────────────── // Plan was pre-initialized in Model.runInstructOnce/runInteractive // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan. - TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan; + TornadoVMMasterPlanWithBatchPrefillDecode plan = (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; int N = promptTokens.size(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index c81ba92c..1acd8f24 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -12,13 +12,13 @@ *
    *
  • {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the * baseline GPU path and Phase 2 sequential prefill/decode.
  • - *
  • {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched + *
  • {@link TornadoVMMasterPlanWithBatchPrefillDecode} — unified plan for Phase 4 batched * prefill + single-token decode within one {@code TornadoExecutionPlan}.
  • *
* *

The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a - * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a + * {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a * {@link TornadoVMMasterPlanStandard}.

*/ public interface TornadoVMMasterPlan { @@ -35,7 +35,7 @@ public interface TornadoVMMasterPlan { * *

Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM}) * and the Phase 2 sequential decode path. Not applicable to - * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.

+ * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.

* * @param position sequence position of the current token * @return logits array for token sampling @@ -48,7 +48,7 @@ public interface TornadoVMMasterPlan { /** * Factory: creates, JIT-compiles, and warms up the appropriate plan. * - *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill} + *

When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode} * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.

* * @param state the model state (must be {@link LlamaState} when batch size {@code > 1}) @@ -59,7 +59,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { int batchSize = Integer.getInteger("llama.prefillBatchSize", 1); TornadoVMMasterPlan plan; if (batchSize > 1) { - plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan( + plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan( (LlamaState) state, model, batchSize); } else { plan = TornadoVMMasterPlanStandard.initialize(state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index c8772e61..f8b2cf63 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -48,7 +48,7 @@ * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention) *
*/ -public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { +public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { private static final boolean ENABLE_TIMING = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); @@ -68,7 +68,7 @@ public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan { private int logitsIdx() { return 2 * N + 2; } // ── Construction ───────────────────────────────────────────────────────── - private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) { + private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) { this.state = state; this.config = (LlamaConfiguration) model.configuration(); this.batchSize = batchSize; @@ -162,12 +162,12 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { * Creates, JIT-compiles, and warms up the unified plan. * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}. */ - public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan( + public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan( LlamaState state, Model model, int batchSize) { long t0 = System.nanoTime(); - TornadoVMMasterPlanBatchPrefill plan = - new TornadoVMMasterPlanBatchPrefill(state, model, batchSize); + TornadoVMMasterPlanWithBatchPrefillDecode plan = + new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize); if (ENABLE_TIMING) System.err.printf("[BatchPlan] Graph construction: %.2f ms%n", From 2b1aababc02eb29982b6cadbf1636176fab3dd62 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 7 Apr 2026 13:34:34 +0300 Subject: [PATCH 12/12] [prf/dec] Fix KV-cache propagation bug from prefill to decode path and refactor task graph consumption logic Introduce `LogitsFP16LayerDecode` with KV-cache pass-through. Override `consumeFromDevice` and `persistOnDevice` in LlamaFFN layers to fix cross-graph propagation for both CUDA and interpreter modes. --- ...adoVMMasterPlanWithBatchPrefillDecode.java | 45 ++++++++++----- .../layers/type/fp16/LlamaFP16FFNLayers.java | 34 ++++++++++- .../layers/type/fp16/LogitsFP16Layer.java | 15 +++++ .../fp16/decode/LlamaFP16FFNLayersDecode.java | 56 +++++++++++++++---- .../fp16/decode/LogitsFP16LayerDecode.java | 53 ++++++++++++++++++ .../prefill/LlamaFP16LayersBatchPrefill.java | 15 ++++- 6 files changed, 187 insertions(+), 31 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index f8b2cf63..3df08dfa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -10,7 +10,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -94,8 +94,8 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, // [N+1] Decode activation (with KV-cache pass-through) ──────────────── KernelContext decodeActCtx = new KernelContext(); - all.add(buildDecodeActivationGraph(decodeActCtx).snapshot()); - scheduler.addWorkerGrid("activationUpdate.updateX", + all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); + scheduler.addWorkerGrid("decodeActivationUpdate.updateX", WorkerGridFactory.genericWorker(config.dim(), 128)); // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── @@ -107,7 +107,10 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, decodeLayers.updateGridScheduler(scheduler); // [2N+2] Logits ─────────────────────────────────────────────────────── - LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config, + // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) + // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the + // KV-cache pointer survives the logits → decode-activation boundary across tokens. + LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); all.add(logitsLayer.getImmutableTaskGraph()); logitsLayer.updateGridScheduler(scheduler); @@ -123,9 +126,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { return new TaskGraph("batchActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("batchUpdateX", - (KernelContext c, HalfFloatArray src, FloatArray dst) -> - dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()), + .task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32, ctx, state.embeddingXBatch, state.wrapXBatch) .persistOnDevice(state.wrapXBatch); } @@ -139,17 +140,24 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { * Both halves of the chain are required; without the re-persist the pointer is * not forwarded in interpreter (non-CUDA-graph) mode.

*/ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx) { - return new TaskGraph("activationUpdate") - .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through -// .transferToDevice(DataTransferMode.EVERY_EXECUTION, -// state.wrapKeyCache, -// state.wrapValueCache) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) + private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { +// System.out.println("lastBatchLayerID = " + lastBatchLayerID); +// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString()); +// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString()); + return new TaskGraph("decodeActivationUpdate") + .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV) + //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", TransformerComputeKernels::convertFP16toFP32, ctx, (HalfFloatArray) state.embeddingX, state.wrapX) +// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe +// .task("dbgKV", +// TransformerComputeKernels::dbgCopyFirst8, +// state.wrapKeyCache, debugKV) +// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV) // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache // re-persisted so updatePersistedObjectState() propagates the device // pointer to decode layer 0's consumeFromDevice without CUDA graphs. @@ -197,6 +205,7 @@ private void forceCopyInReadOnlyData() { state.batchStartPosHolder.init(0); for (int i = 0; i <= logitsIdx(); i++) { + //System.out.println(i + " " + executionPlan.withGraph(i).toString()); var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); if (CUDA_GRAPHS) g.withCUDAGraph(); g.execute(); @@ -268,6 +277,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)"); decodeAct.execute(); + // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and + // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache) +// if (position <= 290) { +// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n", +// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3)); +// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n", +// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3)); +// } // Graphs N+2..2N+1: decode transformer layers for (int l = 0; l < N; l++) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 56f2c0c3..50619cc2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + // consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the + // source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter + // mode (updatePersistedObjectState looks up the predecessor's name, not the current name). + // Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to + // return the correct predecessor graph name so the XPUBuffer is propagated in both modes. + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].asHalfFloatArray(), @@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, + state.wrapValueCache); return unifiedLayer; } + /** + * Returns the name of the predecessor task graph from which {@code wrapX} should be consumed, + * or {@code null} to fall back to the no-arg form (source key = own graph name). + * + *

The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time) + * but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's + * graph name, not the current graph's name, so the XPUBuffer is never propagated and + * {@code executeAlloc} NPEs on a null buffer.

+ * + *

Override in subclasses that receive {@code wrapX} from a named predecessor graph:

+ *
    + *
  • layer 0: return the activation graph name (e.g. {@code "activationUpdate"})
  • + *
  • layer k > 0: return {@code "layer_" + (k-1)}
  • + *
+ */ + protected String predecessorGraphName(int layerIndex) { + return null; + } + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { if (layerIndex == 0) { // First layer: Transfer initial data to device (one-time transfer) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index bf938a0d..1858408e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -22,11 +22,25 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config, lastTaskGraphID, schedulerType); } + /** + * Hook called before any data transfers or tasks. Override to prepend + * {@code consumeFromDevice} declarations that must precede the bytecode + * (e.g. KV-cache pass-through in the Phase 4 unified plan). + */ + protected void configureAdditionalConsumes(TaskGraph logits) {} + + /** + * Hook called after {@code transferToHost}. Override to append + * {@code persistOnDevice} declarations (e.g. KV-cache pass-through in Phase 4). + */ + protected void configureAdditionalPersists(TaskGraph logits) {} + // @formatter:off @Override protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === + configureAdditionalConsumes(logits); logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -80,6 +94,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + configureAdditionalPersists(logits); return logits; } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index f781f08a..4d632425 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -9,12 +9,22 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses - * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}. + * Decode-path FFN layers for the Phase 4 unified plan. * - *

This ensures decode layer 0 receives the KV-cache device pointer that was - * persisted by the last batch prefill layer and passed through the decode - * activation graph.

+ *

Overrides data-transfer declarations so that all cross-graph boundaries use + * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by + * the base class) passes the current graph's own name as the source key. + * In CUDA-graph mode this is harmless (device pointers are frozen at capture time), + * but in interpreter mode {@code updatePersistedObjectState} looks up the + * predecessor's name, so the lookup always misses and the XPUBuffer is + * never propagated — causing either a null-pointer crash or a silent re-upload + * from host (zeros), corrupting the hidden state and KV cache.

+ * + *

Two boundaries are fixed here:

+ *
    + *
  • {@code wrapX}: via {@link #predecessorGraphName} hook in the base class.
  • + *
  • All other consumed objects: via the {@link #configureLayerDataTransfers} override.
  • + *
*/ public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers { public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, @@ -23,11 +33,25 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state, super(taskGraph, state, weights, config, schedulerType); } + /** + * Supplies the correct predecessor graph name for {@code consumeFromDevice(wrapX)}. + * + *

Layer 0 receives {@code wrapX} from the decode activation graph; + * layers 1+ receive it from the previous decode layer. + * Must match the {@code TaskGraph} names used in + * {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.

+ */ + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1); + } + @Override protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { if (layerIndex == 0) { - // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come - // from device (passed through by the decode activation graph). + // Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device + // (passed through by the decode activation graph, which relays them from + // the last batch prefill layer). No FIRST_EXECUTION for KV cache here. layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -35,12 +59,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) state.wrapXb, state.wrapXb2, state.wrapQ, state.wrapK, state.wrapV, state.wrapAtt, state.wrapHb, state.wrapXbFP16); - // KV cache: consume from device (device pointer supplied by - // decode activation's pass-through from last batch layer). - layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache); + // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph(). + layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache); } else { - // Identical to parent for layers 1+ (already uses consumeFromDevice). - return super.configureLayerDataTransfers(layer, layerIndex); + // Layers 1+: use explicit predecessor name for ALL consumed objects. + // Calling super here would use the no-arg form (source key = own graph name), + // which silently fails in interpreter mode and causes re-upload from host. + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); } return layer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java new file mode 100644 index 00000000..760be156 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -0,0 +1,53 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer for the unified prefill-decode plan (Phase 4). + * + *

Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device + * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the + * logits → decode-activation boundary across decode tokens.

+ * + *

In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()} + * propagates device pointers from the predecessor graph's persisted set. After the + * last decode token the predecessor of the next decode-activation graph is the + * logits graph. Without the pass-through here, the KV-cache pointer is absent from + * the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.

+ * + *

Bytecode order matters: {@code consumeFromDevice} must precede task declarations, + * and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in + * {@link LogitsFP16Layer} guarantee this ordering.

+ */ +public class LogitsFP16LayerDecode extends LogitsFP16Layer { + + public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + /** + * Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks. + * + *

Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache + * to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState} + * uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor + * (last decode layer) to never be propagated into the logits graph's device state.

+ */ + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + /** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */ + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 30e3267e..a893623d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -69,10 +69,19 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { state.wrapXbBatch, state.wrapHbBatch, state.wrapKeyCache, state.wrapValueCache); - // wrapXBatch produced by the batch activation graph - layer.consumeFromDevice(state.wrapXBatch); + // wrapXBatch produced by the batch activation graph. + // Explicit source name required: the no-arg form uses the current graph's own + // name ("batchLayer_0") which never matches "batchActivation" in interpreter mode, + // causing wrapXBatch to be re-uploaded from host (zeros) instead of using the + // FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel. + layer.consumeFromDevice("batchActivation", state.wrapXBatch); } else { - layer.consumeFromDevice( + // Explicit predecessor name for all objects. + // The no-arg form would use "batchLayer_k" as the source key, which never matches + // "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from + // host (zeros or stale), corrupting the KV cache written by the previous layer. + String pred = "batchLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, context, state.wrapXBatch, state.wrapXbFP16Batch,