Skip to content

Commit e00fae8

Browse files
[prf/dec][fix] Restructure and fix issues in TornadoVMMasterPlanWithPrefillDecode
This fixes GPU prefill-decode without batching without CUDA Graphs
1 parent 2988e7f commit e00fae8

2 files changed

Lines changed: 204 additions & 64 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java

Lines changed: 131 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,67 @@
11
package org.beehive.gpullama3.tornadovm;
22

3+
import org.beehive.gpullama3.inference.state.LlamaState;
34
import org.beehive.gpullama3.inference.state.State;
4-
import org.beehive.gpullama3.model.Configuration;
5+
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
56
import org.beehive.gpullama3.model.Model;
6-
import org.beehive.gpullama3.tensor.GGMLType;
7-
import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
8-
import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
7+
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
8+
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
9+
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
10+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
11+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
12+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode;
13+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
14+
import uk.ac.manchester.tornado.api.GridScheduler;
915
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
16+
import uk.ac.manchester.tornado.api.KernelContext;
17+
import uk.ac.manchester.tornado.api.TaskGraph;
1018
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
19+
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1120
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
21+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
22+
23+
import java.util.ArrayList;
24+
import java.util.List;
1225

1326
/**
1427
* GPU execution plan for single-token prefill/decode separation.
1528
*
16-
* <p>Uses the same single-token execution plan as {@link TornadoVMMasterPlanStandard}
17-
* but exposes two distinct forward passes:</p>
18-
* <ul>
19-
* <li>{@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips the logits graph.
20-
* Called for each prompt token; KV cache is populated but logits are discarded.</li>
21-
* <li>{@link #tornadoVMForwardDecode} — full execution including logits.
22-
* Called for each generated token.</li>
23-
* </ul>
29+
* <p>Uses dedicated layer classes that carry correct cross-graph
30+
* {@code consumeFromDevice} source names for both CUDA-graph and interpreter
31+
* (no-CUDA-graph) mode. All graphs are owned by this plan and built from scratch —
32+
* no reuse of the standard execution path.</p>
2433
*
25-
* <p>Graph layout (same as {@link TornadoVMMasterPlanStandard}):</p>
34+
* <p>Graph layout (N+2 graphs total):</p>
2635
* <pre>
27-
* graph 0 : preprocessing (embedding setup)
28-
* graphs 1..N : transformer layers
29-
* graph N+1 : logits projection (final RMSNorm + wcls matmul)
36+
* [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
37+
* [1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
38+
* [N+1] logits final RMSNorm + wcls matmul
3039
* </pre>
40+
*
41+
* <p>Two distinct forward passes:</p>
42+
* <ul>
43+
* <li>{@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips logits.
44+
* KV cache is populated for each prompt token; logits are discarded.</li>
45+
* <li>{@link #tornadoVMForwardDecode} — full pass including logits.
46+
* Called for each generated token.</li>
47+
* </ul>
3148
*/
3249
public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan {
3350

34-
private final State state;
35-
private final Model model;
36-
private final Configuration config;
51+
private final LlamaState state;
52+
private final Model model;
53+
private final LlamaConfiguration config;
54+
private final int N; // numberOfLayers
55+
private final TornadoExecutionPlan executionPlan;
56+
private final GridScheduler gridScheduler;
3757

38-
GenericLayerPlanner tornadoVMLayerPlanner;
39-
public TornadoExecutionPlan executionPlan;
58+
// ── Graph-index helpers ───────────────────────────────────────────────────
59+
private int activationIdx() { return 0; }
60+
private int layerIdx(int i) { return 1 + i; }
61+
private int logitsIdx() { return N + 1; }
4062

41-
public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) {
63+
// ── Construction ─────────────────────────────────────────────────────────
64+
TornadoVMMasterPlanWithPrefillDecode(State initialState, Model model) {
4265
long startTime = System.nanoTime();
4366
long planCreationTime = 0;
4467
long warmupTime = 0;
@@ -47,81 +70,127 @@ public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) {
4770
System.err.println("\nStarting TornadoVM initialization...");
4871
}
4972

50-
this.state = state;
51-
this.model = model;
52-
this.config = model.configuration();
73+
this.state = (LlamaState) initialState;
74+
this.model = model;
75+
this.config = (LlamaConfiguration) model.configuration();
76+
this.N = config.numberOfLayers();
77+
this.gridScheduler = new GridScheduler();
5378
this.executionPlan = createExecutionPlan();
5479

5580
if (ENABLE_TORNADOVM_INIT_TIME) {
5681
planCreationTime = System.nanoTime();
57-
System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
82+
System.err.printf("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n",
83+
(planCreationTime - startTime) / 1_000_000.0);
5884
}
5985

6086
if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
6187
executionPlan.withPreCompilation();
6288

6389
if (ENABLE_TORNADOVM_INIT_TIME) {
6490
warmupTime = System.nanoTime();
65-
System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0);
91+
System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n",
92+
(warmupTime - planCreationTime) / 1_000_000.0);
6693
}
6794

6895
forceCopyInReadOnlyData();
6996

7097
if (ENABLE_TORNADOVM_INIT_TIME) {
7198
long copyTime = System.nanoTime();
72-
System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0);
99+
System.err.printf("Transfer read-only weights to GPU: %.2f ms\n",
100+
(copyTime - warmupTime) / 1_000_000.0);
73101
System.err.printf("Finished TornadoVM initialization...\n \n");
74102
}
75103
}
76104

105+
// ── Activation graph ─────────────────────────────────────────────────────
106+
77107
/**
78-
* Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*.
79-
* Prefill is token-by-token but does not compute logits.
108+
* Graph 0: single-token FP16 → FP32.
109+
*
110+
* <p>Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that
111+
* decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}.
112+
* The KV cache is <em>not</em> managed here — it is allocated on the first forward pass
113+
* by decode layer 0 via {@code FIRST_EXECUTION}.</p>
80114
*/
115+
private TaskGraph buildActivationGraph(KernelContext ctx) {
116+
return new TaskGraph("decodeActivation")
117+
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
118+
.task("updateX", TransformerComputeKernels::convertFP16toFP32,
119+
ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
120+
.persistOnDevice(state.wrapX);
121+
}
122+
123+
// ── Plan construction ─────────────────────────────────────────────────────
124+
81125
@Override
82126
public TornadoExecutionPlan createExecutionPlan() {
83-
GGMLType weightType = model.weights().getWeightType();
84-
this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model);
85-
var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs();
86-
var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]);
87-
return new TornadoExecutionPlan(taskGraphArray);
127+
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
128+
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
129+
130+
List<ImmutableTaskGraph> all = new ArrayList<>(N + 2);
131+
132+
// [0] Activation ──────────────────────────────────────────────────────
133+
KernelContext actCtx = new KernelContext();
134+
all.add(buildActivationGraph(actCtx).snapshot());
135+
gridScheduler.addWorkerGrid("decodeActivation.updateX",
136+
WorkerGridFactory.genericWorker(config.dim(), 128));
137+
138+
// [1..N] Decode layer graphs ──────────────────────────────────────────
139+
// Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX).
140+
// Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode.
141+
LlamaFP16FFNLayersPrefillDecode decodeLayers =
142+
new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
143+
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
144+
decodeLayers.updateGridScheduler(gridScheduler);
145+
146+
// [N+1] Logits ────────────────────────────────────────────────────────
147+
// LogitsFP16LayerDecode re-persists the KV cache so the pointer survives
148+
// the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens.
149+
LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode(
150+
"logits", state, weights, config,
151+
decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
152+
all.add(logitsLayer.getImmutableTaskGraph());
153+
logitsLayer.updateGridScheduler(gridScheduler);
154+
155+
return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
88156
}
89157

158+
// ── Initialisation ────────────────────────────────────────────────────────
159+
160+
/** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
90161
@Override
91162
public void forceCopyInReadOnlyData() {
92163
state.wrapX.clear();
93164
state.positionHolder.init(0);
94165

95-
executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
96-
97-
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
98-
executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
166+
for (int i = 0; i <= logitsIdx(); i++) {
167+
var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
168+
if (CUDA_GRAPHS) g.withCUDAGraph();
169+
g.execute();
99170
}
100-
101-
executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
102171
}
103172

173+
// ── Forward passes ────────────────────────────────────────────────────────
174+
104175
/**
105-
* GPU prefill forward: runs preprocessing + all transformer layers, skips logits.
106-
* KV cache is populated; logits projection is intentionally omitted.
176+
* GPU prefill forward: activation + all transformer layers, logits skipped.
177+
* KV cache is populated for each prompt token.
107178
*
108179
* @param position sequence position being processed
109180
*/
110181
public void tornadoVMForwardPrefill(int position) {
111-
// Graph 0: preprocessing
112-
executionPlan.withGraph(0)
113-
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
114-
.execute();
182+
var prefillActivation = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler);
183+
if (CUDA_GRAPHS) prefillActivation.withCUDAGraph();
184+
prefillActivation.execute();
115185

116186
state.positionHolder.set(0, position);
117187
state.temp.clear();
118188
state.tempFFN.clear();
119189

120-
// Graphs 1..N: transformer layers (logits graph N+1 intentionally skipped)
121-
for (int layer = 1; layer <= config.numberOfLayers(); layer++) {
122-
executionPlan.withGraph(layer)
123-
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
124-
.execute();
190+
for (int layer = 0; layer < N; layer++) {
191+
var prefillLayer = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler);
192+
if (CUDA_GRAPHS) prefillLayer.withCUDAGraph();
193+
prefillLayer.execute();
125194
}
126195
}
127196

@@ -137,27 +206,25 @@ public FloatArray tornadoVMForwardDecode(int position) {
137206

138207
@Override
139208
public FloatArray tornadoVMForwardExecuteLayered(int position) {
140-
var preGraph = executionPlan.withGraph(0)
141-
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
142-
if (CUDA_GRAPHS) preGraph.withCUDAGraph();
143-
preGraph.execute();
209+
var act = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler);
210+
if (CUDA_GRAPHS) act.withCUDAGraph();
211+
act.execute();
144212

145213
state.positionHolder.set(0, position);
146214
state.temp.clear();
147215
state.tempFFN.clear();
148216

149-
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
150-
executionPlan.withGraph(1 + layer)
151-
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
152-
.execute();
217+
for (int layer = 0; layer < N; layer++) {
218+
var l = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler);
219+
if (CUDA_GRAPHS) l.withCUDAGraph();
220+
l.execute();
153221
}
154222

155223
state.tempLogits.clear();
156224
state.wrapLogits.clear();
157-
var logitsGraph = executionPlan.withGraph(config.numberOfLayers() + 1)
158-
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
159-
if (CUDA_GRAPHS) logitsGraph.withCUDAGraph();
160-
logitsGraph.execute();
225+
var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler);
226+
if (CUDA_GRAPHS) logits.withCUDAGraph();
227+
logits.execute();
161228

162229
return state.wrapLogits;
163230
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode;
2+
3+
import org.beehive.gpullama3.inference.state.LlamaState;
4+
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
5+
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
6+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
7+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
8+
import uk.ac.manchester.tornado.api.TaskGraph;
9+
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
10+
11+
/**
12+
* Decode FFN layers for the single-token prefill/decode plan
13+
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}).
14+
*
15+
* <p>Combines two concerns:</p>
16+
* <ol>
17+
* <li><b>Correct predecessor names</b> — overrides {@link #predecessorGraphName} so that
18+
* every cross-graph {@code consumeFromDevice} uses the explicit-source form required
19+
* by TornadoVM's interpreter (non-CUDA-graph) mode. Layer 0 names {@code "decodeActivation"};
20+
* layers 1+ name {@code "layer_"+(k-1)}.</li>
21+
* <li><b>KV-cache allocation</b> — layer 0 delegates to the base-class
22+
* {@link #configureLayerDataTransfers} which includes {@code FIRST_EXECUTION} for
23+
* {@code wrapKeyCache} and {@code wrapValueCache}. This allocates the KV-cache device
24+
* buffers on the very first forward pass; subsequent passes skip the re-upload and the
25+
* GPU accumulates entries in place. Layers 1+ use {@code consumeFromDevice} with an
26+
* explicit predecessor name for all objects, matching {@link LlamaFP16FFNLayersDecode}.</li>
27+
* </ol>
28+
*
29+
* <p>The activation graph ("decodeActivation") only persists {@code wrapX} — it does not
30+
* touch the KV cache. Layer 0 is therefore the sole allocator of the KV cache, which avoids
31+
* the NPE in {@code executeAlloc} that occurs when {@code consumeFromDevice} targets an object
32+
* whose device buffer was never properly allocated via {@code FIRST_EXECUTION}.</p>
33+
*/
34+
public class LlamaFP16FFNLayersPrefillDecode extends LlamaFP16FFNLayers {
35+
36+
public LlamaFP16FFNLayersPrefillDecode(String taskGraph, LlamaState state,
37+
LlamaTornadoWeights weights, LlamaConfiguration config,
38+
SchedulerType schedulerType) {
39+
super(taskGraph, state, weights, config, schedulerType);
40+
}
41+
42+
/**
43+
* Layer 0 receives {@code wrapX} from the decode activation graph;
44+
* layers 1+ receive it from the previous decode layer.
45+
*/
46+
@Override
47+
protected String predecessorGraphName(int layerIndex) {
48+
return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1);
49+
}
50+
51+
/**
52+
* Layer 0: delegates to the base class (FIRST_EXECUTION for wrapKeyCache/wrapValueCache +
53+
* all working buffers). KV cache is allocated here on the first forward pass.
54+
*
55+
* <p>Layers 1+: mirrors {@link LlamaFP16FFNLayersDecode} — {@code consumeFromDevice} with
56+
* an explicit predecessor name for every object, required by interpreter mode.</p>
57+
*/
58+
@Override
59+
protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
60+
if (layerIndex == 0) {
61+
return super.configureLayerDataTransfers(layer, 0);
62+
}
63+
String pred = "layer_" + (layerIndex - 1);
64+
layer.consumeFromDevice(pred,
65+
context,
66+
state.wrapXb, state.wrapXb2,
67+
state.wrapQ, state.wrapK, state.wrapV,
68+
state.wrapKeyCache, state.wrapValueCache,
69+
state.wrapAtt, state.wrapHb,
70+
state.positionHolder, state.wrapXbFP16);
71+
return layer;
72+
}
73+
}

0 commit comments

Comments
 (0)