Skip to content

Commit 2dd506c

Browse files
[prf/dec][refactor] Standardize task graph and grid scheduler naming for prefill and decode paths in TornadoVM
1 parent 1cbe491 commit 2dd506c

3 files changed

Lines changed: 32 additions & 37 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste
110110

111111
/** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */
112112
private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
113-
return new TaskGraph("batchActivation")
113+
return new TaskGraph("prefillActivation")
114114
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
115115
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
116-
.task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32,
116+
.task("updateX", TransformerComputeKernels::convertFP16toFP32,
117117
ctx, state.embeddingXBatch, state.wrapXBatch)
118118
.persistOnDevice(state.wrapXBatch);
119119
}
@@ -128,7 +128,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
128128
* not forwarded in interpreter (non-CUDA-graph) mode.</p>
129129
*/
130130
private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) {
131-
return new TaskGraph("decodeActivationUpdate")
131+
return new TaskGraph("decodeActivation")
132132
.consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through
133133
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
134134
.task("updateX",
@@ -153,7 +153,7 @@ public TornadoExecutionPlan createExecutionPlan() {
153153
// [0] Batch prefill activation ────────────────────────────────────────────────
154154
KernelContext batchActCtx = new KernelContext();
155155
all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
156-
gridScheduler.addWorkerGrid("batchActivation.batchUpdateX",
156+
gridScheduler.addWorkerGrid("prefillActivation.updateX",
157157
WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
158158

159159
// [1..N] Batch prefill layer graphs ───────────────────────────────────────────
@@ -165,14 +165,14 @@ public TornadoExecutionPlan createExecutionPlan() {
165165
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
166166
KernelContext decodeActCtx = new KernelContext();
167167
all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
168-
gridScheduler.addWorkerGrid("decodeActivationUpdate.updateX",
168+
gridScheduler.addWorkerGrid("decodeActivation.updateX",
169169
WorkerGridFactory.genericWorker(config.dim(), 128));
170170

171171
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
172172
// Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
173173
LlamaFP16FFNLayersDecode decodeLayers =
174174
new LlamaFP16FFNLayersDecode(
175-
"llamaFFNDecode", state, weights, config, schedulerType);
175+
"decode", state, weights, config, schedulerType);
176176
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
177177
decodeLayers.updateGridScheduler(gridScheduler);
178178

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
4343
*/
4444
@Override
4545
protected String predecessorGraphName(int layerIndex) {
46-
return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1);
46+
return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1);
4747
}
4848

4949
@Override
@@ -60,7 +60,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex)
6060
state.wrapQ, state.wrapK, state.wrapV,
6161
state.wrapAtt, state.wrapHb, state.wrapXbFP16);
6262
// Explicit source — must match the TaskGraph name in buildDecodeActivationGraph().
63-
layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache);
63+
layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache);
6464
} else {
6565
// Layers 1+: use explicit predecessor name for ALL consumed objects.
6666
// Calling super here would use the no-arg form (source key = own graph name),

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,32 @@ public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights
5151

5252
// @formatter:off
5353
private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
54-
String graphName = "batchLayer_" + layerIndex;
54+
String graphName = "batchPrefillLayer_" + layerIndex;
5555
if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName;
5656

57-
TaskGraph layer = new TaskGraph(graphName);
57+
TaskGraph batchPrefillLayer = new TaskGraph(graphName);
5858

5959
// ── Data Transfers ─────────────────────────────────────────────────────
6060
if (layerIndex == 0) {
6161
// batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION
62-
layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder);
62+
batchPrefillLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder);
6363
// Allocate persistent GPU-side intermediates once
64-
layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
64+
batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
6565
context,
6666
state.attnScaleBatch, state.ffnScaleBatch,
6767
state.wrapXbFP16Batch,
6868
state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
6969
state.wrapXbBatch,
7070
state.wrapHbBatch,
7171
state.wrapKeyCache, state.wrapValueCache);
72-
// wrapXBatch produced by the batch activation graph.
73-
// Explicit source name required: the no-arg form uses the current graph's own
74-
// name ("batchLayer_0") which never matches "batchActivation" in interpreter mode,
75-
// causing wrapXBatch to be re-uploaded from host (zeros) instead of using the
76-
// FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel.
77-
layer.consumeFromDevice("batchActivation", state.wrapXBatch);
72+
// wrapXBatch produced by the prefillActivation graph and persists in device memory
73+
// to consume it from there we should use the explicit uniqueTaskGraph name
74+
// the no-arg form would use current graph name, which causes NPE without CUDA Graphs
75+
batchPrefillLayer.consumeFromDevice("prefillActivation", state.wrapXBatch);
7876
} else {
79-
// Explicit predecessor name for all objects.
80-
// The no-arg form would use "batchLayer_k" as the source key, which never matches
81-
// "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from
82-
// host (zeros or stale), corrupting the KV cache written by the previous layer.
83-
String pred = "batchLayer_" + (layerIndex - 1);
84-
layer.consumeFromDevice(pred,
77+
// for the same reasons as above, we should use the explicit uniqueTaskGraph name to consume
78+
String pred = "batchPrefillLayer_" + (layerIndex - 1);
79+
batchPrefillLayer.consumeFromDevice(pred,
8580
context,
8681
state.wrapXBatch,
8782
state.wrapXbFP16Batch,
@@ -94,7 +89,7 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
9489
}
9590

9691
// Per-layer weights: upload once
97-
layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
92+
batchPrefillLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
9893
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
9994
weights.wqLayered[layerIndex].asHalfFloatArray(),
10095
weights.wkLayered[layerIndex].asHalfFloatArray(),
@@ -110,18 +105,18 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
110105
int hidDim = config.hiddenDim();
111106

112107
// ── Attention Block ────────────────────────────────────────────────────
113-
layer.task("batch_attn_rms",
108+
batchPrefillLayer.task("batch_attn_rms",
114109
TransformerBatchPrefillKernels::batchedRmsReduce,
115110
context, state.wrapXBatch, state.attnScaleBatch,
116111
dim, config.rmsNormEps());
117112

118-
layer.task("batch_attn_rms_apply",
113+
batchPrefillLayer.task("batch_attn_rms_apply",
119114
TransformerBatchPrefillKernels::batchedRmsApplyFP16,
120115
context, state.wrapXbFP16Batch, state.wrapXBatch,
121116
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
122117
state.attnScaleBatch, dim);
123118

124-
layer.task("batch_qkv",
119+
batchPrefillLayer.task("batch_qkv",
125120
TransformerBatchPrefillKernels::batchedFusedQKVMatmul,
126121
context,
127122
state.wrapXbFP16Batch,
@@ -131,34 +126,34 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
131126
weights.wvLayered[layerIndex].asHalfFloatArray(),
132127
dim, kvDim, LOCAL_WORK_GROUP_SIZE);
133128

134-
layer.task("batch_rope_kv",
129+
batchPrefillLayer.task("batch_rope_kv",
135130
TransformerBatchPrefillKernels::batchedRopeWithKVCache,
136131
context, state.batchStartPosHolder,
137132
state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
138133
state.wrapKeyCache, state.wrapValueCache,
139134
kvDim, config.headSize(), layerIndex, config.contextLength(), dim);
140135

141-
layer.task("batch_attention",
136+
batchPrefillLayer.task("batch_attention",
142137
TransformerBatchPrefillKernels::batchedFlashAttention,
143138
context, state.batchStartPosHolder,
144139
state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache,
145140
state.wrapXbBatch,
146141
config.numberOfHeads(), config.headSize(),
147142
kvDim, config.kvMul(), layerIndex, config.contextLength(), dim);
148143

149-
layer.task("batch_attn_out",
144+
batchPrefillLayer.task("batch_attn_out",
150145
TransformerBatchPrefillKernels::batchedMatVecWithResidual,
151146
context, state.wrapXbBatch, state.wrapXBatch,
152147
weights.woLayered[layerIndex].asHalfFloatArray(),
153148
dim, dim, LOCAL_WORK_GROUP_SIZE);
154149

155150
// ── FFN Block ──────────────────────────────────────────────────────────
156-
layer.task("batch_ffn_rms",
151+
batchPrefillLayer.task("batch_ffn_rms",
157152
TransformerBatchPrefillKernels::batchedFFNRmsReduce,
158153
context, state.wrapXBatch, state.ffnScaleBatch,
159154
dim, config.rmsNormEps());
160155

161-
layer.task("batch_ffn_gate_up",
156+
batchPrefillLayer.task("batch_ffn_gate_up",
162157
TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp,
163158
context, state.wrapXBatch, state.wrapHbBatch,
164159
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
@@ -167,17 +162,17 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
167162
weights.w3Layered[layerIndex].asHalfFloatArray(),
168163
dim, hidDim, LOCAL_WORK_GROUP_SIZE);
169164

170-
layer.task("batch_ffn_down",
165+
batchPrefillLayer.task("batch_ffn_down",
171166
TransformerBatchPrefillKernels::batchedMatVecWithResidual,
172167
context, state.wrapHbBatch, state.wrapXBatch,
173168
weights.w2Layered[layerIndex].asHalfFloatArray(),
174169
hidDim, dim, LOCAL_WORK_GROUP_SIZE);
175170

176171
// Persist wrapXBatch for the next layer, and KV cache so the decode
177172
// layers can consume it via the activation graph pass-through.
178-
layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache);
173+
batchPrefillLayer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache);
179174

180-
return layer;
175+
return batchPrefillLayer;
181176
}
182177
// @formatter:on
183178

@@ -218,7 +213,7 @@ public void updateGridScheduler(GridScheduler scheduler) {
218213
batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
219214

220215
for (int i = 0; i < config.numberOfLayers(); i++) {
221-
String p = "batchLayer_" + i + ".";
216+
String p = "batchPrefillLayer_" + i + ".";
222217
scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker);
223218
scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker);
224219
scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker);

0 commit comments

Comments
 (0)