1010import org .beehive .gpullama3 .tornadovm .layerplanner .strategy .SchedulerType ;
1111import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .decode .LlamaFP16FFNLayersDecode ;
1212import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .prefill .LlamaFP16LayersBatchPrefill ;
13- import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .LogitsFP16Layer ;
13+ import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .decode . LogitsFP16LayerDecode ;
1414import uk .ac .manchester .tornado .api .GridScheduler ;
1515import uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
1616import uk .ac .manchester .tornado .api .KernelContext ;
@@ -94,8 +94,8 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
9494
9595 // [N+1] Decode activation (with KV-cache pass-through) ────────────────
9696 KernelContext decodeActCtx = new KernelContext ();
97- all .add (buildDecodeActivationGraph (decodeActCtx ).snapshot ());
98- scheduler .addWorkerGrid ("activationUpdate .updateX" ,
97+ all .add (buildDecodeActivationGraph (decodeActCtx , batchLayers . getLastLayerTaskGraphID () ).snapshot ());
98+ scheduler .addWorkerGrid ("decodeActivationUpdate .updateX" ,
9999 WorkerGridFactory .genericWorker (config .dim (), 128 ));
100100
101101 // [N+2..2N+1] Decode layer graphs ────────────────────────────────────
@@ -107,7 +107,10 @@ private TornadoVMMasterPlanWithBatchPrefillDecode(LlamaState state, Model model,
107107 decodeLayers .updateGridScheduler (scheduler );
108108
109109 // [2N+2] Logits ───────────────────────────────────────────────────────
110- LogitsFP16Layer logitsLayer = new LogitsFP16Layer ("logits" , state , weights , config ,
110+ // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
111+ // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
112+ // KV-cache pointer survives the logits → decode-activation boundary across tokens.
113+ LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode ("logits" , state , weights , config ,
111114 decodeLayers .getLastFFNLayerTaskGraphID (), schedulerType );
112115 all .add (logitsLayer .getImmutableTaskGraph ());
113116 logitsLayer .updateGridScheduler (scheduler );
@@ -123,9 +126,7 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
123126 return new TaskGraph ("batchActivation" )
124127 .transferToDevice (DataTransferMode .FIRST_EXECUTION , ctx , state .wrapXBatch )
125128 .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .embeddingXBatch )
126- .task ("batchUpdateX" ,
127- (KernelContext c , HalfFloatArray src , FloatArray dst ) ->
128- dst .set (c .globalIdx , src .get (c .globalIdx ).getFloat32 ()),
129+ .task ("batchUpdateX" , TransformerComputeKernels ::convertFP16toFP32 ,
129130 ctx , state .embeddingXBatch , state .wrapXBatch )
130131 .persistOnDevice (state .wrapXBatch );
131132 }
@@ -139,17 +140,24 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
139140 * Both halves of the chain are required; without the re-persist the pointer is
140141 * not forwarded in interpreter (non-CUDA-graph) mode.</p>
141142 */
142- private TaskGraph buildDecodeActivationGraph (KernelContext ctx ) {
143- return new TaskGraph ("activationUpdate" )
144- .consumeFromDevice (state .wrapKeyCache , state .wrapValueCache ) // KV pass-through
145- // .transferToDevice(DataTransferMode.EVERY_EXECUTION,
146- // state.wrapKeyCache,
147- // state.wrapValueCache)
148- .transferToDevice (DataTransferMode .FIRST_EXECUTION , ctx , state .wrapX )
143+ private TaskGraph buildDecodeActivationGraph (KernelContext ctx , String lastBatchLayerID ) {
144+ // System.out.println("lastBatchLayerID = " + lastBatchLayerID);
145+ // System.out.println("[buildDecodeActivationGraph] state.wrapX = " + state.wrapX.toString());
146+ // System.out.println("[buildDecodeActivationGraph] state.wrapKeyCache = " + state.wrapKeyCache.toString());
147+ // System.out.println("[buildDecodeActivationGraph] state.wrapValueCache = " + state.wrapValueCache.toString());
148+ return new TaskGraph ("decodeActivationUpdate" )
149+ .consumeFromDevice (lastBatchLayerID , state .wrapKeyCache , state .wrapValueCache ) // KV pass-through
150+ //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX, debugKV)
151+ //.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapX)
149152 .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .embeddingX )
150153 .task ("updateX" ,
151154 TransformerComputeKernels ::convertFP16toFP32 ,
152155 ctx , (HalfFloatArray ) state .embeddingX , state .wrapX )
156+ // // DEBUG: snapshot first 8 elements of wrapKeyCache and wrapX for host-side probe
157+ // .task("dbgKV",
158+ // TransformerComputeKernels::dbgCopyFirst8,
159+ // state.wrapKeyCache, debugKV)
160+ // .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX, debugKV)
153161 // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache
154162 // re-persisted so updatePersistedObjectState() propagates the device
155163 // pointer to decode layer 0's consumeFromDevice without CUDA graphs.
@@ -197,6 +205,7 @@ private void forceCopyInReadOnlyData() {
197205 state .batchStartPosHolder .init (0 );
198206
199207 for (int i = 0 ; i <= logitsIdx (); i ++) {
208+ //System.out.println(i + " " + executionPlan.withGraph(i).toString());
200209 var g = executionPlan .withGraph (i ).withGridScheduler (gridScheduler );
201210 if (CUDA_GRAPHS ) g .withCUDAGraph ();
202211 g .execute ();
@@ -268,6 +277,14 @@ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
268277 if (CUDA_GRAPHS ) decodeAct .withCUDAGraph ();
269278 //System.err.println("[DEBUG] about to execute decode activation (graph " + decodeActivationIdx() + "--)");
270279 decodeAct .execute ();
280+ // DEBUG: print first 4 of wrapX (should be non-zero FP32 embedding) and
281+ // first 4 of debugKV (should be non-zero after batch prefill wrote the KV cache)
282+ // if (position <= 290) {
283+ // System.err.printf("[DBG pos=%d] wrapX[0..3] = %.4f %.4f %.4f %.4f%n",
284+ // position, state.wrapX.get(0), state.wrapX.get(1), state.wrapX.get(2), state.wrapX.get(3));
285+ // System.err.printf("[DBG pos=%d] debugKV[0..3]= %.4f %.4f %.4f %.4f%n",
286+ // position, debugKV.get(0), debugKV.get(1), debugKV.get(2), debugKV.get(3));
287+ // }
271288
272289 // Graphs N+2..2N+1: decode transformer layers
273290 for (int l = 0 ; l < N ; l ++) {
0 commit comments