From c00aa825d07d1fed718339ece69dc66a32670440 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 26 Mar 2026 22:42:02 +0200
Subject: [PATCH 01/12] [refactor] Move QuantizedLayerPlanner to layerplanner
package root-level
---
.../gpullama3/tornadovm/TornadoVMMasterPlan.java | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index a42dc310..4b752735 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -55,6 +55,8 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
}
+ tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
+
// 2. Perform warmup with extra iterations to ensure JIT compilation is complete
tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code
@@ -130,6 +132,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
// 1. Execute the preprocessing graph (e.g., input preparation, memory initialization)
executionPlan.withGraph(getPreprocessingGraphIndex())
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ .withCUDAGraph()
.execute();
// Set the position in the state object (used by attention layers)
@@ -142,6 +145,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
executionPlan.withGraph(getLayerGraphIndex(layer))
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ .withCUDAGraph()
.execute();
}
state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f
@@ -149,6 +153,7 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
// 3. Execute the final graph that projects the last hidden state to output logits
executionPlan.withGraph(getFinalLogitsGraphIndex())
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ .withCUDAGraph()
.execute();
// @formatter:on
@@ -187,15 +192,15 @@ public void forceCopyInReadOnlyDataLayered() {
state.positionHolder.init(0);
// Execute activation update graph
- executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
// Execute layer processing graphs
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
- executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
}
// Execute logits graph
- executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
}
/**
From 8be1c05e1b56d0be3aec89456f64f5c973396ad1 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 31 Mar 2026 18:19:02 +0300
Subject: [PATCH 02/12] [prf/dec] Add CLI options for batched prefill and
prefill batch size configuration
---
llama-tornado | 22 ++++++++++++++++++++++
1 file changed, 22 insertions(+)
diff --git a/llama-tornado b/llama-tornado
index 57a50f1c..81349a5e 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -87,6 +87,12 @@ class LlamaRunner:
if args.verbose_init:
cmd.append("-Dllama.EnableTimingForTornadoVMInit=true")
+ if args.batched_prefill:
+ cmd.append("-Dllama.batchedPrefill=true")
+
+ if args.prefill_batch_size is not None:
+ cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}")
+
# Debug options
debug_config = []
@@ -472,6 +478,22 @@ def create_parser() -> argparse.ArgumentParser:
help="Execute the command after showing it (use with --show-command)",
)
+ # Prefill/Decode optimization
+ prefill_group = parser.add_argument_group("Prefill/Decode Optimization")
+ prefill_group.add_argument(
+ "--batched-prefill",
+ dest="batched_prefill",
+ action="store_true",
+ help="Enable batched prefill/decode separation (llama.batchedPrefill=true)",
+ )
+ prefill_group.add_argument(
+ "--prefill-batch-size",
+ dest="prefill_batch_size",
+ type=int,
+ default=None,
+ help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)",
+ )
+
# Advanced options
advanced_group = parser.add_argument_group("Advanced Options")
advanced_group.add_argument(
From ca4744f880ea058308b8bbc648b5b5098f7dde96 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 31 Mar 2026 18:20:28 +0300
Subject: [PATCH 03/12] [prf/dec] Add CPU path for prefill/decode. Separates
inference path with InferenceCoreWithPrefillDecode and
InferenceEngineWithPrefillDecode
---
.../InferenceCoreWithPrefillDecode.java | 124 ++++++++++++++++++
.../InferenceEngineWithPrefillDecode.java | 120 +++++++++++++++++
.../beehive/gpullama3/model/llama/Llama.java | 6 +
3 files changed, 250 insertions(+)
create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
create mode 100644 src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
new file mode 100644
index 00000000..d662afe1
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
@@ -0,0 +1,124 @@
+package org.beehive.gpullama3.inference;
+
+import org.beehive.gpullama3.auxiliary.Parallel;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.standard.FloatTensor;
+
+/**
+ * Low-level forward passes for the prefill/decode separated inference path.
+ *
+ * Parallel to {@link InferenceCore} — does NOT modify it.
+ *
+ * The key addition is {@link #forwardJavaPrefill}, which runs a full
+ * transformer forward pass but skips the final RMSNorm and vocabulary
+ * projection (wcls matmul). This is correct for all prefill positions
+ * because their logits are discarded anyway; only the KV-cache update
+ * matters. Skipping the projection saves one large matmul
+ * (vocabularySize × dim) per prefill token.
+ */
+public final class InferenceCoreWithPrefillDecode {
+
+ private InferenceCoreWithPrefillDecode() {}
+
+ /**
+ * Prefill-only forward pass for LLaMA (CPU, FP32 weights).
+ *
+ * Identical to {@link InferenceCore#forwardJava} except the final
+ * RMSNorm and vocabulary projection are omitted. The KV cache is
+ * populated correctly at {@code position}.
+ *
+ * @param model the LLaMA model (must carry {@link StandardWeights})
+ * @param state mutable inference state (KV cache, activations …)
+ * @param token input token id
+ * @param position sequence position being processed
+ */
+ public static void forwardJavaPrefill(Model model, State state, int token, int position) {
+ final Configuration config = model.configuration();
+ final StandardWeights weights = (StandardWeights) model.weights();
+ int dim = config.dim();
+ int headSize = config.headSize();
+ int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
+ int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
+ float sqrtHeadSize = (float) Math.sqrt(headSize);
+
+ // Token embedding
+ weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
+
+ // Transformer layers
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ // Attention RMSNorm
+ InferenceCore.rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
+
+ // QKV projections
+ weights.wq[l].matmul(state.xb, state.q, dim, dim);
+ weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
+ weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
+
+ // RoPE
+ for (int i = 0; i < dim; i += 2) {
+ int head_dim = i % headSize;
+ float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
+ float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
+ int rotn = i < kvDim ? 2 : 1;
+ for (int v = 0; v < rotn; v++) {
+ FloatTensor vec = v == 0 ? state.q : state.k;
+ float v0 = vec.getFloat(i);
+ float v1 = vec.getFloat(i + 1);
+ vec.setFloat(i, v0 * fcr - v1 * fci);
+ vec.setFloat(i + 1, v0 * fci + v1 * fcr);
+ }
+ }
+
+ // KV cache update
+ state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
+ state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
+
+ // Multi-head attention
+ int curLayer = l;
+ Parallel.parallelFor(0, config.numberOfHeads(), h -> {
+ int qOffset = h * headSize;
+ int attOffset = h * config.contextLength();
+
+ for (int t = 0; t <= position; t++) {
+ int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
+ float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
+ score /= sqrtHeadSize;
+ state.att.setFloat(attOffset + t, score);
+ }
+
+ state.att.softmaxInPlace(attOffset, position + 1);
+
+ int xbOffset = h * headSize;
+ state.xb.fillInPlace(xbOffset, headSize, 0f);
+ for (int t = 0; t <= position; t++) {
+ int vOffset = t * kvDim + (h / kvMul) * headSize;
+ float a = state.att.getFloat(attOffset + t);
+ state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
+ }
+ });
+
+ // Attention output projection + residual
+ weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
+ state.x.addInPlace(state.xb2);
+
+ // FFN RMSNorm
+ InferenceCore.rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
+
+ // FFN (SwiGLU)
+ weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
+ weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
+ state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
+ state.hb.multiplyInPlace(state.hb2);
+ weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
+
+ // FFN residual
+ state.x.addInPlace(state.xb);
+ }
+
+ // Final RMSNorm and vocab projection intentionally omitted:
+ // logits are not needed for prefill positions — only the KV cache matters.
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
new file mode 100644
index 00000000..b97f3c72
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -0,0 +1,120 @@
+package org.beehive.gpullama3.inference;
+
+import org.beehive.gpullama3.auxiliary.LastRunMetrics;
+import org.beehive.gpullama3.inference.sampler.Sampler;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.function.IntConsumer;
+
+/**
+ * Token generation entry point for the prefill/decode separated inference path.
+ *
+ * Parallel to {@link InferenceEngine} — does NOT modify it.
+ *
+ * The split loop runs two phases:
+ *
+ * - Prefill (positions 0..N-1): calls
+ * {@link InferenceCoreWithPrefillDecode#forwardJavaPrefill} for every
+ * prompt token. Vocabulary projection is skipped because these logits
+ * are discarded. KV cache is populated identically to the baseline.
+ * - Decode (position N onward): calls
+ * {@link InferenceCore#forwardJava} per generated token.
+ * Behaviour is identical to the baseline decode path.
+ *
+ *
+ * Activated by {@code -Dllama.batchedPrefill=true} (set via
+ * {@code --batched-prefill} in the Python launcher).
+ */
+public final class InferenceEngineWithPrefillDecode {
+
+ private InferenceEngineWithPrefillDecode() {}
+
+ /**
+ * LLaMA token generation with prefill/decode separation (CPU, Phase 1).
+ *
+ * Drop-in replacement for
+ * {@link InferenceEngine#generateTokensLlama} when the batched-prefill
+ * flag is enabled. Only the CPU path is implemented here; GPU support
+ * is added in a later phase.
+ */
+ public static List generateTokensLlama(
+ Model model, State state, int startPosition,
+ List promptTokens, Set stopTokens,
+ int maxTokens, Sampler sampler, boolean echo,
+ IntConsumer onTokenGenerated) {
+
+ long startNanos = System.nanoTime();
+
+ final Configuration config = model.configuration();
+ if (maxTokens < 0 || config.contextLength() < maxTokens) {
+ maxTokens = config.contextLength();
+ }
+
+ List generatedTokens = new ArrayList<>();
+
+ int currentToken = state.latestToken; // BOS
+ int pos = startPosition;
+
+ // ── Phase 1: Prefill ──────────────────────────────────────────────────
+ // Run all prompt tokens through the forward pass without computing
+ // logits. The KV cache is populated at each position, which is all
+ // that matters. After this loop:
+ // currentToken == promptTokens.getLast()
+ // pos == startPosition + promptTokens.size()
+ for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) {
+ InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos);
+ currentToken = promptTokens.get(promptIndex);
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(currentToken))));
+ }
+ pos++;
+ }
+
+ state.latestToken = currentToken;
+
+ // ── Phase 2: Decode ───────────────────────────────────────────────────
+ // Standard single-token forward with logits. Behaviour is identical
+ // to the baseline InferenceEngine decode path.
+ long inferenceStartNanos = 0;
+ while (pos < maxTokens) {
+ if (inferenceStartNanos == 0) {
+ inferenceStartNanos = System.nanoTime();
+ }
+
+ var logits = InferenceCore.forwardJava(model, state, currentToken, pos);
+ int nextToken = sampler.sampleToken(logits);
+
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(nextToken))));
+ }
+
+ generatedTokens.add(nextToken);
+
+ if (onTokenGenerated != null) {
+ onTokenGenerated.accept(nextToken);
+ }
+
+ if (stopTokens.contains(nextToken)) {
+ break;
+ }
+
+ currentToken = nextToken;
+ state.latestToken = currentToken;
+ pos++;
+ }
+
+ long endNanos = System.nanoTime();
+ int totalTokens = promptTokens.size() + generatedTokens.size();
+ LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0);
+
+ return generatedTokens;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
index 8c69cb40..8036809e 100644
--- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
+++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
@@ -2,6 +2,7 @@
import org.beehive.gpullama3.inference.InferenceCore;
import org.beehive.gpullama3.inference.InferenceEngine;
+import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.state.State;
@@ -19,6 +20,8 @@
public class Llama extends AbstractModel {
+ static final boolean BATCHED_PREFILL = Boolean.getBoolean("llama.batchedPrefill");
+
LlamaConfiguration configuration;
public Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
@@ -63,6 +66,9 @@ public void forward(State state, int token, int position) {
@Override
public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
+ if (BATCHED_PREFILL) {
+ return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
+ }
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
}
From d5f5629407e95f3752ab1fb65bdb5838db418421 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 31 Mar 2026 19:09:26 +0300
Subject: [PATCH 04/12] [prf/dec] Add GPU path for prefill/decode with
TornadoVM integration. Implements `InferenceEngineWithPrefillDecode` and
`TornadoVMMasterPlanWithPrefillDecode` for batched token generation. Refactor
`Llama` to support the batched prefill flag.
---
.../InferenceCoreWithPrefillDecode.java | 40 ++++++++
.../InferenceEngineWithPrefillDecode.java | 94 +++++++++++++++++++
.../beehive/gpullama3/model/llama/Llama.java | 3 +
.../TornadoVMMasterPlanWithPrefillDecode.java | 79 ++++++++++++++++
4 files changed, 216 insertions(+)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
index d662afe1..91bb6f79 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
@@ -3,9 +3,13 @@
import org.beehive.gpullama3.auxiliary.Parallel;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
+
+import java.lang.foreign.MemorySegment;
/**
* Low-level forward passes for the prefill/decode separated inference path.
@@ -121,4 +125,40 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
// Final RMSNorm and vocab projection intentionally omitted:
// logits are not needed for prefill positions — only the KV cache matters.
}
+
+ /**
+ * GPU prefill-only forward pass for LLaMA (FP16, TornadoVM).
+ *
+ * Copies the token embedding into {@code state.embeddingX} (same as
+ * {@link InferenceCore#forwardTornadoVM}) then delegates to
+ * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill},
+ * which executes preprocessing + layer graphs but skips the logits graph.
+ *
+ * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only)
+ * @param state mutable inference state
+ * @param token input token id
+ * @param position sequence position being processed
+ * @param prefillPlan the prefill/decode plan wrapper
+ * @throws UnsupportedOperationException if the model uses Q8_0 weights
+ */
+ public static void forwardTornadoVMPrefill(Model model, State state, int token, int position,
+ TornadoVMMasterPlanWithPrefillDecode prefillPlan) {
+ final Configuration configuration = model.configuration();
+ final TornadoWeights weights = (TornadoWeights) model.weights();
+
+ switch (weights.getWeightType()) {
+ case F16 -> {
+ MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ int bytes = Short.BYTES;
+ MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes,
+ state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes);
+ }
+ case Q8_0 -> throw new UnsupportedOperationException(
+ // TODO Phase 4: implement Q8_0 GPU batched prefill kernels
+ "GPU prefill/decode path not yet implemented for Q8_0 weights");
+ default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
+ }
+
+ prefillPlan.tornadoVMForwardPrefill(position);
+ }
}
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
index b97f3c72..0ea06c84 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -3,9 +3,13 @@
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
+import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
import java.util.ArrayList;
import java.util.List;
@@ -117,4 +121,94 @@ public static List generateTokensLlama(
return generatedTokens;
}
+
+ /**
+ * LLaMA GPU token generation with prefill/decode separation (Phase 2).
+ *
+ * Drop-in replacement for
+ * {@link InferenceEngine#generateTokensGPULlama} when the batched-prefill
+ * flag is enabled. FP16 only; Q8_0 throws {@link UnsupportedOperationException}.
+ *
+ * Split loop:
+ *
+ * - Prefill (0..N-1): {@link InferenceCoreWithPrefillDecode#forwardTornadoVMPrefill}
+ * — layer graphs execute, logits graph is skipped.
+ * - Decode (N onward): {@link InferenceCore#forwardTornadoVM}
+ * — identical to the baseline GPU decode path.
+ *
+ */
+ public static List generateTokensGPULlama(
+ Model model, State state, int startPosition,
+ List promptTokens, Set stopTokens,
+ int maxTokens, Sampler sampler, boolean echo,
+ IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
+
+ // Q8_0 GPU prefill not implemented yet
+ if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) {
+ // TODO Phase 4: implement Q8_0 GPU batched prefill kernels
+ throw new UnsupportedOperationException(
+ "GPU prefill/decode path not yet implemented for Q8_0 weights");
+ }
+
+ long startNanos = System.nanoTime();
+
+ final Configuration config = model.configuration();
+ int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens)
+ ? config.contextLength() : maxTokens;
+
+ List generatedTokens = new ArrayList<>();
+
+ int currentToken = state.latestToken; // BOS
+ int pos = startPosition;
+
+ // Thin wrapper: no new TornadoVM plan created, just holds the reference
+ TornadoVMMasterPlanWithPrefillDecode prefillPlan =
+ new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model);
+
+ // ── Phase 1: Prefill (GPU, no logits) ────────────────────────────────
+ for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) {
+ InferenceCoreWithPrefillDecode.forwardTornadoVMPrefill(model, state, currentToken, pos, prefillPlan);
+ currentToken = promptTokens.get(promptIndex);
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(currentToken))));
+ }
+ pos++;
+ }
+
+ state.latestToken = currentToken;
+
+ // ── Phase 2: Decode (GPU, with logits) ───────────────────────────────
+ while (pos < actualMaxTokens) {
+ var logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan);
+ int nextToken = sampler.sampleToken(logits);
+
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(nextToken))));
+ }
+
+ generatedTokens.add(nextToken);
+
+ if (onTokenGenerated != null) {
+ onTokenGenerated.accept(nextToken);
+ }
+
+ if (stopTokens.contains(nextToken)) {
+ break;
+ }
+
+ currentToken = nextToken;
+ state.latestToken = currentToken;
+ pos++;
+ }
+
+ long endNanos = System.nanoTime();
+ int totalTokens = promptTokens.size() + generatedTokens.size();
+ LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0);
+
+ return generatedTokens;
+ }
+
+
}
diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
index 8036809e..12a95070 100644
--- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
+++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java
@@ -75,6 +75,9 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
+ if (BATCHED_PREFILL) {
+ return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
+ }
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
new file mode 100644
index 00000000..61b81bef
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
@@ -0,0 +1,79 @@
+package org.beehive.gpullama3.tornadovm;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+
+/**
+ * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass.
+ *
+ * Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.
+ *
+ * The existing execution plan has this graph layout:
+ *
+ * graph 0 : preprocessing (embedding setup)
+ * graphs 1..N : transformer layers
+ * graph N+1 : logits projection (final RMSNorm + wcls matmul)
+ *
+ *
+ * {@link #tornadoVMForwardPrefill} executes graphs 0..N and deliberately
+ * skips graph N+1. The KV cache is populated correctly by the layer graphs;
+ * the logits are not needed for prefill positions so the projection is wasted
+ * work that we avoid.
+ *
+ * For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped
+ * plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour
+ * to the baseline GPU path.
+ */
+public class TornadoVMMasterPlanWithPrefillDecode {
+
+ private final TornadoVMMasterPlan plan;
+ private final State state;
+ private final Configuration config;
+
+ public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) {
+ this.plan = plan;
+ this.state = state;
+ this.config = model.configuration();
+ }
+
+ /**
+ * GPU prefill forward: runs preprocessing + all transformer layers, skips logits.
+ *
+ * Mirrors {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered} except
+ * the final logits graph (graph {@code numberOfLayers + 1}) is not executed.
+ *
+ * @param position sequence position being processed
+ */
+ public void tornadoVMForwardPrefill(int position) {
+ // Graph 0: preprocessing
+ plan.executionPlan.withGraph(0)
+ .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler())
+ .execute();
+
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ // Graphs 1..N: transformer layers
+ for (int layer = 1; layer <= config.numberOfLayers(); layer++) {
+ plan.executionPlan.withGraph(layer)
+ .withGridScheduler(plan.tornadoVMLayerPlanner.getGridScheduler())
+ .execute();
+ }
+
+ // Graph N+1 (logits) intentionally skipped — not needed for prefill positions.
+ }
+
+ /**
+ * GPU decode forward: full execution including logits.
+ * Delegates to {@link TornadoVMMasterPlan#tornadoVMForwardExecuteLayered}.
+ *
+ * @param position sequence position being processed
+ * @return logits array for token sampling
+ */
+ public FloatArray tornadoVMForwardDecode(int position) {
+ return plan.tornadoVMForwardExecuteLayered(position);
+ }
+}
From fbbc41ff3041235cf61102d0131e7afec07fdbe5 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 31 Mar 2026 22:34:07 +0300
Subject: [PATCH 05/12] [prf/dec] Batch prompt tokens during prefill phase in
CPU path
---
.../InferenceCoreWithPrefillDecode.java | 142 ++++++++++++++++++
.../InferenceEngineWithPrefillDecode.java | 96 ++++++++----
2 files changed, 207 insertions(+), 31 deletions(-)
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
index 91bb6f79..460bb9af 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
@@ -6,6 +6,7 @@
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
@@ -126,6 +127,147 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
// logits are not needed for prefill positions — only the KV cache matters.
}
+ /**
+ * CPU batched prefill forward pass for LLaMA (Phase 3).
+ *
+ * Processes {@code batchSize} prompt tokens simultaneously through all
+ * transformer layers. For each layer, Q/K/V projections, output projection,
+ * and FFN projections are computed via batch matmul
+ * ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}),
+ * which parallelises over both output dimension and batch simultaneously.
+ * Attention reuses {@code state.att} sequentially per token (parallel per
+ * head within each token), keeping memory overhead minimal.
+ *
+ * The logits layer is intentionally omitted — only the KV cache matters
+ * for prefill positions.
+ *
+ * @param model the LLaMA model (must carry {@link StandardWeights})
+ * @param state mutable inference state (KV cache, att buffer …)
+ * @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b}
+ * @param startPos sequence position of {@code tokens[0]}
+ * @param batchSize number of tokens in this chunk ({@code tokens.length})
+ */
+ public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) {
+ final Configuration config = model.configuration();
+ final StandardWeights weights = (StandardWeights) model.weights();
+ int dim = config.dim();
+ int headSize = config.headSize();
+ int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
+ int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
+ float sqrtHeadSize = (float) Math.sqrt(headSize);
+
+ // ── Batch activation tensors (allocated once per chunk) ───────────────
+ FloatTensor[] x = new FloatTensor[batchSize];
+ FloatTensor[] xb = new FloatTensor[batchSize];
+ FloatTensor[] xb2 = new FloatTensor[batchSize];
+ FloatTensor[] q = new FloatTensor[batchSize];
+ FloatTensor[] k = new FloatTensor[batchSize];
+ FloatTensor[] v = new FloatTensor[batchSize];
+ FloatTensor[] hb = new FloatTensor[batchSize];
+ FloatTensor[] hb2 = new FloatTensor[batchSize];
+ for (int b = 0; b < batchSize; b++) {
+ x[b] = ArrayFloatTensor.allocate(dim);
+ xb[b] = ArrayFloatTensor.allocate(dim);
+ xb2[b] = ArrayFloatTensor.allocate(dim);
+ q[b] = ArrayFloatTensor.allocate(dim);
+ k[b] = ArrayFloatTensor.allocate(kvDim);
+ v[b] = ArrayFloatTensor.allocate(kvDim);
+ hb[b] = ArrayFloatTensor.allocate(config.hiddenDim());
+ hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim());
+ }
+
+ // ── Token embeddings ──────────────────────────────────────────────────
+ Parallel.parallelFor(0, batchSize, b ->
+ weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim));
+
+ // ── Transformer layers ────────────────────────────────────────────────
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ final int layer = l;
+
+ // Attention RMSNorm (parallel per b)
+ Parallel.parallelFor(0, batchSize, b ->
+ InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps()));
+
+ // QKV projections — batch matmul parallelises over (dim × batchSize)
+ weights.wq[l].matmul(batchSize, xb, q, dim, dim);
+ weights.wk[l].matmul(batchSize, xb, k, kvDim, dim);
+ weights.wv[l].matmul(batchSize, xb, v, kvDim, dim);
+
+ // RoPE + KV cache store (parallel per b — different positions, no conflict)
+ Parallel.parallelFor(0, batchSize, b -> {
+ int pos = startPos + b;
+ for (int i = 0; i < dim; i += 2) {
+ int head_dim = i % headSize;
+ float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2));
+ float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2));
+ int rotn = i < kvDim ? 2 : 1;
+ for (int vv = 0; vv < rotn; vv++) {
+ FloatTensor vec = vv == 0 ? q[b] : k[b];
+ float v0 = vec.getFloat(i);
+ float v1 = vec.getFloat(i + 1);
+ vec.setFloat(i, v0 * fcr - v1 * fci);
+ vec.setFloat(i + 1, v0 * fci + v1 * fcr);
+ }
+ }
+ k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim);
+ v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim);
+ });
+
+ // Attention — sequential per b (state.att is shared), parallel per head
+ for (int b = 0; b < batchSize; b++) {
+ final int pos_b = startPos + b;
+ final int bFinal = b;
+ Parallel.parallelFor(0, config.numberOfHeads(), h -> {
+ int qOffset = h * headSize;
+ int attOffset = h * config.contextLength();
+
+ for (int t = 0; t <= pos_b; t++) {
+ int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
+ float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize;
+ state.att.setFloat(attOffset + t, score);
+ }
+ state.att.softmaxInPlace(attOffset, pos_b + 1);
+
+ int xbOffset = h * headSize;
+ xb[bFinal].fillInPlace(xbOffset, headSize, 0f);
+ for (int t = 0; t <= pos_b; t++) {
+ int vOffset = t * kvDim + (h / kvMul) * headSize;
+ float a = state.att.getFloat(attOffset + t);
+ xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a);
+ }
+ });
+ }
+
+ // Output projection — batch matmul
+ weights.wo[l].matmul(batchSize, xb, xb2, dim, dim);
+
+ // Residual + FFN RMSNorm (parallel per b)
+ Parallel.parallelFor(0, batchSize, b -> {
+ x[b].addInPlace(xb2[b]);
+ InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps());
+ });
+
+ // FFN projections — batch matmul
+ weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim);
+ weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim);
+
+ // SwiGLU (parallel per b)
+ Parallel.parallelFor(0, batchSize, b -> {
+ hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
+ hb[b].multiplyInPlace(hb2[b]);
+ });
+
+ // w2 projection — batch matmul (output reuses xb)
+ weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim());
+
+ // FFN residual (parallel per b)
+ Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b]));
+ }
+
+ // Final RMSNorm and vocab projection intentionally omitted —
+ // logits are not needed for any token in a prefill batch.
+ }
+
/**
* GPU prefill-only forward pass for LLaMA (FP16, TornadoVM).
*
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
index 0ea06c84..b581b8e8 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -12,6 +12,7 @@
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
@@ -39,13 +40,20 @@ public final class InferenceEngineWithPrefillDecode {
private InferenceEngineWithPrefillDecode() {}
+ /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */
+ static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1);
+
/**
- * LLaMA token generation with prefill/decode separation (CPU, Phase 1).
+ * LLaMA token generation with prefill/decode separation (CPU).
*
- * Drop-in replacement for
- * {@link InferenceEngine#generateTokensLlama} when the batched-prefill
- * flag is enabled. Only the CPU path is implemented here; GPU support
- * is added in a later phase.
+ * When {@code llama.prefillBatchSize > 1} (Phase 3), prompt tokens are
+ * processed in chunks of that size using batch matmul, which traverses each
+ * weight matrix once per chunk instead of once per token.
+ *
+ * When {@code llama.prefillBatchSize == 1} (Phase 1), falls back to
+ * sequential single-token prefill (skip logits only).
+ *
+ * Drop-in replacement for {@link InferenceEngine#generateTokensLlama}.
*/
public static List generateTokensLlama(
Model model, State state, int startPosition,
@@ -56,42 +64,68 @@ public static List generateTokensLlama(
long startNanos = System.nanoTime();
final Configuration config = model.configuration();
- if (maxTokens < 0 || config.contextLength() < maxTokens) {
- maxTokens = config.contextLength();
- }
+ int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens)
+ ? config.contextLength() : maxTokens;
List generatedTokens = new ArrayList<>();
int currentToken = state.latestToken; // BOS
int pos = startPosition;
-
- // ── Phase 1: Prefill ──────────────────────────────────────────────────
- // Run all prompt tokens through the forward pass without computing
- // logits. The KV cache is populated at each position, which is all
- // that matters. After this loop:
- // currentToken == promptTokens.getLast()
- // pos == startPosition + promptTokens.size()
- for (int promptIndex = 0; promptIndex < promptTokens.size(); promptIndex++) {
- InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos);
- currentToken = promptTokens.get(promptIndex);
- if (echo) {
- System.err.print(Tokenizer.replaceControlCharacters(
- model.tokenizer().decode(List.of(currentToken))));
+ int N = promptTokens.size();
+
+ // ── Prefill ───────────────────────────────────────────────────────────
+ if (N > 0 && pos < actualMaxTokens) {
+ if (PREFILL_BATCH_SIZE > 1) {
+ // Phase 3: batch prefill — process tokens in chunks of PREFILL_BATCH_SIZE.
+ // Build the token sequence at positions [startPosition .. startPosition+N-1]:
+ // position startPosition+0 : currentToken (BOS)
+ // position startPosition+1 : promptTokens[0]
+ // ...
+ // position startPosition+N-1: promptTokens[N-2]
+ int[] prefillSeq = new int[N];
+ prefillSeq[0] = currentToken;
+ for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1);
+
+ for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) {
+ int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos);
+ int chunkSize = chunkEnd - chunkStart;
+ int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd);
+
+ if (chunkSize == 1) {
+ InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, chunk[0], pos + chunkStart);
+ } else {
+ InferenceCoreWithPrefillDecode.batchForwardJavaPrefill(model, state, chunk, pos + chunkStart, chunkSize);
+ }
+
+ if (echo) {
+ for (int b = 0; b < chunkSize; b++) {
+ int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1));
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(echoed))));
+ }
+ }
+ }
+
+ currentToken = promptTokens.get(N - 1);
+ pos = startPosition + N;
+ } else {
+ // Phase 1: sequential prefill — single token, no logits
+ for (int promptIndex = 0; promptIndex < N && pos < actualMaxTokens; promptIndex++) {
+ InferenceCoreWithPrefillDecode.forwardJavaPrefill(model, state, currentToken, pos);
+ currentToken = promptTokens.get(promptIndex);
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(currentToken))));
+ }
+ pos++;
+ }
}
- pos++;
}
state.latestToken = currentToken;
- // ── Phase 2: Decode ───────────────────────────────────────────────────
- // Standard single-token forward with logits. Behaviour is identical
- // to the baseline InferenceEngine decode path.
- long inferenceStartNanos = 0;
- while (pos < maxTokens) {
- if (inferenceStartNanos == 0) {
- inferenceStartNanos = System.nanoTime();
- }
-
+ // ── Decode ────────────────────────────────────────────────────────────
+ while (pos < actualMaxTokens) {
var logits = InferenceCore.forwardJava(model, state, currentToken, pos);
int nextToken = sampler.sampleToken(logits);
From f0bca5f56e58db89cec21e9319127fe30413860d Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 2 Apr 2026 18:05:49 +0300
Subject: [PATCH 06/12] [prf/dec][wip] Add GPU-based prefill-decode with
batched prefill (working state, with cuda graphs only)
---
.../InferenceEngineWithPrefillDecode.java | 82 +++-
.../gpullama3/inference/state/LlamaState.java | 42 ++
.../gpullama3/inference/state/State.java | 2 +-
.../tornadovm/TornadoVMMasterPlan.java | 240 ++-------
.../TornadoVMMasterPlanBatchPrefill.java | 342 +++++++++++++
.../TornadoVMMasterPlanStandard.java | 149 ++++++
.../TornadoVMMasterPlanWithPrefillDecode.java | 8 +-
.../TransformerBatchPrefillKernels.java | 461 ++++++++++++++++++
.../fp16/LlamaFP16BatchPrefillLayers.java | 238 +++++++++
9 files changed, 1364 insertions(+), 200 deletions(-)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
index b581b8e8..6517df12 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -2,6 +2,7 @@
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
import org.beehive.gpullama3.inference.sampler.Sampler;
+import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.tensor.GGMLType;
@@ -9,6 +10,8 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
import java.util.ArrayList;
@@ -40,7 +43,7 @@ public final class InferenceEngineWithPrefillDecode {
private InferenceEngineWithPrefillDecode() {}
- /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3). */
+ /** Prefill chunk size. 1 = sequential (Phase 1 behaviour), >1 = batched (Phase 3/4). */
static final int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1);
/**
@@ -195,9 +198,82 @@ public static List generateTokensGPULlama(
int currentToken = state.latestToken; // BOS
int pos = startPosition;
+ if (PREFILL_BATCH_SIZE > 1) {
+ // ── Phase 4: Batch GPU Prefill ────────────────────────────────────
+ // Plan was pre-initialized in Model.runInstructOnce/runInteractive
+ // as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan.
+ TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan;
+
+ int N = promptTokens.size();
+
+ // Build the token sequence at positions [startPosition .. startPosition+N-1]:
+ // position startPosition+0 : currentToken (BOS/previous token)
+ // position startPosition+1 : promptTokens[0]
+ // ...
+ // position startPosition+N-1: promptTokens[N-2]
+ int[] prefillSeq = new int[N];
+ prefillSeq[0] = currentToken;
+ for (int i = 1; i < N; i++) prefillSeq[i] = promptTokens.get(i - 1);
+
+ for (int chunkStart = 0; chunkStart < N && pos + chunkStart < actualMaxTokens; chunkStart += PREFILL_BATCH_SIZE) {
+ int chunkEnd = Math.min(Math.min(chunkStart + PREFILL_BATCH_SIZE, N), actualMaxTokens - pos);
+ int chunkSize = chunkEnd - chunkStart;
+ int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd);
+
+ if (chunkSize == 1) {
+ // Single-token chunk: use decode path (includes logits skip is not needed
+ // here, but we need the KV cache populated — use batch prefill with size 1)
+ plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, 1);
+ } else {
+ plan.tornadoVMForwardBatchPrefill(chunk, pos + chunkStart, model, chunkSize);
+ }
+
+ if (echo) {
+ for (int b = 0; b < chunkSize; b++) {
+ int echoed = promptTokens.get(Math.min(chunkStart + b, N - 1));
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(echoed))));
+ }
+ }
+ }
+
+ currentToken = promptTokens.get(N - 1);
+ pos = startPosition + N;
+ state.latestToken = currentToken;
+
+ // ── Phase 4: Decode (GPU, with logits, via unified plan) ──────────
+ while (pos < actualMaxTokens) {
+ var logits = plan.tornadoVMForwardDecode(currentToken, pos, model);
+ int nextToken = sampler.sampleToken(logits);
+
+ if (echo) {
+ System.err.print(Tokenizer.replaceControlCharacters(
+ model.tokenizer().decode(List.of(nextToken))));
+ }
+
+ generatedTokens.add(nextToken);
+
+ if (onTokenGenerated != null) {
+ onTokenGenerated.accept(nextToken);
+ }
+
+ if (stopTokens.contains(nextToken)) {
+ break;
+ }
+
+ currentToken = nextToken;
+ state.latestToken = currentToken;
+ pos++;
+ }
+
+ } else {
+ // ── Phase 2: Sequential GPU Prefill + Decode ─────────────────────────
+
// Thin wrapper: no new TornadoVM plan created, just holds the reference
+ // Plan is a TornadoVMMasterPlanStandard when PREFILL_BATCH_SIZE == 1.
TornadoVMMasterPlanWithPrefillDecode prefillPlan =
- new TornadoVMMasterPlanWithPrefillDecode(tornadoVMPlan, state, model);
+ new TornadoVMMasterPlanWithPrefillDecode(
+ (TornadoVMMasterPlanStandard) tornadoVMPlan, state, model);
// ── Phase 1: Prefill (GPU, no logits) ────────────────────────────────
for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) {
@@ -237,6 +313,8 @@ public static List generateTokensGPULlama(
pos++;
}
+ } // end else (Phase 2)
+
long endNanos = System.nanoTime();
int totalTokens = promptTokens.size() + generatedTokens.size();
LastRunMetrics.setMetrics(totalTokens, (endNanos - startNanos) / 1_000_000_000.0);
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
index 21344223..d298d388 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java
@@ -22,8 +22,50 @@
*/
public final class LlamaState extends State {
+ // ── Batch-prefill GPU buffers ─────────────────────────────────────────────
+ // Allocated when llama.prefillBatchSize > 1; null otherwise.
+ // Layout: flat [B × stride], element [b][i] at index b*stride + i.
+ public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input)
+ public final FloatArray wrapXBatch; // B × dim (live activations)
+ public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16)
+ public final FloatArray wrapQBatch; // B × dim
+ public final FloatArray wrapKBatch; // B × kvDim
+ public final FloatArray wrapVBatch; // B × kvDim
+ public final FloatArray wrapXbBatch; // B × dim (attention output)
+ public final FloatArray wrapHbBatch; // B × hiddenDim
+ public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn)
+ public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN)
+ public final IntArray batchStartPosHolder; // 1 (start position of chunk)
+
public LlamaState(Configuration config, int batchsize) {
super(config, batchsize);
+ int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1);
+ if (gpuBatchSize > 1) {
+ int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
+ this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim());
+ this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim());
+ this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim());
+ this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim());
+ this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim);
+ this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim);
+ this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim());
+ this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim());
+ this.attnScaleBatch = new FloatArray(gpuBatchSize);
+ this.ffnScaleBatch = new FloatArray(gpuBatchSize);
+ this.batchStartPosHolder = new IntArray(1);
+ } else {
+ this.embeddingXBatch = null;
+ this.wrapXBatch = null;
+ this.wrapXbFP16Batch = null;
+ this.wrapQBatch = null;
+ this.wrapKBatch = null;
+ this.wrapVBatch = null;
+ this.wrapXbBatch = null;
+ this.wrapHbBatch = null;
+ this.attnScaleBatch = null;
+ this.ffnScaleBatch = null;
+ this.batchStartPosHolder = null;
+ }
}
@Override
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java
index f8e9906a..06d448e7 100644
--- a/src/main/java/org/beehive/gpullama3/inference/state/State.java
+++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java
@@ -75,7 +75,7 @@ public abstract class State {
/** last index in previous block */
protected State(Configuration config, int batchsize) {
- this.batchsize = -1;
+ this.batchsize = batchsize;
this.latestToken = -1;
this.localSize = 256;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index 4b752735..43030fd0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -1,212 +1,66 @@
package org.beehive.gpullama3.tornadovm;
+import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-public class TornadoVMMasterPlan {
- public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
-
- private final State state;
- private final Configuration config;
- public TornadoExecutionPlan executionPlan;
- GenericLayerPlanner tornadoVMLayerPlanner;
-
- public TornadoVMMasterPlan(State state, Model model) {
- this.tornadoVMLayerPlanner = createPlanner(state, model);
- this.executionPlan = createExecutionPlan();
- this.state = state;
- this.config = model.configuration();
- }
+/**
+ * Common contract for all TornadoVM GPU execution plans.
+ *
+ * Two concrete implementations exist:
+ *
+ * - {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the
+ * baseline GPU path and Phase 2 sequential prefill/decode.
+ * - {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched
+ * prefill + single-token decode within one {@code TornadoExecutionPlan}.
+ *
+ *
+ * The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation
+ * based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a
+ * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a
+ * {@link TornadoVMMasterPlanStandard}.
+ */
+public interface TornadoVMMasterPlan {
+
+ boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(
+ System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
/**
- * Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3.
- * Copying read-only model weights to the GPU
+ * Single-token forward pass returning output logits.
*
- * @param state
- * The model state containing KV cache
- * @param model
- * The Llama model instance
- * @return The initialized TornadoVMMasterPlan ready for inference
- */
- public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
- // Initialize timing variables outside conditional blocks to avoid scope issues
- long startTime = System.nanoTime();
- long planCreationTime = 0;
- long warmupTime = 0;
-
- // Start a timing message if enabled
- if (ENABLE_TORNADOVM_INIT_TIME) {
- System.err.println("\nStarting TornadoVM initialization...");
- }
-
- // 1. Pre-allocate the TornadoVM plan
- TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model);
-
- // Record time after plan creation
- if (ENABLE_TORNADOVM_INIT_TIME) {
- planCreationTime = System.nanoTime();
- System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
- }
-
- tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
-
- // 2. Perform warmup with extra iterations to ensure JIT compilation is complete
- tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code
-
- // Record time after warmup
- if (ENABLE_TORNADOVM_INIT_TIME) {
- warmupTime = System.nanoTime();
- System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0);
- }
-
- // 3. Perform copy-in of read-only weights and objects
- tornadoVMPlan.forceCopyInReadOnlyDataLayered(); // Force copy-in read-only weights
-
- // Record final timing information
- if (ENABLE_TORNADOVM_INIT_TIME) {
- long copyTime = System.nanoTime();
- System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0);
- System.err.printf("Finished TornadoVM initialization...\n \n");
- }
-
- model.setTornadoVMPlan(tornadoVMPlan);
-
- return tornadoVMPlan;
- }
-
- private TornadoExecutionPlan createExecutionPlan() {
- var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs();
- var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]);
- return new TornadoExecutionPlan(taskGraphArray);
- }
-
- private GenericLayerPlanner createPlanner(State state, Model model) {
- // ========== STEP 1: Detect Quantization Type ==========
- GGMLType weightType = model.weights().getWeightType();
-
- // ========== STEP 2: Route via Factory ==========
- // Factory handles all model × quantization combinations
- GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model);
-
- return basePlanner;
- }
-
- /**
- * Determines whether the NVIDIA-specific scheduler should be used based on the current
- * hardware backend and the model type.
- *
- * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the
- * NVIDIA-specific scheduler should not be used.
+ *
Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM})
+ * and the Phase 2 sequential decode path. Not applicable to
+ * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.
*
- * @param model
- * the model whose type may affect the scheduler decision
- * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
+ * @param position sequence position of the current token
+ * @return logits array for token sampling
*/
+ FloatArray tornadoVMForwardExecuteLayered(int position);
- /**
- * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context
- * window.
- *
- * The execution happens in three phases:
- *
- * - Initial token embedding lookup (already done before calling this method)
- * - Sequential processing through each transformer layer using TornadoVM
- * - Final projection to logits using TornadoVM
- *
- *
- * @param position
- * The current position in the sequence being processed
- * @return FloatTensor containing the output logits for token prediction
- */
-
- // int pos, ModelPlanner
- public FloatArray tornadoVMForwardExecuteLayered(int position) {
- // @formatter:off
- // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization)
- executionPlan.withGraph(getPreprocessingGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- .withCUDAGraph()
- .execute();
-
- // Set the position in the state object (used by attention layers)
- state.positionHolder.set(0, position);
- state.temp.clear();
- state.tempFFN.clear();
-
- // 2. Execute each transformer layer graph sequentially
- // Each graph computes attention and feed-forward transformations for one layer
- for (int layer = 0; layer < config.numberOfLayers(); layer++) {
- executionPlan.withGraph(getLayerGraphIndex(layer))
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- .withCUDAGraph()
- .execute();
- }
- state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f
- state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f
- // 3. Execute the final graph that projects the last hidden state to output logits
- executionPlan.withGraph(getFinalLogitsGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- .withCUDAGraph()
- .execute();
-
- // @formatter:on
- // Return the logits (used for token prediction)
- return state.wrapLogits;
- }
+ /** Releases all device memory held by this plan. */
+ void freeTornadoExecutionPlan();
/**
- * Returns the graph index for the pre-processing step (e.g., token embedding).
- */
- private int getPreprocessingGraphIndex() {
- return 0;
- }
-
- /**
- * Returns the graph index for the given transformer layer.
+ * Factory: creates, JIT-compiles, and warms up the appropriate plan.
*
- * @param layerIndex
- * Index of the transformer layer (0-based)
- */
- private int getLayerGraphIndex(int layerIndex) {
- return 1 + layerIndex;
- }
-
- /**
- * Returns the graph index for the final projection to logits.
+ * When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill}
+ * is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.
+ *
+ * @param state the model state (must be {@link LlamaState} when batch size {@code > 1})
+ * @param model the model instance
+ * @return the initialized plan, also stored via {@link Model#setTornadoVMPlan}
*/
- private int getFinalLogitsGraphIndex() {
- return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1;
- }
-
- /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
- public void forceCopyInReadOnlyDataLayered() {
- // Execute all TornadoVM graphs
- state.wrapX.clear();
- state.positionHolder.init(0);
-
- // Execute activation update graph
- executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
-
- // Execute layer processing graphs
- for (int layer = 0; layer < config.numberOfLayers(); layer++) {
- executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
+ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
+ int batchSize = Integer.getInteger("llama.prefillBatchSize", 1);
+ TornadoVMMasterPlan plan;
+ if (batchSize > 1) {
+ plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan(
+ (LlamaState) state, model, batchSize);
+ } else {
+ plan = TornadoVMMasterPlanStandard.initialize(state, model);
}
-
- // Execute logits graph
- executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
- }
-
- /**
- * Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks.
- */
- public void freeTornadoExecutionPlan() {
- executionPlan.freeDeviceMemory();
+ model.setTornadoVMPlan(plan);
+ return plan;
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
new file mode 100644
index 00000000..b2388bf3
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
@@ -0,0 +1,342 @@
+package org.beehive.gpullama3.tornadovm;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+import java.lang.foreign.MemorySegment;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Unified GPU execution plan for Phase 4: batched prefill + single-token decode.
+ *
+ * A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
+ * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via
+ * {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would
+ * allocate independent device buffers and lose the prefill KV state.
+ *
+ * Graph layout (2N+3 graphs total):
+ *
+ * [0] batch activation B×dim FP16 → FP32
+ * [1..N] batch layer graphs B tokens, all transformer ops
+ * [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through
+ * [N+2..2N+1] decode layer graphs single-token, standard kernels
+ * [2N+2] logits graph
+ *
+ *
+ * KV cache pointer chain across phases:
+ *
+ * batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→
+ * decodeActivation --consumeFromDevice(wrapKeyCache)-→ (pass-through)
+ * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention)
+ *
+ */
+public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan {
+
+ private static final boolean ENABLE_TIMING =
+ Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
+
+ private final LlamaState state;
+ private final LlamaConfiguration config;
+ private final int batchSize;
+ private final int N; // numberOfLayers
+ private final TornadoExecutionPlan executionPlan;
+ private final GridScheduler gridScheduler;
+
+ // ── Graph-index helpers ───────────────────────────────────────────────────
+ private int batchActivationIdx() { return 0; }
+ private int batchLayerIdx(int i) { return 1 + i; }
+ private int decodeActivationIdx() { return N + 1; }
+ private int decodeLayerIdx(int i) { return N + 2 + i; }
+ private int logitsIdx() { return 2 * N + 2; }
+
+ // ── Construction ─────────────────────────────────────────────────────────
+ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) {
+ this.state = state;
+ this.config = (LlamaConfiguration) model.configuration();
+ this.batchSize = batchSize;
+ this.N = config.numberOfLayers();
+
+ LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
+ SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+
+ List all = new ArrayList<>(2 * N + 3);
+ GridScheduler scheduler = new GridScheduler();
+
+ // [0] Batch activation ────────────────────────────────────────────────
+ KernelContext batchActCtx = new KernelContext();
+ all.add(buildBatchActivationGraph(batchActCtx).snapshot());
+ scheduler.addWorkerGrid("batchActivation.batchUpdateX",
+ WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
+
+ // [1..N] Batch layer graphs ───────────────────────────────────────────
+ LlamaFP16BatchPrefillLayers batchLayers =
+ new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize);
+ all.addAll(batchLayers.getLayerImmutableTaskGraphs());
+ batchLayers.updateGridScheduler(scheduler);
+
+ // [N+1] Decode activation (with KV-cache pass-through) ────────────────
+ KernelContext decodeActCtx = new KernelContext();
+ all.add(buildDecodeActivationGraph(decodeActCtx).snapshot());
+ scheduler.addWorkerGrid("activationUpdate.updateX",
+ WorkerGridFactory.genericWorker(config.dim(), 128));
+
+ // [N+2..2N+1] Decode layer graphs ────────────────────────────────────
+ // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
+ LlamaFP16FFNLayersForUnifiedDecode decodeLayers =
+ new LlamaFP16FFNLayersForUnifiedDecode(
+ "llamaFFNDecode", state, weights, config, schedulerType);
+ all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
+ decodeLayers.updateGridScheduler(scheduler);
+
+ // [2N+2] Logits ───────────────────────────────────────────────────────
+ LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config,
+ decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ all.add(logitsLayer.getImmutableTaskGraph());
+ logitsLayer.updateGridScheduler(scheduler);
+
+ this.gridScheduler = scheduler;
+ this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
+ }
+
+ // ── Activation graphs ─────────────────────────────────────────────────────
+
+ /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */
+ private TaskGraph buildBatchActivationGraph(KernelContext ctx) {
+ return new TaskGraph("batchActivation")
+ .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
+ .task("batchUpdateX",
+ (KernelContext c, HalfFloatArray src, FloatArray dst) ->
+ dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()),
+ ctx, state.embeddingXBatch, state.wrapXBatch)
+ .persistOnDevice(state.wrapXBatch);
+ }
+
+ /**
+ * Graph N+1: single-token FP16 → FP32.
+ *
+ * Receives the KV-cache device pointer from batch layer N via
+ * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so
+ * that {@code updatePersistedObjectState()} can propagate it to decode layer 0.
+ * Both halves of the chain are required; without the re-persist the pointer is
+ * not forwarded in interpreter (non-CUDA-graph) mode.
+ */
+ private TaskGraph buildDecodeActivationGraph(KernelContext ctx) {
+ return new TaskGraph("activationUpdate")
+ .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through
+ .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
+ .task("updateX",
+ TransformerComputeKernels::convertFP16toFP32,
+ ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
+ // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache
+ // re-persisted so updatePersistedObjectState() propagates the device
+ // pointer to decode layer 0's consumeFromDevice without CUDA graphs.
+ .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
+ }
+
+ // ── Static factory ────────────────────────────────────────────────────────
+
+ /**
+ * Creates, JIT-compiles, and warms up the unified plan.
+ * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}.
+ */
+ public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan(
+ LlamaState state, Model model, int batchSize) {
+
+ long t0 = System.nanoTime();
+ TornadoVMMasterPlanBatchPrefill plan =
+ new TornadoVMMasterPlanBatchPrefill(state, model, batchSize);
+
+ if (ENABLE_TIMING)
+ System.err.printf("[BatchPlan] Graph construction: %.2f ms%n",
+ (System.nanoTime() - t0) / 1e6);
+
+ plan.executionPlan.withAllGraphs().withCUDAGraph();
+ plan.executionPlan.withPreCompilation();
+
+ if (ENABLE_TIMING)
+ System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n",
+ (System.nanoTime() - t0) / 1e6);
+
+ plan.forceCopyInReadOnlyData();
+
+ if (ENABLE_TIMING)
+ System.err.printf("[BatchPlan] Init complete: %.2f ms%n",
+ (System.nanoTime() - t0) / 1e6);
+
+ return plan;
+ }
+
+ /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
+ private void forceCopyInReadOnlyData() {
+ state.wrapXBatch.clear();
+ state.wrapX.clear();
+ state.positionHolder.init(0);
+ state.batchStartPosHolder.init(0);
+
+ for (int i = 0; i <= logitsIdx(); i++) {
+ executionPlan.withGraph(i)
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+ }
+ }
+
+ // ── Forward passes ────────────────────────────────────────────────────────
+
+ /**
+ * Batch prefill: runs graphs 0..N (activation + N layers), skips logits.
+ *
+ * @param tokenIds token IDs for this chunk (length == batchSize, or tail)
+ * @param startPos sequence position of tokenIds[0]
+ * @param model model (for embedding table)
+ * @param chunkSize actual number of tokens in this chunk (≤ batchSize)
+ */
+ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) {
+ LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
+ MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ int bytes = Short.BYTES;
+ int dim = config.dim();
+
+ // Copy B embeddings into embeddingXBatch
+ for (int b = 0; b < chunkSize; b++) {
+ MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes,
+ state.embeddingXBatch.getSegment(), (long) b * dim * bytes,
+ (long) dim * bytes);
+ }
+ state.batchStartPosHolder.set(0, startPos);
+
+ // Graph 0: batch activation
+ executionPlan.withGraph(batchActivationIdx())
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+
+ // Graphs 1..N: batch transformer layers
+ for (int l = 0; l < N; l++) {
+ executionPlan.withGraph(batchLayerIdx(l))
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+ }
+ // Logits skipped — not needed for prefill positions.
+ }
+
+ /**
+ * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits).
+ *
+ * @param token token ID to process
+ * @param position sequence position
+ * @param model model (for embedding table)
+ * @return logits array for sampling
+ */
+ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
+ LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
+ MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ int bytes = Short.BYTES;
+ int dim = config.dim();
+
+ MemorySegment.copy(embTable, (long) token * dim * bytes,
+ state.embeddingX.getSegment(), 0L, (long) dim * bytes);
+
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ // Graph N+1: decode activation
+ executionPlan.withGraph(decodeActivationIdx())
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+
+ // Graphs N+2..2N+1: decode transformer layers
+ for (int l = 0; l < N; l++) {
+ executionPlan.withGraph(decodeLayerIdx(l))
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+ }
+
+ state.tempLogits.clear();
+ state.wrapLogits.clear();
+
+ // Graph 2N+2: logits
+ executionPlan.withGraph(logitsIdx())
+ .withGridScheduler(gridScheduler)
+ .withCUDAGraph()
+ .execute();
+
+ return state.wrapLogits;
+ }
+
+ @Override
+ public FloatArray tornadoVMForwardExecuteLayered(int position) {
+ throw new UnsupportedOperationException(
+ "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan");
+ }
+
+ @Override
+ public void freeTornadoExecutionPlan() {
+ executionPlan.freeDeviceMemory();
+ }
+
+ // ── Inner class: decode layer 0 with consumeFromDevice for KV cache ───────
+
+ /**
+ * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses
+ * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}.
+ *
+ * This ensures decode layer 0 receives the KV-cache device pointer that was
+ * persisted by the last batch prefill layer and passed through the decode
+ * activation graph.
+ */
+ private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers {
+
+ LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state,
+ LlamaTornadoWeights weights, LlamaConfiguration config,
+ SchedulerType schedulerType) {
+ super(taskGraph, state, weights, config, schedulerType);
+ }
+
+ @Override
+ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
+ if (layerIndex == 0) {
+ // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come
+ // from device (passed through by the decode activation graph).
+ layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder, state.temp, state.tempFFN);
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapAtt, state.wrapHb, state.wrapXbFP16);
+ // KV cache: consume from device (device pointer supplied by
+ // decode activation's pass-through from last batch layer).
+ layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache);
+ } else {
+ // Identical to parent for layers 1+ (already uses consumeFromDevice).
+ return super.configureLayerDataTransfers(layer, layerIndex);
+ }
+ return layer;
+ }
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
new file mode 100644
index 00000000..91586f2c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
@@ -0,0 +1,149 @@
+package org.beehive.gpullama3.tornadovm;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
+import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+
+/**
+ * Standard (single-token) GPU execution plan.
+ *
+ * Processes one token at a time through preprocessing + N transformer layers +
+ * logits projection. Used for both the baseline GPU path and the Phase 2
+ * sequential prefill/decode path.
+ */
+public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan {
+
+ public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
+
+ private final State state;
+ private final Configuration config;
+ public TornadoExecutionPlan executionPlan;
+ GenericLayerPlanner tornadoVMLayerPlanner;
+
+ public TornadoVMMasterPlanStandard(State state, Model model) {
+ this.tornadoVMLayerPlanner = createPlanner(state, model);
+ this.executionPlan = createExecutionPlan();
+ this.state = state;
+ this.config = model.configuration();
+ }
+
+ /**
+ * Initializes and warms up the standard TornadoVM plan.
+ *
+ * @param state the model state containing KV cache
+ * @param model the model instance
+ * @return the initialized plan ready for inference
+ */
+ static TornadoVMMasterPlanStandard initialize(State state, Model model) {
+ long startTime = System.nanoTime();
+ long planCreationTime = 0;
+ long warmupTime = 0;
+
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ System.err.println("\nStarting TornadoVM initialization...");
+ }
+
+ TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model);
+
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ planCreationTime = System.nanoTime();
+ System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
+ }
+
+ tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
+ tornadoVMPlan.executionPlan.withPreCompilation();
+
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ warmupTime = System.nanoTime();
+ System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0);
+ }
+
+ tornadoVMPlan.forceCopyInReadOnlyDataLayered();
+
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ long copyTime = System.nanoTime();
+ System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0);
+ System.err.printf("Finished TornadoVM initialization...\n \n");
+ }
+
+ return tornadoVMPlan;
+ }
+
+ private TornadoExecutionPlan createExecutionPlan() {
+ var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs();
+ var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]);
+ return new TornadoExecutionPlan(taskGraphArray);
+ }
+
+ private GenericLayerPlanner createPlanner(State state, Model model) {
+ GGMLType weightType = model.weights().getWeightType();
+ return QuantizationPlannerFactory.create(weightType, state, model);
+ }
+
+ @Override
+ public FloatArray tornadoVMForwardExecuteLayered(int position) {
+ // @formatter:off
+ executionPlan.withGraph(getPreprocessingGraphIndex())
+ .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ .withCUDAGraph()
+ .execute();
+
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ for (int layer = 0; layer < config.numberOfLayers(); layer++) {
+ executionPlan.withGraph(getLayerGraphIndex(layer))
+ .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ //.withCUDAGraph()
+ .execute();
+ }
+ state.tempLogits.clear();
+ state.wrapLogits.clear();
+ executionPlan.withGraph(getFinalLogitsGraphIndex())
+ .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
+ .withCUDAGraph()
+ .execute();
+ // @formatter:on
+ return state.wrapLogits;
+ }
+
+ private int getPreprocessingGraphIndex() {
+ return 0;
+ }
+
+ private int getLayerGraphIndex(int layerIndex) {
+ return 1 + layerIndex;
+ }
+
+ private int getFinalLogitsGraphIndex() {
+ return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1;
+ }
+
+ public void forceCopyInReadOnlyDataLayered() {
+ state.wrapX.clear();
+ state.positionHolder.init(0);
+
+ //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
+ executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+
+ for (int layer = 0; layer < config.numberOfLayers(); layer++) {
+ //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
+ executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ }
+
+ //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
+ executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ }
+
+ @Override
+ public void freeTornadoExecutionPlan() {
+ executionPlan.freeDeviceMemory();
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
index 61b81bef..e5262b17 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
@@ -6,9 +6,9 @@
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
/**
- * Wraps {@link TornadoVMMasterPlan} and adds a prefill-only GPU forward pass.
+ * Wraps {@link TornadoVMMasterPlanStandard} and adds a prefill-only GPU forward pass.
*
- * Parallel to {@link TornadoVMMasterPlan} — does NOT modify it.
+ * Parallel to {@link TornadoVMMasterPlanStandard} — does NOT modify it.
*
* The existing execution plan has this graph layout:
*
@@ -28,11 +28,11 @@
*/
public class TornadoVMMasterPlanWithPrefillDecode {
- private final TornadoVMMasterPlan plan;
+ private final TornadoVMMasterPlanStandard plan;
private final State state;
private final Configuration config;
- public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlan plan, State state, Model model) {
+ public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, State state, Model model) {
this.plan = plan;
this.state = state;
this.config = model.configuration();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
new file mode 100644
index 00000000..9bba3860
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
@@ -0,0 +1,461 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * GPU kernels for batched prefill (Phase 4).
+ *
+ * Each kernel processes {@code batchSize} tokens simultaneously.
+ * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}.
+ * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.
+ *
+ * These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode}
+ * batch task graphs; they are NOT invoked directly.
+ */
+public final class TransformerBatchPrefillKernels {
+
+ private TransformerBatchPrefillKernels() {}
+
+ // ── Activation ────────────────────────────────────────────────────────────
+
+ /**
+ * Converts B×dim FP16 token embeddings to FP32.
+ * Worker: B*dim global threads, localSize=128.
+ */
+ public static void batchEmbeddingToFP32(KernelContext context,
+ HalfFloatArray embeddingXBatch,
+ FloatArray wrapXBatch) {
+ int gid = context.globalIdx;
+ wrapXBatch.set(gid, embeddingXBatch.get(gid).getFloat32());
+ }
+
+ // ── RMS Norm (attention) ─────────────────────────────────────────────────
+
+ /**
+ * Sequential RMS reduction — one thread per batch item.
+ *
+ * Each thread computes the RMS scale factor for its token:
+ * {@code scale[b] = 1 / sqrt( mean(x[b]²) + eps )}
+ *
+ * Worker: batchSize global threads, localSize=1.
+ */
+ public static void batchedRmsReduce(KernelContext context,
+ FloatArray wrapXBatch,
+ FloatArray attnScaleBatch,
+ int dim, float eps) {
+ int b = context.globalIdx;
+ int base = b * dim;
+ float ss = 0.0f;
+ for (int i = 0; i < dim; i++) {
+ float val = wrapXBatch.get(base + i);
+ ss += val * val;
+ }
+ ss /= dim;
+ ss += eps;
+ attnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss));
+ }
+
+ /**
+ * Applies RMS normalization and FP16-quantizes the result.
+ *
+ * {@code xbFP16Batch[b*dim+i] = FP16( rmsWeights[i] * scale[b] * x[b*dim+i] )}
+ *
+ * Worker: B*dim global threads, localSize=256.
+ */
+ public static void batchedRmsApplyFP16(KernelContext context,
+ HalfFloatArray xbFP16Batch,
+ FloatArray wrapXBatch,
+ FloatArray rmsWeights,
+ FloatArray attnScaleBatch,
+ int dim) {
+ int gid = context.globalIdx;
+ int b = gid / dim;
+ int i = gid % dim;
+ float scale = attnScaleBatch.get(b);
+ float result = rmsWeights.get(i) * scale * wrapXBatch.get(gid);
+ xbFP16Batch.set(gid, new HalfFloat(result));
+ }
+
+ // ── QKV Projection ────────────────────────────────────────────────────────
+
+ /**
+ * Fused batched QKV projection (FP16 weights, FP16 input).
+ *
+ * One workgroup per (batchIdx, outputRow) pair.
+ * globalGroupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.
+ *
+ * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedFusedQKVMatmul(KernelContext context,
+ HalfFloatArray xbFP16Batch,
+ FloatArray wrapQBatch,
+ FloatArray wrapKBatch,
+ FloatArray wrapVBatch,
+ HalfFloatArray wq,
+ HalfFloatArray wk,
+ HalfFloatArray wv,
+ int dim, int kvDim,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int totalRows = dim + 2 * kvDim;
+ int batchIdx = groupId / totalRows;
+ int rowIdx = groupId % totalRows;
+ int inputOff = batchIdx * dim;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowIdx < dim) {
+ int rowOff = rowIdx * dim;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32();
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]);
+
+ } else if (rowIdx < dim + kvDim) {
+ int kRow = rowIdx - dim;
+ int rowOff = kRow * dim;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32();
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]);
+
+ } else {
+ int vRow = rowIdx - dim - kvDim;
+ int rowOff = vRow * dim;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32();
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]);
+ }
+ }
+
+ // ── RoPE + KV Cache ───────────────────────────────────────────────────────
+
+ /**
+ * Fused batched RoPE rotation + KV cache write.
+ *
+ * globalIdx encodes (batchIdx, pairIdx) as {@code batchIdx*(dim/2) + pairIdx}.
+ * Position for token b = {@code startPos + b}.
+ *
+ * Worker: B*(dim/2) global threads, localSize=512 (or less if B*dim/2 < 512).
+ */
+ public static void batchedRopeWithKVCache(KernelContext context,
+ IntArray batchStartPosHolder,
+ FloatArray wrapQBatch,
+ FloatArray wrapKBatch,
+ FloatArray wrapVBatch,
+ FloatArray wrapKeyCache,
+ FloatArray wrapValueCache,
+ int kvDim, int headSize,
+ int layerIndex, int contextLength, int dim) {
+ int globalIdx = context.globalIdx;
+ int halfDim = dim / 2;
+ int batchIdx = globalIdx / halfDim;
+ int pairIdx = globalIdx % halfDim;
+ int i = pairIdx * 2;
+
+ int pos = batchStartPosHolder.get(0) + batchIdx;
+ int qOffset = batchIdx * dim;
+ int kOffset = batchIdx * kvDim;
+
+ if (i + 1 < dim) {
+ int head_dim = i % headSize;
+ float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) headSize);
+ float val = pos * freq;
+ float fcr = TornadoMath.cos(val);
+ float fci = TornadoMath.sin(val);
+
+ // Rotate Q
+ float v0q = wrapQBatch.get(qOffset + i);
+ float v1q = wrapQBatch.get(qOffset + i + 1);
+ wrapQBatch.set(qOffset + i, v0q * fcr - v1q * fci);
+ wrapQBatch.set(qOffset + i + 1, v0q * fci + v1q * fcr);
+
+ // Rotate K and write K,V to cache
+ if (i + 1 < kvDim) {
+ float v0k = wrapKBatch.get(kOffset + i);
+ float v1k = wrapKBatch.get(kOffset + i + 1);
+ float rotK0 = v0k * fcr - v1k * fci;
+ float rotK1 = v0k * fci + v1k * fcr;
+ wrapKBatch.set(kOffset + i, rotK0);
+ wrapKBatch.set(kOffset + i + 1, rotK1);
+
+ int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim;
+ wrapKeyCache.set(cacheOff + i, rotK0);
+ wrapKeyCache.set(cacheOff + i + 1, rotK1);
+ wrapValueCache.set(cacheOff + i, wrapVBatch.get(kOffset + i));
+ wrapValueCache.set(cacheOff + i + 1, wrapVBatch.get(kOffset + i + 1));
+ }
+ }
+ }
+
+ // ── Attention ─────────────────────────────────────────────────────────────
+
+ /**
+ * Batched causal flash attention.
+ *
+ * One workgroup per (batchIdx, headIdx) pair:
+ * {@code groupIdx = batchIdx * nHeads + headIdx}.
+ * Token b attends to positions 0..{@code startPos + b} (causal).
+ *
+ * Worker: B*nHeads workgroups × optimalLocalSize threads.
+ */
+ public static void batchedFlashAttention(KernelContext context,
+ IntArray batchStartPosHolder,
+ FloatArray wrapQBatch,
+ FloatArray wrapKeyCache,
+ FloatArray wrapValueCache,
+ FloatArray wrapXbBatch,
+ int nHeads, int headSize,
+ int kvDim, int kvMul,
+ int layerIndex, int contextLength, int dim) {
+ int tid = context.localIdx;
+ int groupId = context.groupIdx;
+ int localSz = context.localGroupSizeX;
+
+ int batchIdx = groupId / nHeads;
+ int h = groupId % nHeads;
+ int pos = batchStartPosHolder.get(0) + batchIdx;
+ int loff = layerIndex * contextLength * kvDim;
+ int kvHeadIdx = h / kvMul;
+ int BLOCK_C = 16;
+
+ float[] qShared = context.allocateFloatLocalArray(headSize);
+ float[] kTile = context.allocateFloatLocalArray(BLOCK_C * headSize);
+ float[] vTile = context.allocateFloatLocalArray(BLOCK_C * headSize);
+ float[] sTile = context.allocateFloatLocalArray(BLOCK_C);
+ float[] maxHolder = context.allocateFloatLocalArray(1);
+
+ // Load Q into shared memory
+ int qOffset = batchIdx * dim + h * headSize;
+ for (int i = tid; i < headSize; i += localSz) {
+ qShared[i] = wrapQBatch.get(qOffset + i);
+ }
+ context.localBarrier();
+
+ float maxScore = Float.NEGATIVE_INFINITY;
+ float sumExp = 0.0f;
+ float[] output = new float[headSize];
+ for (int i = 0; i < headSize; i++) output[i] = 0.0f;
+
+ for (int tileC = 0; tileC <= pos; tileC += BLOCK_C) {
+ int tileEnd = Math.min(tileC + BLOCK_C - 1, pos);
+
+ // Load K/V tile
+ for (int t = tileC + tid; t <= tileEnd; t += localSz) {
+ int tInTile = t - tileC;
+ int tileMOff = tInTile * headSize;
+ for (int d = 0; d < headSize; d++) {
+ int kvOff = loff + t * kvDim + kvHeadIdx * headSize + d;
+ kTile[tileMOff + d] = wrapKeyCache.get(kvOff);
+ vTile[tileMOff + d] = wrapValueCache.get(kvOff);
+ }
+ }
+ context.localBarrier();
+
+ // Compute attention scores
+ for (int t = tileC + tid; t <= tileEnd; t += localSz) {
+ int tInTile = t - tileC;
+ float score = 0.0f;
+ for (int d = 0; d < headSize; d++) {
+ score += qShared[d] * kTile[tInTile * headSize + d];
+ }
+ sTile[tInTile] = score / TornadoMath.sqrt(headSize);
+ }
+ context.localBarrier();
+
+ // Tile max
+ float tileMax = Float.NEGATIVE_INFINITY;
+ for (int t = 0; t <= tileEnd - tileC; t++) {
+ if (sTile[t] > tileMax) tileMax = sTile[t];
+ }
+ if (tid == 0) maxHolder[0] = tileMax;
+ context.localBarrier();
+ float curTileMax = maxHolder[0];
+
+ float newMax = Math.max(maxScore, curTileMax);
+ if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
+ float scale = TornadoMath.exp(maxScore - newMax);
+ sumExp *= scale;
+ for (int d = 0; d < headSize; d++) output[d] *= scale;
+ }
+ maxScore = newMax;
+
+ for (int t = 0; t <= tileEnd - tileC; t++) {
+ float expScore = TornadoMath.exp(sTile[t] - maxScore);
+ sumExp += expScore;
+ for (int d = 0; d < headSize; d++) {
+ output[d] += expScore * vTile[t * headSize + d];
+ }
+ }
+ context.localBarrier();
+ }
+
+ float norm = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f;
+ int xbOffset = batchIdx * dim + h * headSize;
+ for (int d = tid; d < headSize; d += localSz) {
+ wrapXbBatch.set(xbOffset + d, output[d] * norm);
+ }
+ }
+
+ // ── Output / FFN Projections ─────────────────────────────────────────────
+
+ /**
+ * Batched matrix-vector multiply with residual add.
+ *
+ * Used for both the attention output projection (Wo) and the FFN down
+ * projection (W2). One workgroup per (batchIdx, outputRow):
+ * {@code groupIdx = batchIdx * d + rowIdx}.
+ *
+ *
+ * - Wo: inputBatch=xbBatch (B×dim), outputBatch=xBatch (B×dim), n=dim, d=dim
+ * - W2: inputBatch=hbBatch (B×hiddenDim), outputBatch=xBatch (B×dim), n=hiddenDim, d=dim
+ *
+ *
+ * Worker: B*d workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedMatVecWithResidual(KernelContext context,
+ FloatArray inputBatch,
+ FloatArray outputBatch,
+ HalfFloatArray w,
+ int n, int d,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int batchIdx = groupId / d;
+ int rowIdx = groupId % d;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+ int inputOff = batchIdx * n;
+ int rowOff = rowIdx * n;
+
+ float partial = 0.0f;
+ for (int j = localId; j < n; j += localWorkGroupSize) {
+ partial += w.get(rowOff + j).getFloat32() * inputBatch.get(inputOff + j);
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) {
+ int outIdx = batchIdx * d + rowIdx;
+ outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]);
+ }
+ }
+
+ // ── FFN RMS Norm ─────────────────────────────────────────────────────────
+
+ /**
+ * Sequential FFN RMS reduction — one thread per batch item.
+ * Worker: batchSize global threads, localSize=1.
+ */
+ public static void batchedFFNRmsReduce(KernelContext context,
+ FloatArray wrapXBatch,
+ FloatArray ffnScaleBatch,
+ int dim, float eps) {
+ int b = context.globalIdx;
+ int base = b * dim;
+ float ss = 0.0f;
+ for (int i = 0; i < dim; i++) {
+ float val = wrapXBatch.get(base + i);
+ ss += val * val;
+ }
+ ss /= dim;
+ ss += eps;
+ ffnScaleBatch.set(b, 1.0f / TornadoMath.sqrt(ss));
+ }
+
+ // ── FFN SwiGLU ───────────────────────────────────────────────────────────
+
+ /**
+ * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU.
+ *
+ * One workgroup per (batchIdx, hiddenRow):
+ * {@code groupIdx = batchIdx * hiddenDim + rowIdx}.
+ *
+ * Worker: B*hiddenDim workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedFusedRmsNormFFNGateUp(KernelContext context,
+ FloatArray wrapXBatch,
+ FloatArray wrapHbBatch,
+ FloatArray rmsFFNWeights,
+ FloatArray ffnScaleBatch,
+ HalfFloatArray w1,
+ HalfFloatArray w3,
+ int dim, int hiddenDim,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int batchIdx = groupId / hiddenDim;
+ int rowIdx = groupId % hiddenDim;
+
+ float scale = ffnScaleBatch.get(batchIdx);
+ int inputOff = batchIdx * dim;
+ int rowOff = rowIdx * dim;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ // W1 matmul with inline RMS apply
+ float sum1 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j);
+ sum1 += w1.get(rowOff + j).getFloat32() * normed;
+ }
+ localSum[localId] = sum1;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ float result1 = localSum[0];
+
+ // W3 matmul with inline RMS apply
+ float sum3 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j);
+ sum3 += w3.get(rowOff + j).getFloat32() * normed;
+ }
+ localSum[localId] = sum3;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ float result3 = localSum[0];
+
+ // SiLU(W1·x) × (W3·x)
+ if (localId == 0) {
+ float silu = result1 / (1.0f + TornadoMath.exp(-result1));
+ wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3);
+ }
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java
new file mode 100644
index 00000000..e28ecb19
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java
@@ -0,0 +1,238 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+/**
+ * Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill.
+ *
+ * One {@link ImmutableTaskGraph} per transformer layer, each processing
+ * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.
+ *
+ * KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device
+ * after every layer so the subsequent single-token decode layers can consume it.
+ */
+public class LlamaFP16BatchPrefillLayers {
+
+ // Matches the local workgroup size used by the single-token kernels.
+ static final int LOCAL_WORK_GROUP_SIZE = 32;
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final LlamaConfiguration config;
+ private final KernelContext context = new KernelContext();
+ private final int batchSize;
+ private final List layerITGs;
+ private String lastLayerTaskGraphID;
+
+ public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights,
+ LlamaConfiguration config, int batchSize) {
+ this.state = state;
+ this.weights = weights;
+ this.config = config;
+ this.batchSize = batchSize;
+ this.layerITGs = IntStream.range(0, config.numberOfLayers())
+ .mapToObj(this::createBatchLayerTaskGraph)
+ .map(TaskGraph::snapshot)
+ .toList();
+ }
+
+ // @formatter:off
+ private TaskGraph createBatchLayerTaskGraph(int layerIndex) {
+ String graphName = "batchLayer_" + layerIndex;
+ if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName;
+
+ TaskGraph layer = new TaskGraph(graphName);
+
+ // ── Data Transfers ─────────────────────────────────────────────────────
+ if (layerIndex == 0) {
+ // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION
+ layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder);
+ // Allocate persistent GPU-side intermediates once
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.attnScaleBatch, state.ffnScaleBatch,
+ state.wrapXbFP16Batch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapXbBatch,
+ state.wrapHbBatch,
+ state.wrapKeyCache, state.wrapValueCache);
+ // wrapXBatch produced by the batch activation graph
+ layer.consumeFromDevice(state.wrapXBatch);
+ } else {
+ layer.consumeFromDevice(
+ context,
+ state.wrapXBatch,
+ state.wrapXbFP16Batch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapXbBatch,
+ state.wrapHbBatch,
+ state.wrapKeyCache, state.wrapValueCache,
+ state.batchStartPosHolder,
+ state.attnScaleBatch, state.ffnScaleBatch);
+ }
+
+ // Per-layer weights: upload once
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asHalfFloatArray(),
+ weights.wkLayered[layerIndex].asHalfFloatArray(),
+ weights.wvLayered[layerIndex].asHalfFloatArray(),
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asHalfFloatArray(),
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ weights.w3Layered[layerIndex].asHalfFloatArray());
+
+ int dim = config.dim();
+ int kvDim = config.kvDim();
+ int hidDim = config.hiddenDim();
+
+ // ── Attention Block ────────────────────────────────────────────────────
+ layer.task("batch_attn_rms",
+ TransformerBatchPrefillKernels::batchedRmsReduce,
+ context, state.wrapXBatch, state.attnScaleBatch,
+ dim, config.rmsNormEps());
+
+ layer.task("batch_attn_rms_apply",
+ TransformerBatchPrefillKernels::batchedRmsApplyFP16,
+ context, state.wrapXbFP16Batch, state.wrapXBatch,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ state.attnScaleBatch, dim);
+
+ layer.task("batch_qkv",
+ TransformerBatchPrefillKernels::batchedFusedQKVMatmul,
+ context,
+ state.wrapXbFP16Batch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ weights.wqLayered[layerIndex].asHalfFloatArray(),
+ weights.wkLayered[layerIndex].asHalfFloatArray(),
+ weights.wvLayered[layerIndex].asHalfFloatArray(),
+ dim, kvDim, LOCAL_WORK_GROUP_SIZE);
+
+ layer.task("batch_rope_kv",
+ TransformerBatchPrefillKernels::batchedRopeWithKVCache,
+ context, state.batchStartPosHolder,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapKeyCache, state.wrapValueCache,
+ kvDim, config.headSize(), layerIndex, config.contextLength(), dim);
+
+ layer.task("batch_attention",
+ TransformerBatchPrefillKernels::batchedFlashAttention,
+ context, state.batchStartPosHolder,
+ state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache,
+ state.wrapXbBatch,
+ config.numberOfHeads(), config.headSize(),
+ kvDim, config.kvMul(), layerIndex, config.contextLength(), dim);
+
+ layer.task("batch_attn_out",
+ TransformerBatchPrefillKernels::batchedMatVecWithResidual,
+ context, state.wrapXbBatch, state.wrapXBatch,
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ dim, dim, LOCAL_WORK_GROUP_SIZE);
+
+ // ── FFN Block ──────────────────────────────────────────────────────────
+ layer.task("batch_ffn_rms",
+ TransformerBatchPrefillKernels::batchedFFNRmsReduce,
+ context, state.wrapXBatch, state.ffnScaleBatch,
+ dim, config.rmsNormEps());
+
+ layer.task("batch_ffn_gate_up",
+ TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp,
+ context, state.wrapXBatch, state.wrapHbBatch,
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ state.ffnScaleBatch,
+ weights.w1Layered[layerIndex].asHalfFloatArray(),
+ weights.w3Layered[layerIndex].asHalfFloatArray(),
+ dim, hidDim, LOCAL_WORK_GROUP_SIZE);
+
+ layer.task("batch_ffn_down",
+ TransformerBatchPrefillKernels::batchedMatVecWithResidual,
+ context, state.wrapHbBatch, state.wrapXBatch,
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ hidDim, dim, LOCAL_WORK_GROUP_SIZE);
+
+ // Persist wrapXBatch for the next layer, and KV cache so the decode
+ // layers can consume it via the activation graph pass-through.
+ layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache);
+
+ return layer;
+ }
+ // @formatter:on
+
+ /** Registers all batch layer workers in the shared {@link GridScheduler}. */
+ public void updateGridScheduler(GridScheduler scheduler) {
+ int dim = config.dim();
+ int kvDim = config.kvDim();
+ int hidDim = config.hiddenDim();
+ int nHeads = config.numberOfHeads();
+ int headSz = config.headSize();
+
+ // RMS: one thread per batch token
+ WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1);
+
+ // RMS apply: B*dim threads, local=256 (dim is always a multiple of 256 for LLaMA)
+ WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256);
+
+ // QKV: B*(dim+2*kvDim) workgroups × LOCAL_WORK_GROUP_SIZE
+ int qkvRows = dim + 2 * kvDim;
+ WorkerGrid qkvWorker = WorkerGridFactory.genericWorker(
+ batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+
+ // RoPE+KV cache: B*(dim/2) threads, local=512
+ int ropeGlobal = batchSize * (dim / 2);
+ int ropeLocal = Math.min(512, ropeGlobal);
+ while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--;
+ WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal);
+
+ // Attention (flash): B*nHeads workgroups × optimalLocalSize
+ int optLocal = findOptimalLocalSize(headSz);
+ WorkerGrid attnWorker = WorkerGridFactory.genericWorker(
+ batchSize * nHeads * optLocal, optLocal);
+
+ // Mat-vec (Wo, W2): B*d workgroups × LOCAL_WORK_GROUP_SIZE
+ WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker(
+ batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+ WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker(
+ batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String p = "batchLayer_" + i + ".";
+ scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker);
+ scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker);
+ scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker);
+ scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker);
+ scheduler.addWorkerGrid(p + "batch_attention", attnWorker);
+ scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker);
+ }
+ }
+
+ private static int findOptimalLocalSize(int size) {
+ int optimal = Math.min(size, 64);
+ if (size % optimal != 0) {
+ for (int s = 64; s >= 1; s--) {
+ if (size % s == 0) { optimal = s; break; }
+ }
+ }
+ return optimal;
+ }
+
+ public List getLayerImmutableTaskGraphs() { return layerITGs; }
+ public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; }
+ public KernelContext getContext() { return context; }
+}
From 4152edb48829dd07df4e7203a91785adc05c1acc Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Fri, 3 Apr 2026 20:22:19 +0300
Subject: [PATCH 07/12] [prf/dec][refactor] Restructure prefill-decode
ExecutionPlan components to dedicated classes and packages
Move `LlamaFP16BatchPrefillLayers` to `tornadovm.layers.type.fp16.prefll` and `LlamaFP16FFNLayersForUnifiedDecode` to `tornadovm.layers.type.fp16.decode`
---
.../TornadoVMMasterPlanBatchPrefill.java | 49 +++----------------
.../LlamaFP16FFNLayersForUnifiedDecode.java | 47 ++++++++++++++++++
.../LlamaFP16BatchPrefillLayers.java | 2 +-
3 files changed, 56 insertions(+), 42 deletions(-)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/{ => prefill}/LlamaFP16BatchPrefillLayers.java (99%)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
index b2388bf3..258bb9fe 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
@@ -8,8 +8,8 @@
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16BatchPrefillLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -300,43 +300,10 @@ public void freeTornadoExecutionPlan() {
}
// ── Inner class: decode layer 0 with consumeFromDevice for KV cache ───────
-
- /**
- * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses
- * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}.
- *
- * This ensures decode layer 0 receives the KV-cache device pointer that was
- * persisted by the last batch prefill layer and passed through the decode
- * activation graph.
- */
- private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers {
-
- LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state,
- LlamaTornadoWeights weights, LlamaConfiguration config,
- SchedulerType schedulerType) {
- super(taskGraph, state, weights, config, schedulerType);
- }
-
- @Override
- protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
- if (layerIndex == 0) {
- // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come
- // from device (passed through by the decode activation graph).
- layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.positionHolder, state.temp, state.tempFFN);
- layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- context,
- state.wrapXb, state.wrapXb2,
- state.wrapQ, state.wrapK, state.wrapV,
- state.wrapAtt, state.wrapHb, state.wrapXbFP16);
- // KV cache: consume from device (device pointer supplied by
- // decode activation's pass-through from last batch layer).
- layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache);
- } else {
- // Identical to parent for layers 1+ (already uses consumeFromDevice).
- return super.configureLayerDataTransfers(layer, layerIndex);
- }
- return layer;
- }
- }
+// moved to package
+//
+// private static final class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers {
+//
+//
+// }
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java
new file mode 100644
index 00000000..b1cd063f
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java
@@ -0,0 +1,47 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+/**
+ * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses
+ * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}.
+ *
+ * This ensures decode layer 0 receives the KV-cache device pointer that was
+ * persisted by the last batch prefill layer and passed through the decode
+ * activation graph.
+ */
+public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers {
+ public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state,
+ LlamaTornadoWeights weights, LlamaConfiguration config,
+ SchedulerType schedulerType) {
+ super(taskGraph, state, weights, config, schedulerType);
+ }
+
+ @Override
+ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
+ if (layerIndex == 0) {
+ // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come
+ // from device (passed through by the decode activation graph).
+ layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder, state.temp, state.tempFFN);
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapAtt, state.wrapHb, state.wrapXbFP16);
+ // KV cache: consume from device (device pointer supplied by
+ // decode activation's pass-through from last batch layer).
+ layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache);
+ } else {
+ // Identical to parent for layers 1+ (already uses consumeFromDevice).
+ return super.configureLayerDataTransfers(layer, layerIndex);
+ }
+ return layer;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java
similarity index 99%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java
index e28ecb19..8414be72 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16BatchPrefillLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill;
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
From 8d2d15dd2ced670f753e3b827b96bd1575f47246 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Fri, 3 Apr 2026 20:27:34 +0300
Subject: [PATCH 08/12] [prf/dec][dbg] Guard CUDA Graphs enable/disable behind
`--no-cuda-graphs` option to ease debugging
---
llama-tornado | 9 ++++
.../tornadovm/TornadoVMMasterPlan.java | 4 ++
.../TornadoVMMasterPlanBatchPrefill.java | 44 ++++++++-----------
.../TornadoVMMasterPlanStandard.java | 18 ++++----
4 files changed, 41 insertions(+), 34 deletions(-)
diff --git a/llama-tornado b/llama-tornado
index 81349a5e..4900a100 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -93,6 +93,9 @@ class LlamaRunner:
if args.prefill_batch_size is not None:
cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}")
+ if args.no_cuda_graphs:
+ cmd.append("-Dllama.cudaGraphs=false")
+
# Debug options
debug_config = []
@@ -493,6 +496,12 @@ def create_parser() -> argparse.ArgumentParser:
default=None,
help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)",
)
+ prefill_group.add_argument(
+ "--no-cuda-graphs",
+ dest="no_cuda_graphs",
+ action="store_true",
+ help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); useful for debugging",
+ )
# Advanced options
advanced_group = parser.add_argument_group("Advanced Options")
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index 43030fd0..c81ba92c 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -26,6 +26,10 @@ public interface TornadoVMMasterPlan {
boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(
System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
+ /** When {@code false}, {@code withCUDAGraph()} is never called — useful for debugging. */
+ boolean CUDA_GRAPHS = Boolean.parseBoolean(
+ System.getProperty("llama.cudaGraphs", "true"));
+
/**
* Single-token forward pass returning output logits.
*
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
index 258bb9fe..cc6591c2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
@@ -170,7 +170,7 @@ public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan(
System.err.printf("[BatchPlan] Graph construction: %.2f ms%n",
(System.nanoTime() - t0) / 1e6);
- plan.executionPlan.withAllGraphs().withCUDAGraph();
+ if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph();
plan.executionPlan.withPreCompilation();
if (ENABLE_TIMING)
@@ -194,10 +194,9 @@ private void forceCopyInReadOnlyData() {
state.batchStartPosHolder.init(0);
for (int i = 0; i <= logitsIdx(); i++) {
- executionPlan.withGraph(i)
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) g.withCUDAGraph();
+ g.execute();
}
}
@@ -226,17 +225,15 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod
state.batchStartPosHolder.set(0, startPos);
// Graph 0: batch activation
- executionPlan.withGraph(batchActivationIdx())
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) batchAct.withCUDAGraph();
+ batchAct.execute();
// Graphs 1..N: batch transformer layers
for (int l = 0; l < N; l++) {
- executionPlan.withGraph(batchLayerIdx(l))
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) batchLayer.withCUDAGraph();
+ batchLayer.execute();
}
// Logits skipped — not needed for prefill positions.
}
@@ -263,27 +260,24 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
state.tempFFN.clear();
// Graph N+1: decode activation
- executionPlan.withGraph(decodeActivationIdx())
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
+ decodeAct.execute();
// Graphs N+2..2N+1: decode transformer layers
for (int l = 0; l < N; l++) {
- executionPlan.withGraph(decodeLayerIdx(l))
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) decodeLayer.withCUDAGraph();
+ decodeLayer.execute();
}
state.tempLogits.clear();
state.wrapLogits.clear();
// Graph 2N+2: logits
- executionPlan.withGraph(logitsIdx())
- .withGridScheduler(gridScheduler)
- .withCUDAGraph()
- .execute();
+ var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler);
+ if (CUDA_GRAPHS) logits.withCUDAGraph();
+ logits.execute();
return state.wrapLogits;
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
index 91586f2c..c9d816ee 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
@@ -56,7 +56,7 @@ static TornadoVMMasterPlanStandard initialize(State state, Model model) {
System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
}
- tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
+ if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
tornadoVMPlan.executionPlan.withPreCompilation();
if (ENABLE_TORNADOVM_INIT_TIME) {
@@ -89,10 +89,10 @@ private GenericLayerPlanner createPlanner(State state, Model model) {
@Override
public FloatArray tornadoVMForwardExecuteLayered(int position) {
// @formatter:off
- executionPlan.withGraph(getPreprocessingGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- .withCUDAGraph()
- .execute();
+ var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex())
+ .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
+ if (CUDA_GRAPHS) preGraph.withCUDAGraph();
+ preGraph.execute();
state.positionHolder.set(0, position);
state.temp.clear();
@@ -106,10 +106,10 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
}
state.tempLogits.clear();
state.wrapLogits.clear();
- executionPlan.withGraph(getFinalLogitsGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- .withCUDAGraph()
- .execute();
+ var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex())
+ .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
+ if (CUDA_GRAPHS) logitsGraph.withCUDAGraph();
+ logitsGraph.execute();
// @formatter:on
return state.wrapLogits;
}
From 2e51fc29a354e0a5c92f66e673c8f7e24e2be4a3 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Fri, 3 Apr 2026 20:32:24 +0300
Subject: [PATCH 09/12] [prf/dec][refactor] Rename
`LlamaFP16FFNLayersForUnifiedDecode` to `LlamaFP16FFNLayersDecode`
---
.../tornadovm/TornadoVMMasterPlanBatchPrefill.java | 6 +++---
...orUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} | 8 ++++----
2 files changed, 7 insertions(+), 7 deletions(-)
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/{LlamaFP16FFNLayersForUnifiedDecode.java => LlamaFP16FFNLayersDecode.java} (85%)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
index cc6591c2..6ce9c23e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersForUnifiedDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import uk.ac.manchester.tornado.api.GridScheduler;
@@ -100,8 +100,8 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
// Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
- LlamaFP16FFNLayersForUnifiedDecode decodeLayers =
- new LlamaFP16FFNLayersForUnifiedDecode(
+ LlamaFP16FFNLayersDecode decodeLayers =
+ new LlamaFP16FFNLayersDecode(
"llamaFFNDecode", state, weights, config, schedulerType);
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
decodeLayers.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
similarity index 85%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
index b1cd063f..f781f08a 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersForUnifiedDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
@@ -16,10 +16,10 @@
* persisted by the last batch prefill layer and passed through the decode
* activation graph.
*/
-public class LlamaFP16FFNLayersForUnifiedDecode extends LlamaFP16FFNLayers {
- public LlamaFP16FFNLayersForUnifiedDecode(String taskGraph, LlamaState state,
- LlamaTornadoWeights weights, LlamaConfiguration config,
- SchedulerType schedulerType) {
+public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers {
+ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
+ LlamaTornadoWeights weights, LlamaConfiguration config,
+ SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
}
From 885473b2eddb20d26329d182c0d42d8a20ef3233 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Fri, 3 Apr 2026 20:42:32 +0300
Subject: [PATCH 10/12] [prf/dec][refactor] Rename
`LlamaFP16BatchPrefillLayers` to `LlamaFP16LayersBatchPrefill`
---
...doVMMasterPlanWithBatchPrefillDecode.java} | 22 ++++++++++++-------
....java => LlamaFP16LayersBatchPrefill.java} | 10 ++++-----
2 files changed, 19 insertions(+), 13 deletions(-)
rename src/main/java/org/beehive/gpullama3/tornadovm/{TornadoVMMasterPlanBatchPrefill.java => TornadoVMMasterPlanWithBatchPrefillDecode.java} (91%)
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/{LlamaFP16BatchPrefillLayers.java => LlamaFP16LayersBatchPrefill.java} (97%)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
similarity index 91%
rename from src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
index 6ce9c23e..c8772e61 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
@@ -9,7 +9,7 @@
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16BatchPrefillLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -80,15 +80,15 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch
List all = new ArrayList<>(2 * N + 3);
GridScheduler scheduler = new GridScheduler();
- // [0] Batch activation ────────────────────────────────────────────────
+ // [0] Batch prefill activation ────────────────────────────────────────────────
KernelContext batchActCtx = new KernelContext();
- all.add(buildBatchActivationGraph(batchActCtx).snapshot());
+ all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
scheduler.addWorkerGrid("batchActivation.batchUpdateX",
WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
- // [1..N] Batch layer graphs ───────────────────────────────────────────
- LlamaFP16BatchPrefillLayers batchLayers =
- new LlamaFP16BatchPrefillLayers(state, weights, config, batchSize);
+ // [1..N] Batch prefill layer graphs ───────────────────────────────────────────
+ LlamaFP16LayersBatchPrefill batchLayers =
+ new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
all.addAll(batchLayers.getLayerImmutableTaskGraphs());
batchLayers.updateGridScheduler(scheduler);
@@ -116,10 +116,10 @@ private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batch
this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
}
- // ── Activation graphs ─────────────────────────────────────────────────────
+ // ── Batch Prefill Activation graphs ─────────────────────────────────────────────────────
/** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */
- private TaskGraph buildBatchActivationGraph(KernelContext ctx) {
+ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
return new TaskGraph("batchActivation")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
@@ -142,6 +142,9 @@ private TaskGraph buildBatchActivationGraph(KernelContext ctx) {
private TaskGraph buildDecodeActivationGraph(KernelContext ctx) {
return new TaskGraph("activationUpdate")
.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through
+// .transferToDevice(DataTransferMode.EVERY_EXECUTION,
+// state.wrapKeyCache,
+// state.wrapValueCache)
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
.task("updateX",
@@ -235,6 +238,7 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod
if (CUDA_GRAPHS) batchLayer.withCUDAGraph();
batchLayer.execute();
}
+ //System.err.println("[DEBUG] last batch layer done, about to return from prefill");
// Logits skipped — not needed for prefill positions.
}
@@ -262,12 +266,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
// Graph N+1: decode activation
var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler);
if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
+ //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)");
decodeAct.execute();
// Graphs N+2..2N+1: decode transformer layers
for (int l = 0; l < N; l++) {
var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler);
if (CUDA_GRAPHS) decodeLayer.withCUDAGraph();
+ //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)");
decodeLayer.execute();
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
index 8414be72..30e3267e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16BatchPrefillLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
@@ -24,7 +24,7 @@
* KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device
* after every layer so the subsequent single-token decode layers can consume it.
*/
-public class LlamaFP16BatchPrefillLayers {
+public class LlamaFP16LayersBatchPrefill {
// Matches the local workgroup size used by the single-token kernels.
static final int LOCAL_WORK_GROUP_SIZE = 32;
@@ -37,20 +37,20 @@ public class LlamaFP16BatchPrefillLayers {
private final List layerITGs;
private String lastLayerTaskGraphID;
- public LlamaFP16BatchPrefillLayers(LlamaState state, LlamaTornadoWeights weights,
- LlamaConfiguration config, int batchSize) {
+ public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights,
+ LlamaConfiguration config, int batchSize) {
this.state = state;
this.weights = weights;
this.config = config;
this.batchSize = batchSize;
this.layerITGs = IntStream.range(0, config.numberOfLayers())
- .mapToObj(this::createBatchLayerTaskGraph)
+ .mapToObj(this::createBatchPrefillLayerTaskGraph)
.map(TaskGraph::snapshot)
.toList();
}
// @formatter:off
- private TaskGraph createBatchLayerTaskGraph(int layerIndex) {
+ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
String graphName = "batchLayer_" + layerIndex;
if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName;
From 61287934317c674b81860ae20282a822854a07f0 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Fri, 3 Apr 2026 21:23:09 +0300
Subject: [PATCH 11/12] [prf/dec][refactor] Rename
`TornadoVMMasterPlanBatchPrefill` to
`TornadoVMMasterPlanWithBatchPrefillDecode`
---
.../inference/InferenceEngineWithPrefillDecode.java | 4 ++--
.../gpullama3/tornadovm/TornadoVMMasterPlan.java | 10 +++++-----
.../TornadoVMMasterPlanWithBatchPrefillDecode.java | 10 +++++-----
3 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
index 6517df12..74c474e9 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -10,7 +10,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefill;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
@@ -202,7 +202,7 @@ public static List generateTokensGPULlama(
// ── Phase 4: Batch GPU Prefill ────────────────────────────────────
// Plan was pre-initialized in Model.runInstructOnce/runInteractive
// as a TornadoVMMasterPlanBatchPrefill by TornadoVMMasterPlan.initializeTornadoVMPlan.
- TornadoVMMasterPlanBatchPrefill plan = (TornadoVMMasterPlanBatchPrefill) tornadoVMPlan;
+ TornadoVMMasterPlanWithBatchPrefillDecode plan = (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan;
int N = promptTokens.size();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index c81ba92c..1acd8f24 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -12,13 +12,13 @@
*
* - {@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the
* baseline GPU path and Phase 2 sequential prefill/decode.
- * - {@link TornadoVMMasterPlanBatchPrefill} — unified plan for Phase 4 batched
+ *
- {@link TornadoVMMasterPlanWithBatchPrefillDecode} — unified plan for Phase 4 batched
* prefill + single-token decode within one {@code TornadoExecutionPlan}.
*
*
* The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation
* based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a
- * {@link TornadoVMMasterPlanBatchPrefill}; otherwise returns a
+ * {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a
* {@link TornadoVMMasterPlanStandard}.
*/
public interface TornadoVMMasterPlan {
@@ -35,7 +35,7 @@ public interface TornadoVMMasterPlan {
*
* Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM})
* and the Phase 2 sequential decode path. Not applicable to
- * {@link TornadoVMMasterPlanBatchPrefill} — that plan uses its own typed methods.
+ * {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.
*
* @param position sequence position of the current token
* @return logits array for token sampling
@@ -48,7 +48,7 @@ public interface TornadoVMMasterPlan {
/**
* Factory: creates, JIT-compiles, and warms up the appropriate plan.
*
- * When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanBatchPrefill}
+ *
When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode}
* is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.
*
* @param state the model state (must be {@link LlamaState} when batch size {@code > 1})
@@ -59,7 +59,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
int batchSize = Integer.getInteger("llama.prefillBatchSize", 1);
TornadoVMMasterPlan plan;
if (batchSize > 1) {
- plan = TornadoVMMasterPlanBatchPrefill.initializeUnifiedPlan(
+ plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan(
(LlamaState) state, model, batchSize);
} else {
plan = TornadoVMMasterPlanStandard.initialize(state, model);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
index c8772e61..f8b2cf63 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
@@ -48,7 +48,7 @@
* decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention)
*
*/
-public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan {
+public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan {
private static final boolean ENABLE_TIMING =
Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
@@ -68,7 +68,7 @@ public class TornadoVMMasterPlanBatchPrefill implements TornadoVMMasterPlan {
private int logitsIdx() { return 2 * N + 2; }
// ── Construction ─────────────────────────────────────────────────────────
- private TornadoVMMasterPlanBatchPrefill(LlamaState state, Model model, int batchSize) {
+ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) {
this.state = state;
this.config = (LlamaConfiguration) model.configuration();
this.batchSize = batchSize;
@@ -162,12 +162,12 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx) {
* Creates, JIT-compiles, and warms up the unified plan.
* Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}.
*/
- public static TornadoVMMasterPlanBatchPrefill initializeUnifiedPlan(
+ public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan(
LlamaState state, Model model, int batchSize) {
long t0 = System.nanoTime();
- TornadoVMMasterPlanBatchPrefill plan =
- new TornadoVMMasterPlanBatchPrefill(state, model, batchSize);
+ TornadoVMMasterPlanWithBatchPrefillDecode plan =
+ new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize);
if (ENABLE_TIMING)
System.err.printf("[BatchPlan] Graph construction: %.2f ms%n",
From 2b1aababc02eb29982b6cadbf1636176fab3dd62 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 7 Apr 2026 13:34:34 +0300
Subject: [PATCH 12/12] [prf/dec] Fix KV-cache propagation bug from prefill to
decode path and refactor task graph consumption logic
Introduce `LogitsFP16LayerDecode` with KV-cache pass-through. Override `consumeFromDevice` and `persistOnDevice` in LlamaFFN layers to fix cross-graph propagation for both CUDA and interpreter modes.
---
...adoVMMasterPlanWithBatchPrefillDecode.java | 45 ++++++++++-----
.../layers/type/fp16/LlamaFP16FFNLayers.java | 34 ++++++++++-
.../layers/type/fp16/LogitsFP16Layer.java | 15 +++++
.../fp16/decode/LlamaFP16FFNLayersDecode.java | 56 +++++++++++++++----
.../fp16/decode/LogitsFP16LayerDecode.java | 53 ++++++++++++++++++
.../prefill/LlamaFP16LayersBatchPrefill.java | 15 ++++-
6 files changed, 187 insertions(+), 31 deletions(-)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
index f8b2cf63..3df08dfa 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
@@ -10,7 +10,7 @@
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -94,8 +94,8 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
KernelContext decodeActCtx = new KernelContext();
- all.add(buildDecodeActivationGraph(decodeActCtx).snapshot());
- scheduler.addWorkerGrid("activationUpdate.updateX",
+ all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
+ scheduler.addWorkerGrid("decodeActivationUpdate.updateX",
WorkerGridFactory.genericWorker(config.dim(), 128));
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
@@ -107,7 +107,10 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
decodeLayers.updateGridScheduler(scheduler);
// [2N+2] Logits ───────────────────────────────────────────────────────
- LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config,
+ // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
+ // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
+ // KV-cache pointer survives the logits → decode-activation boundary across tokens.
+ LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
all.add(logitsLayer.getImmutableTaskGraph());
logitsLayer.updateGridScheduler(scheduler);
@@ -123,9 +126,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
return new TaskGraph("batchActivation")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
- .task("batchUpdateX",
- (KernelContext c, HalfFloatArray src, FloatArray dst) ->
- dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()),
+ .task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32,
ctx, state.embeddingXBatch, state.wrapXBatch)
.persistOnDevice(state.wrapXBatch);
}
@@ -139,17 +140,24 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
* Both halves of the chain are required; without the re-persist the pointer is
* not forwarded in interpreter (non-CUDA-graph) mode.
*/
- private TaskGraph buildDecodeActivationGraph(KernelContext ctx) {
- return new TaskGraph("activationUpdate")
- .consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through
-// .transferToDevice(DataTransferMode.EVERY_EXECUTION,
-// state.wrapKeyCache,
-// state.wrapValueCache)
- .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
+ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) {
+// System.out.println("lastBatchLayerID = " + lastBatchLayerID);
+// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString());
+// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString());
+// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString());
+ return new TaskGraph("decodeActivationUpdate")
+ .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through
+ //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV)
+ //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
.task("updateX",
TransformerComputeKernels::convertFP16toFP32,
ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
+// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe
+// .task("dbgKV",
+// TransformerComputeKernels::dbgCopyFirst8,
+// state.wrapKeyCache, debugKV)
+// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV)
// wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache
// re-persisted so updatePersistedObjectState() propagates the device
// pointer to decode layer 0's consumeFromDevice without CUDA graphs.
@@ -197,6 +205,7 @@ private void forceCopyInReadOnlyData() {
state.batchStartPosHolder.init(0);
for (int i = 0; i <= logitsIdx(); i++) {
+ //System.out.println(i + " " + executionPlan.withGraph(i).toString());
var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
if (CUDA_GRAPHS) g.withCUDAGraph();
g.execute();
@@ -268,6 +277,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
//System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)");
decodeAct.execute();
+ // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and
+ // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache)
+// if (position <= 290) {
+// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n",
+// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3));
+// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n",
+// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3));
+// }
// Graphs N+2..2N+1: decode transformer layers
for (int l = 0; l < N; l++) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
index 56f2c0c3..50619cc2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
@@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
// === Data Setup ===
- unifiedLayer.consumeFromDevice(state.wrapX);
+ // consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the
+ // source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter
+ // mode (updatePersistedObjectState looks up the predecessor's name, not the current name).
+ // Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to
+ // return the correct predecessor graph name so the XPUBuffer is propagated in both modes.
+ String wrapXSrc = predecessorGraphName(layerIndex);
+ if (wrapXSrc != null) {
+ unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX);
+ } else {
+ unifiedLayer.consumeFromDevice(state.wrapX);
+ }
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
weights.wqLayered[layerIndex].asHalfFloatArray(),
@@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
weights.w2Layered[layerIndex].asHalfFloatArray(),
config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
- unifiedLayer.persistOnDevice(state.wrapX);
+ unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache,
+ state.wrapValueCache);
return unifiedLayer;
}
+ /**
+ * Returns the name of the predecessor task graph from which {@code wrapX} should be consumed,
+ * or {@code null} to fall back to the no-arg form (source key = own graph name).
+ *
+ * The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time)
+ * but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's
+ * graph name, not the current graph's name, so the XPUBuffer is never propagated and
+ * {@code executeAlloc} NPEs on a null buffer.
+ *
+ * Override in subclasses that receive {@code wrapX} from a named predecessor graph:
+ *
+ * - layer 0: return the activation graph name (e.g. {@code "activationUpdate"})
+ * - layer k > 0: return {@code "layer_" + (k-1)}
+ *
+ */
+ protected String predecessorGraphName(int layerIndex) {
+ return null;
+ }
+
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
if (layerIndex == 0) {
// First layer: Transfer initial data to device (one-time transfer)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
index bf938a0d..1858408e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
@@ -22,11 +22,25 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
super(name, state, weights, config, lastTaskGraphID, schedulerType);
}
+ /**
+ * Hook called before any data transfers or tasks. Override to prepend
+ * {@code consumeFromDevice} declarations that must precede the bytecode
+ * (e.g. KV-cache pass-through in the Phase 4 unified plan).
+ */
+ protected void configureAdditionalConsumes(TaskGraph logits) {}
+
+ /**
+ * Hook called after {@code transferToHost}. Override to append
+ * {@code persistOnDevice} declarations (e.g. KV-cache pass-through in Phase 4).
+ */
+ protected void configureAdditionalPersists(TaskGraph logits) {}
+
// @formatter:off
@Override
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
var logits = new TaskGraph("logits");
// === Data Setup ===
+ configureAdditionalConsumes(logits);
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
@@ -80,6 +94,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c
// === Transfer Results to Host ===
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ configureAdditionalPersists(logits);
return logits;
}
// @formatter:on
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
index f781f08a..4d632425 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
@@ -9,12 +9,22 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses
- * {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}.
+ * Decode-path FFN layers for the Phase 4 unified plan.
*
- * This ensures decode layer 0 receives the KV-cache device pointer that was
- * persisted by the last batch prefill layer and passed through the decode
- * activation graph.
+ * Overrides data-transfer declarations so that all cross-graph boundaries use
+ * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by
+ * the base class) passes the current graph's own name as the source key.
+ * In CUDA-graph mode this is harmless (device pointers are frozen at capture time),
+ * but in interpreter mode {@code updatePersistedObjectState} looks up the
+ * predecessor's name, so the lookup always misses and the XPUBuffer is
+ * never propagated — causing either a null-pointer crash or a silent re-upload
+ * from host (zeros), corrupting the hidden state and KV cache.
+ *
+ * Two boundaries are fixed here:
+ *
+ * - {@code wrapX}: via {@link #predecessorGraphName} hook in the base class.
+ * - All other consumed objects: via the {@link #configureLayerDataTransfers} override.
+ *
*/
public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers {
public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
@@ -23,11 +33,25 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
super(taskGraph, state, weights, config, schedulerType);
}
+ /**
+ * Supplies the correct predecessor graph name for {@code consumeFromDevice(wrapX)}.
+ *
+ * Layer 0 receives {@code wrapX} from the decode activation graph;
+ * layers 1+ receive it from the previous decode layer.
+ * Must match the {@code TaskGraph} names used in
+ * {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.
+ */
+ @Override
+ protected String predecessorGraphName(int layerIndex) {
+ return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1);
+ }
+
@Override
protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
if (layerIndex == 0) {
- // Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come
- // from device (passed through by the decode activation graph).
+ // Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device
+ // (passed through by the decode activation graph, which relays them from
+ // the last batch prefill layer). No FIRST_EXECUTION for KV cache here.
layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
state.positionHolder, state.temp, state.tempFFN);
layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
@@ -35,12 +59,20 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex)
state.wrapXb, state.wrapXb2,
state.wrapQ, state.wrapK, state.wrapV,
state.wrapAtt, state.wrapHb, state.wrapXbFP16);
- // KV cache: consume from device (device pointer supplied by
- // decode activation's pass-through from last batch layer).
- layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache);
+ // Explicit source — must match the TaskGraph name in buildDecodeActivationGraph().
+ layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache);
} else {
- // Identical to parent for layers 1+ (already uses consumeFromDevice).
- return super.configureLayerDataTransfers(layer, layerIndex);
+ // Layers 1+: use explicit predecessor name for ALL consumed objects.
+ // Calling super here would use the no-arg form (source key = own graph name),
+ // which silently fails in interpreter mode and causes re-upload from host.
+ String pred = "layer_" + (layerIndex - 1);
+ layer.consumeFromDevice(pred,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapKeyCache, state.wrapValueCache,
+ state.wrapAtt, state.wrapHb,
+ state.positionHolder, state.wrapXbFP16);
}
return layer;
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
new file mode 100644
index 00000000..760be156
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
@@ -0,0 +1,53 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import uk.ac.manchester.tornado.api.TaskGraph;
+
+/**
+ * Logits layer for the unified prefill-decode plan (Phase 4).
+ *
+ * Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device
+ * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the
+ * logits → decode-activation boundary across decode tokens.
+ *
+ * In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()}
+ * propagates device pointers from the predecessor graph's persisted set. After the
+ * last decode token the predecessor of the next decode-activation graph is the
+ * logits graph. Without the pass-through here, the KV-cache pointer is absent from
+ * the logits persisted set, cleared to null, and the first decode layer crashes with
+ * an NPE in {@code executeAlloc}.
+ *
+ * Bytecode order matters: {@code consumeFromDevice} must precede task declarations,
+ * and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in
+ * {@link LogitsFP16Layer} guarantee this ordering.
+ */
+public class LogitsFP16LayerDecode extends LogitsFP16Layer {
+
+ public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config,
+ String lastTaskGraphID, SchedulerType schedulerType) {
+ super(name, state, weights, config, lastTaskGraphID, schedulerType);
+ }
+
+ /**
+ * Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks.
+ *
+ * Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache
+ * to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState}
+ * uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor
+ * (last decode layer) to never be propagated into the logits graph's device state.
+ */
+ @Override
+ protected void configureAdditionalConsumes(TaskGraph logits) {
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache);
+ }
+
+ /** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */
+ @Override
+ protected void configureAdditionalPersists(TaskGraph logits) {
+ logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
index 30e3267e..a893623d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
@@ -69,10 +69,19 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
state.wrapXbBatch,
state.wrapHbBatch,
state.wrapKeyCache, state.wrapValueCache);
- // wrapXBatch produced by the batch activation graph
- layer.consumeFromDevice(state.wrapXBatch);
+ // wrapXBatch produced by the batch activation graph.
+ // Explicit source name required: the no-arg form uses the current graph's own
+ // name ("batchLayer_0") which never matches "batchActivation" in interpreter mode,
+ // causing wrapXBatch to be re-uploaded from host (zeros) instead of using the
+ // FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel.
+ layer.consumeFromDevice("batchActivation", state.wrapXBatch);
} else {
- layer.consumeFromDevice(
+ // Explicit predecessor name for all objects.
+ // The no-arg form would use "batchLayer_k" as the source key, which never matches
+ // "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from
+ // host (zeros or stale), corrupting the KV cache written by the previous layer.
+ String pred = "batchLayer_" + (layerIndex - 1);
+ layer.consumeFromDevice(pred,
context,
state.wrapXBatch,
state.wrapXbFP16Batch,