Skip to content

Commit 2b1aaba

Browse files
[prf/dec] Fix KV-cache propagation bug from prefill to decode path and refactor task graph consumption logic
Introduce `LogitsFP16LayerDecode` with KV-cache pass-through. Override `consumeFromDevice` and `persistOnDevice` in LlamaFFN layers to fix cross-graph propagation for both CUDA and interpreter modes.
1 parent 6128793 commit 2b1aaba

6 files changed

Lines changed: 187 additions & 31 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
1111
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
1212
import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
13-
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
13+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
1414
import uk.ac.manchester.tornado.api.GridScheduler;
1515
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1616
import uk.ac.manchester.tornado.api.KernelContext;
@@ -94,8 +94,8 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
9494

9595
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
9696
KernelContext decodeActCtx = new KernelContext();
97-
all.add(buildDecodeActivationGraph(decodeActCtx).snapshot());
98-
scheduler.addWorkerGrid("activationUpdate.updateX",
97+
all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
98+
scheduler.addWorkerGrid("decodeActivationUpdate.updateX",
9999
WorkerGridFactory.genericWorker(config.dim(), 128));
100100

101101
// [N+2..2N+1] Decode layer graphs ────────────────────────────────────
@@ -107,7 +107,10 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
107107
decodeLayers.updateGridScheduler(scheduler);
108108

109109
// [2N+2] Logits ───────────────────────────────────────────────────────
110-
LogitsFP16Layer logitsLayer = new LogitsFP16Layer("logits", state, weights, config,
110+
// LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
111+
// at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
112+
// KV-cache pointer survives the logits → decode-activation boundary across tokens.
113+
LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
111114
decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
112115
all.add(logitsLayer.getImmutableTaskGraph());
113116
logitsLayer.updateGridScheduler(scheduler);
@@ -123,9 +126,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
123126
return new TaskGraph("batchActivation")
124127
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
125128
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
126-
.task("batchUpdateX",
127-
(KernelContext c, HalfFloatArray src, FloatArray dst) ->
128-
dst.set(c.globalIdx, src.get(c.globalIdx).getFloat32()),
129+
.task("batchUpdateX", TransformerComputeKernels::convertFP16toFP32,
129130
ctx, state.embeddingXBatch, state.wrapXBatch)
130131
.persistOnDevice(state.wrapXBatch);
131132
}
@@ -139,17 +140,24 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
139140
* Both halves of the chain are required; without the re-persist the pointer is
140141
* not forwarded in interpreter (non-CUDA-graph) mode.</p>
141142
*/
142-
private TaskGraph buildDecodeActivationGraph(KernelContext ctx) {
143-
return new TaskGraph("activationUpdate")
144-
.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache) // KV pass-through
145-
// .transferToDevice(DataTransferMode.EVERY_EXECUTION,
146-
// state.wrapKeyCache,
147-
// state.wrapValueCache)
148-
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
143+
private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) {
144+
// System.out.println("lastBatchLayerID = " + lastBatchLayerID);
145+
// System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString());
146+
// System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString());
147+
// System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString());
148+
return new TaskGraph("decodeActivationUpdate")
149+
.consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through
150+
//.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV)
151+
//.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
149152
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
150153
.task("updateX",
151154
TransformerComputeKernels::convertFP16toFP32,
152155
ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
156+
// // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe
157+
// .task("dbgKV",
158+
// TransformerComputeKernels::dbgCopyFirst8,
159+
// state.wrapKeyCache, debugKV)
160+
// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV)
153161
// wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache
154162
// re-persisted so updatePersistedObjectState() propagates the device
155163
// pointer to decode layer 0's consumeFromDevice without CUDA graphs.
@@ -197,6 +205,7 @@ private void forceCopyInReadOnlyData() {
197205
state.batchStartPosHolder.init(0);
198206

199207
for (int i = 0; i <= logitsIdx(); i++) {
208+
//System.out.println(i + " " + executionPlan.withGraph(i).toString());
200209
var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
201210
if (CUDA_GRAPHS) g.withCUDAGraph();
202211
g.execute();
@@ -268,6 +277,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
268277
if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
269278
//System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)");
270279
decodeAct.execute();
280+
// DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and
281+
// first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache)
282+
// if (position <= 290) {
283+
// System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n",
284+
// position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3));
285+
// System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n",
286+
// position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3));
287+
// }
271288

272289
// Graphs N+2..2N+1: decode transformer layers
273290
for (int l = 0; l < N; l++) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,17 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
146146
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
147147

148148
// === Data Setup ===
149-
unifiedLayer.consumeFromDevice(state.wrapX);
149+
// consumeFromDevice for wrapX: the no-arg form uses the current graph's own name as the
150+
// source key, which works in CUDA-graph mode (pointers are frozen) but fails in interpreter
151+
// mode (updatePersistedObjectState looks up the predecessor's name, not the current name).
152+
// Subclasses that receive wrapX across a graph boundary override predecessorGraphName() to
153+
// return the correct predecessor graph name so the XPUBuffer is propagated in both modes.
154+
String wrapXSrc = predecessorGraphName(layerIndex);
155+
if (wrapXSrc != null) {
156+
unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX);
157+
} else {
158+
unifiedLayer.consumeFromDevice(state.wrapX);
159+
}
150160
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
151161
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
152162
weights.wqLayered[layerIndex].asHalfFloatArray(),
@@ -248,11 +258,31 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
248258
weights.w2Layered[layerIndex].asHalfFloatArray(),
249259
config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
250260

251-
unifiedLayer.persistOnDevice(state.wrapX);
261+
unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache,
262+
state.wrapValueCache);
252263

253264
return unifiedLayer;
254265
}
255266

267+
/**
268+
* Returns the name of the predecessor task graph from which {@code wrapX} should be consumed,
269+
* or {@code null} to fall back to the no-arg form (source key = own graph name).
270+
*
271+
* <p>The no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time)
272+
* but fails in interpreter mode: {@code updatePersistedObjectState} looks up the predecessor's
273+
* graph name, not the current graph's name, so the XPUBuffer is never propagated and
274+
* {@code executeAlloc} NPEs on a null buffer.</p>
275+
*
276+
* <p>Override in subclasses that receive {@code wrapX} from a named predecessor graph:</p>
277+
* <ul>
278+
* <li>layer 0: return the activation graph name (e.g. {@code "activationUpdate"})</li>
279+
* <li>layer k &gt; 0: return {@code "layer_" + (k-1)}</li>
280+
* </ul>
281+
*/
282+
protected String predecessorGraphName(int layerIndex) {
283+
return null;
284+
}
285+
256286
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
257287
if (layerIndex == 0) {
258288
// First layer: Transfer initial data to device (one-time transfer)

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,25 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
2222
super(name, state, weights, config, lastTaskGraphID, schedulerType);
2323
}
2424

25+
/**
26+
* Hook called before any data transfers or tasks. Override to prepend
27+
* {@code consumeFromDevice} declarations that must precede the bytecode
28+
* (e.g. KV-cache pass-through in the Phase 4 unified plan).
29+
*/
30+
protected void configureAdditionalConsumes(TaskGraph logits) {}
31+
32+
/**
33+
* Hook called after {@code transferToHost}. Override to append
34+
* {@code persistOnDevice} declarations (e.g. KV-cache pass-through in Phase 4).
35+
*/
36+
protected void configureAdditionalPersists(TaskGraph logits) {}
37+
2538
// @formatter:off
2639
@Override
2740
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
2841
var logits = new TaskGraph("logits");
2942
// === Data Setup ===
43+
configureAdditionalConsumes(logits);
3044
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
3145
logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
3246
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
@@ -80,6 +94,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c
8094

8195
// === Transfer Results to Host ===
8296
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
97+
configureAdditionalPersists(logits);
8398
return logits;
8499
}
85100
// @formatter:on

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,22 @@
99
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1010

1111
/**
12-
* Identical to {@link LlamaFP16FFNLayers} except decode layer 0 uses
13-
* {@code consumeFromDevice} for the KV cache instead of {@code FIRST_EXECUTION}.
12+
* Decode-path FFN layers for the Phase 4 unified plan.
1413
*
15-
* <p>This ensures decode layer 0 receives the KV-cache device pointer that was
16-
* persisted by the last batch prefill layer and passed through the decode
17-
* activation graph.</p>
14+
* <p>Overrides data-transfer declarations so that all cross-graph boundaries use
15+
* the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by
16+
* the base class) passes the <em>current</em> graph's own name as the source key.
17+
* In CUDA-graph mode this is harmless (device pointers are frozen at capture time),
18+
* but in interpreter mode {@code updatePersistedObjectState} looks up the
19+
* <em>predecessor's</em> name, so the lookup always misses and the XPUBuffer is
20+
* never propagated — causing either a null-pointer crash or a silent re-upload
21+
* from host (zeros), corrupting the hidden state and KV cache.</p>
22+
*
23+
* <p>Two boundaries are fixed here:</p>
24+
* <ul>
25+
* <li>{@code wrapX}: via {@link #predecessorGraphName} hook in the base class.</li>
26+
* <li>All other consumed objects: via the {@link #configureLayerDataTransfers} override.</li>
27+
* </ul>
1828
*/
1929
public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers {
2030
public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
@@ -23,24 +33,46 @@ public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,
2333
super(taskGraph, state, weights, config, schedulerType);
2434
}
2535

36+
/**
37+
* Supplies the correct predecessor graph name for {@code consumeFromDevice(wrapX)}.
38+
*
39+
* <p>Layer 0 receives {@code wrapX} from the decode activation graph;
40+
* layers 1+ receive it from the previous decode layer.
41+
* Must match the {@code TaskGraph} names used in
42+
* {@code buildDecodeActivationGraph()} and {@code createFFNLayerTaskGraph()}.</p>
43+
*/
44+
@Override
45+
protected String predecessorGraphName(int layerIndex) {
46+
return (layerIndex == 0) ? "decodeActivationUpdate" : "layer_" + (layerIndex - 1);
47+
}
48+
2649
@Override
2750
protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
2851
if (layerIndex == 0) {
29-
// Same as parent layer 0 BUT wrapKeyCache/wrapValueCache come
30-
// from device (passed through by the decode activation graph).
52+
// Same as parent layer 0, but wrapKeyCache/wrapValueCache come from device
53+
// (passed through by the decode activation graph, which relays them from
54+
// the last batch prefill layer). No FIRST_EXECUTION for KV cache here.
3155
layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
3256
state.positionHolder, state.temp, state.tempFFN);
3357
layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
3458
context,
3559
state.wrapXb, state.wrapXb2,
3660
state.wrapQ, state.wrapK, state.wrapV,
3761
state.wrapAtt, state.wrapHb, state.wrapXbFP16);
38-
// KV cache: consume from device (device pointer supplied by
39-
// decode activation's pass-through from last batch layer).
40-
layer.consumeFromDevice(state.wrapKeyCache, state.wrapValueCache);
62+
// Explicit source — must match the TaskGraph name in buildDecodeActivationGraph().
63+
layer.consumeFromDevice("decodeActivationUpdate", state.wrapKeyCache, state.wrapValueCache);
4164
} else {
42-
// Identical to parent for layers 1+ (already uses consumeFromDevice).
43-
return super.configureLayerDataTransfers(layer, layerIndex);
65+
// Layers 1+: use explicit predecessor name for ALL consumed objects.
66+
// Calling super here would use the no-arg form (source key = own graph name),
67+
// which silently fails in interpreter mode and causes re-upload from host.
68+
String pred = "layer_" + (layerIndex - 1);
69+
layer.consumeFromDevice(pred,
70+
context,
71+
state.wrapXb, state.wrapXb2,
72+
state.wrapQ, state.wrapK, state.wrapV,
73+
state.wrapKeyCache, state.wrapValueCache,
74+
state.wrapAtt, state.wrapHb,
75+
state.positionHolder, state.wrapXbFP16);
4476
}
4577
return layer;
4678
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode;
2+
3+
import org.beehive.gpullama3.inference.state.State;
4+
import org.beehive.gpullama3.inference.weights.Weights;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
7+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
8+
import uk.ac.manchester.tornado.api.TaskGraph;
9+
10+
/**
11+
* Logits layer for the unified prefill-decode plan (Phase 4).
12+
*
13+
* <p>Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device
14+
* pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the
15+
* logits → decode-activation boundary across decode tokens.</p>
16+
*
17+
* <p>In interpreter (non-CUDA-graph) mode, {@code updatePersistedObjectState()}
18+
* propagates device pointers from the predecessor graph's persisted set. After the
19+
* last decode token the predecessor of the next decode-activation graph is the
20+
* logits graph. Without the pass-through here, the KV-cache pointer is absent from
21+
* the logits persisted set, cleared to null, and the first decode layer crashes with
22+
* an NPE in {@code executeAlloc}.</p>
23+
*
24+
* <p>Bytecode order matters: {@code consumeFromDevice} must precede task declarations,
25+
* and {@code persistOnDevice} must follow {@code transferToHost}. The hooks in
26+
* {@link LogitsFP16Layer} guarantee this ordering.</p>
27+
*/
28+
public class LogitsFP16LayerDecode extends LogitsFP16Layer {
29+
30+
public LogitsFP16LayerDecode(String name, State state, Weights weights, Configuration config,
31+
String lastTaskGraphID, SchedulerType schedulerType) {
32+
super(name, state, weights, config, lastTaskGraphID, schedulerType);
33+
}
34+
35+
/**
36+
* Prepends {@code consumeFromDevice(lastTaskGraphID, wrapKeyCache, wrapValueCache)} before all tasks.
37+
*
38+
* <p>Must use the named-source form so that {@code updatePersistedObjectState()} adds the KV cache
39+
* to the source-keyed map. Without the source name, the fallback in {@code updatePersistedObjectState}
40+
* uses the current graph's general persisted list, which causes the XPUBuffer from the predecessor
41+
* (last decode layer) to never be propagated into the logits graph's device state.</p>
42+
*/
43+
@Override
44+
protected void configureAdditionalConsumes(TaskGraph logits) {
45+
logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache);
46+
}
47+
48+
/** Appends {@code persistOnDevice(wrapKeyCache, wrapValueCache)} after {@code transferToHost}. */
49+
@Override
50+
protected void configureAdditionalPersists(TaskGraph logits) {
51+
logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache);
52+
}
53+
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,19 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
6969
state.wrapXbBatch,
7070
state.wrapHbBatch,
7171
state.wrapKeyCache, state.wrapValueCache);
72-
// wrapXBatch produced by the batch activation graph
73-
layer.consumeFromDevice(state.wrapXBatch);
72+
// wrapXBatch produced by the batch activation graph.
73+
// Explicit source name required: the no-arg form uses the current graph's own
74+
// name ("batchLayer_0") which never matches "batchActivation" in interpreter mode,
75+
// causing wrapXBatch to be re-uploaded from host (zeros) instead of using the
76+
// FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel.
77+
layer.consumeFromDevice("batchActivation", state.wrapXBatch);
7478
} else {
75-
layer.consumeFromDevice(
79+
// Explicit predecessor name for all objects.
80+
// The no-arg form would use "batchLayer_k" as the source key, which never matches
81+
// "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from
82+
// host (zeros or stale), corrupting the KV cache written by the previous layer.
83+
String pred = "batchLayer_" + (layerIndex - 1);
84+
layer.consumeFromDevice(pred,
7685
context,
7786
state.wrapXBatch,
7887
state.wrapXbFP16Batch,

0 commit comments

Comments
 (0)