Skip to content

Commit 869c67d

Browse files
[prf/dec] Refactor TornadoVM execution plans to unify GPU paths for standard, prefill-decode, and batched-prefill-decode setups.
1 parent 9aff199 commit 869c67d

4 files changed

Lines changed: 220 additions & 152 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package org.beehive.gpullama3.tornadovm;
22

3-
import org.beehive.gpullama3.inference.state.LlamaState;
43
import org.beehive.gpullama3.inference.state.State;
54
import org.beehive.gpullama3.model.Model;
5+
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
66
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
77

88
/**
@@ -35,14 +35,14 @@ public interface TornadoVMMasterPlan {
3535
int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1);
3636

3737
/**
38-
* Factory: creates, JIT-compiles, and warms up the appropriate plan.
38+
* Factory: creates, JIT-compiles, and warms up the appropriate TornadoVMMasterPlan.
3939
*
4040
* <p>When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1},
4141
* a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned.
4242
* Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline
4343
* path and the sequential prefill/decode path when batch size is 1).</p>
4444
*
45-
* @param state the model state (must be {@link LlamaState} when batch size {@code > 1})
45+
* @param state the model state
4646
* @param model the model instance
4747
* @return the initialized plan, also stored via {@link Model#setTornadoVMPlan}
4848
*/
@@ -51,29 +51,26 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
5151

5252
if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) {
5353
// GPU path with batched prefill/decode
54-
plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan(
55-
(LlamaState) state, model, PREFILL_BATCH_SIZE);
54+
plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model);
5655
} else if (WITH_PREFILL_DECODE) {
5756
// GPU path with simple prefill/decode
58-
plan = TornadoVMMasterPlanWithPrefillDecode.initialize(state, model);
57+
plan = new TornadoVMMasterPlanWithPrefillDecode(state, model);
5958
} else {
6059
// GPU path with no prefill/decode
61-
plan = TornadoVMMasterPlanStandard.initialize(state, model);
60+
plan = new TornadoVMMasterPlanStandard(state, model);
6261
}
6362
model.setTornadoVMPlan(plan);
6463
return plan;
6564
}
6665

6766
/**
68-
* Single-token forward pass returning output logits.
69-
*
70-
* <p>Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM})
71-
* and the Phase 2 sequential decode path. Not applicable to
72-
* {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.</p>
73-
*
74-
* @param position sequence position of the current token
75-
* @return logits array for token sampling
67+
* Creates the appropriate {@link TornadoExecutionPlan} instance
68+
* for the given {@link Model} and {@link State}.
7669
*/
70+
TornadoExecutionPlan createExecutionPlan();
71+
72+
void forceCopyInReadOnlyData();
73+
7774
FloatArray tornadoVMForwardExecuteLayered(int position);
7875

7976
/** Releases all device memory held by this plan. */

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,14 @@
1919
*/
2020
public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan {
2121

22-
public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
23-
2422
private final State state;
23+
private final Model model;
2524
private final Configuration config;
26-
public TornadoExecutionPlan executionPlan;
25+
2726
GenericLayerPlanner tornadoVMLayerPlanner;
27+
public TornadoExecutionPlan executionPlan;
2828

2929
public TornadoVMMasterPlanStandard(State state, Model model) {
30-
this.tornadoVMLayerPlanner = createPlanner(state, model);
31-
this.executionPlan = createExecutionPlan();
32-
this.state = state;
33-
this.config = model.configuration();
34-
}
35-
36-
/**
37-
* Initializes and warms up the standard TornadoVM plan.
38-
*
39-
* @param state the model state containing KV cache
40-
* @param model the model instance
41-
* @return the initialized plan ready for inference
42-
*/
43-
static TornadoVMMasterPlanStandard initialize(State state, Model model) {
4430
long startTime = System.nanoTime();
4531
long planCreationTime = 0;
4632
long warmupTime = 0;
@@ -49,43 +35,52 @@ static TornadoVMMasterPlanStandard initialize(State state, Model model) {
4935
System.err.println("\nStarting TornadoVM initialization...");
5036
}
5137

52-
TornadoVMMasterPlanStandard tornadoVMPlan = new TornadoVMMasterPlanStandard(state, model);
38+
this.state = state;
39+
this.model = model;
40+
this.config = model.configuration();
41+
42+
this.executionPlan = createExecutionPlan();
5343

5444
if (ENABLE_TORNADOVM_INIT_TIME) {
5545
planCreationTime = System.nanoTime();
56-
System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
46+
System.err.printf("TornadoVM GPU standard execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
5747
}
5848

59-
if (CUDA_GRAPHS) tornadoVMPlan.executionPlan.withAllGraphs().withCUDAGraph();
60-
tornadoVMPlan.executionPlan.withPreCompilation();
49+
if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
50+
executionPlan.withPreCompilation();
6151

6252
if (ENABLE_TORNADOVM_INIT_TIME) {
6353
warmupTime = System.nanoTime();
6454
System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0);
6555
}
6656

67-
tornadoVMPlan.forceCopyInReadOnlyDataLayered();
57+
forceCopyInReadOnlyData();
6858

6959
if (ENABLE_TORNADOVM_INIT_TIME) {
7060
long copyTime = System.nanoTime();
7161
System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0);
7262
System.err.printf("Finished TornadoVM initialization...\n \n");
7363
}
74-
75-
return tornadoVMPlan;
7664
}
7765

78-
private TornadoExecutionPlan createExecutionPlan() {
66+
// @Override
67+
// public GenericLayerPlanner createPlanner() {
68+
// GGMLType weightType = model.weights().getWeightType();
69+
// return QuantizationPlannerFactory.create(weightType, state, model);
70+
// }
71+
72+
/**
73+
* Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass.
74+
*/
75+
@Override
76+
public TornadoExecutionPlan createExecutionPlan() {
77+
GGMLType weightType = model.weights().getWeightType();
78+
this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model);
7979
var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs();
8080
var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]);
8181
return new TornadoExecutionPlan(taskGraphArray);
8282
}
8383

84-
private GenericLayerPlanner createPlanner(State state, Model model) {
85-
GGMLType weightType = model.weights().getWeightType();
86-
return QuantizationPlannerFactory.create(weightType, state, model);
87-
}
88-
8984
@Override
9085
public FloatArray tornadoVMForwardExecuteLayered(int position) {
9186
// @formatter:off
@@ -126,7 +121,8 @@ private int getFinalLogitsGraphIndex() {
126121
return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1;
127122
}
128123

129-
public void forceCopyInReadOnlyDataLayered() {
124+
@Override
125+
public void forceCopyInReadOnlyData() {
130126
state.wrapX.clear();
131127
state.positionHolder.init(0);
132128

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.beehive.gpullama3.tornadovm;
22

33
import org.beehive.gpullama3.inference.state.LlamaState;
4+
import org.beehive.gpullama3.inference.state.State;
45
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
56
import org.beehive.gpullama3.model.Model;
67
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
@@ -25,7 +26,7 @@
2526
import java.util.List;
2627

2728
/**
28-
* Unified GPU execution plan for Phase 4: batched prefill + single-token decode.
29+
* GPU execution plan for batched prefill + single-token decode.
2930
*
3031
* <p>A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
3132
* ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via
@@ -50,10 +51,8 @@
5051
*/
5152
public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan {
5253

53-
private static final boolean ENABLE_TIMING =
54-
Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
55-
5654
private final LlamaState state;
55+
private final Model model;
5756
private final LlamaConfiguration config;
5857
private final int batchSize;
5958
private final int N; // numberOfLayers
@@ -68,55 +67,44 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste
6867
private int logitsIdx() { return 2 * N + 2; }
6968

7069
// ── Construction ─────────────────────────────────────────────────────────
71-
private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model, int batchSize) {
72-
this.state = state;
70+
TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) {
71+
long startTime = System.nanoTime();
72+
long planCreationTime = 0;
73+
long warmupTime = 0;
74+
75+
if (ENABLE_TORNADOVM_INIT_TIME) {
76+
System.err.println("\nStarting TornadoVM initialization...");
77+
}
78+
79+
this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now
80+
this.model = model;
7381
this.config = (LlamaConfiguration) model.configuration();
74-
this.batchSize = batchSize;
82+
this.batchSize = PREFILL_BATCH_SIZE;
7583
this.N = config.numberOfLayers();
7684

77-
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
78-
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
79-
80-
List<ImmutableTaskGraph> all = new ArrayList<>(2 * N + 3);
81-
GridScheduler scheduler = new GridScheduler();
85+
this.gridScheduler = new GridScheduler();
86+
this.executionPlan = createExecutionPlan();
8287

83-
// [0] Batch prefill activation ────────────────────────────────────────────────
84-
KernelContext batchActCtx = new KernelContext();
85-
all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
86-
scheduler.addWorkerGrid("batchActivation.batchUpdateX",
87-
WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
88+
if (ENABLE_TORNADOVM_INIT_TIME) {
89+
planCreationTime = System.nanoTime();
90+
System.err.printf("TornadoVM GPU batched prefill/decode execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0);
91+
}
8892

89-
// [1..N] Batch prefill layer graphs ───────────────────────────────────────────
90-
LlamaFP16LayersBatchPrefill batchLayers =
91-
new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
92-
all.addAll(batchLayers.getLayerImmutableTaskGraphs());
93-
batchLayers.updateGridScheduler(scheduler);
93+
if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
94+
executionPlan.withPreCompilation();
9495

95-
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
96-
KernelContext decodeActCtx = new KernelContext();
97-
all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
98-
scheduler.addWorkerGrid("decodeActivationUpdate.updateX",
99-
WorkerGridFactory.genericWorker(config.dim(), 128));
96+
if (ENABLE_TORNADOVM_INIT_TIME) {
97+
warmupTime = System.nanoTime();
98+
System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0);
99+
}
100100

101-
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
102-
// Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
103-
LlamaFP16FFNLayersDecode decodeLayers =
104-
new LlamaFP16FFNLayersDecode(
105-
"llamaFFNDecode", state, weights, config, schedulerType);
106-
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
107-
decodeLayers.updateGridScheduler(scheduler);
101+
forceCopyInReadOnlyData();
108102

109-
// [2N+2] Logits ───────────────────────────────────────────────────────
110-
// LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
111-
// at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
112-
// KV-cache pointer survives the logits → decode-activation boundary across tokens.
113-
LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
114-
decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
115-
all.add(logitsLayer.getImmutableTaskGraph());
116-
logitsLayer.updateGridScheduler(scheduler);
117-
118-
this.gridScheduler = scheduler;
119-
this.executionPlan = new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
103+
if (ENABLE_TORNADOVM_INIT_TIME) {
104+
long copyTime = System.nanoTime();
105+
System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0);
106+
System.err.printf("Finished TornadoVM initialization...\n \n");
107+
}
120108
}
121109

122110
// ── Batch Prefill Activation graphs ─────────────────────────────────────────────────────
@@ -164,41 +152,58 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatch
164152
.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
165153
}
166154

167-
// ── Static factory ────────────────────────────────────────────────────────
168-
169155
/**
170-
* Creates, JIT-compiles, and warms up the unified plan.
171-
* Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}.
156+
* Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*.
172157
*/
173-
public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan(
174-
LlamaState state, Model model, int batchSize) {
158+
@Override
159+
public TornadoExecutionPlan createExecutionPlan() {
160+
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
161+
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
175162

176-
long t0 = System.nanoTime();
177-
TornadoVMMasterPlanWithBatchPrefillDecode plan =
178-
new TornadoVMMasterPlanWithBatchPrefillDecode(state, model, batchSize);
163+
List<ImmutableTaskGraph> all = new ArrayList<>(2 * N + 3);
179164

180-
if (ENABLE_TIMING)
181-
System.err.printf("[BatchPlan] Graph construction: %.2f ms%n",
182-
(System.nanoTime() - t0) / 1e6);
165+
// [0] Batch prefill activation ────────────────────────────────────────────────
166+
KernelContext batchActCtx = new KernelContext();
167+
all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
168+
gridScheduler.addWorkerGrid("batchActivation.batchUpdateX",
169+
WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
183170

184-
if (CUDA_GRAPHS) plan.executionPlan.withAllGraphs().withCUDAGraph();
185-
plan.executionPlan.withPreCompilation();
171+
// [1..N] Batch prefill layer graphs ───────────────────────────────────────────
172+
LlamaFP16LayersBatchPrefill batchLayers =
173+
new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
174+
all.addAll(batchLayers.getLayerImmutableTaskGraphs());
175+
batchLayers.updateGridScheduler(gridScheduler);
186176

187-
if (ENABLE_TIMING)
188-
System.err.printf("[BatchPlan] JIT compilation: %.2f ms%n",
189-
(System.nanoTime() - t0) / 1e6);
177+
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
178+
KernelContext decodeActCtx = new KernelContext();
179+
all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
180+
gridScheduler.addWorkerGrid("decodeActivationUpdate.updateX",
181+
WorkerGridFactory.genericWorker(config.dim(), 128));
190182

191-
plan.forceCopyInReadOnlyData();
183+
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
184+
// Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
185+
LlamaFP16FFNLayersDecode decodeLayers =
186+
new LlamaFP16FFNLayersDecode(
187+
"llamaFFNDecode", state, weights, config, schedulerType);
188+
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
189+
decodeLayers.updateGridScheduler(gridScheduler);
192190

193-
if (ENABLE_TIMING)
194-
System.err.printf("[BatchPlan] Init complete: %.2f ms%n",
195-
(System.nanoTime() - t0) / 1e6);
191+
// [2N+2] Logits ───────────────────────────────────────────────────────
192+
// LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
193+
// at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
194+
// KV-cache pointer survives the logits → decode-activation boundary across tokens.
195+
LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
196+
decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
197+
all.add(logitsLayer.getImmutableTaskGraph());
198+
logitsLayer.updateGridScheduler(gridScheduler);
196199

197-
return plan;
200+
return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
198201
}
199202

203+
200204
/** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
201-
private void forceCopyInReadOnlyData() {
205+
@Override
206+
public void forceCopyInReadOnlyData() {
202207
state.wrapXBatch.clear();
203208
state.wrapX.clear();
204209
state.positionHolder.init(0);

0 commit comments

Comments
 (0)