@@ -51,37 +51,32 @@ public LlamaFP16LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights
5151
5252 // @formatter:off
5353 private TaskGraph createBatchPrefillLayerTaskGraph (int layerIndex ) {
54- String graphName = "batchLayer_ " + layerIndex ;
54+ String graphName = "batchPrefillLayer_ " + layerIndex ;
5555 if (layerIndex == config .numberOfLayers () - 1 ) lastLayerTaskGraphID = graphName ;
5656
57- TaskGraph layer = new TaskGraph (graphName );
57+ TaskGraph batchPrefillLayer = new TaskGraph (graphName );
5858
5959 // ── Data Transfers ─────────────────────────────────────────────────────
6060 if (layerIndex == 0 ) {
6161 // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION
62- layer .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .batchStartPosHolder );
62+ batchPrefillLayer .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .batchStartPosHolder );
6363 // Allocate persistent GPU-side intermediates once
64- layer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
64+ batchPrefillLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
6565 context ,
6666 state .attnScaleBatch , state .ffnScaleBatch ,
6767 state .wrapXbFP16Batch ,
6868 state .wrapQBatch , state .wrapKBatch , state .wrapVBatch ,
6969 state .wrapXbBatch ,
7070 state .wrapHbBatch ,
7171 state .wrapKeyCache , state .wrapValueCache );
72- // wrapXBatch produced by the batch activation graph.
73- // Explicit source name required: the no-arg form uses the current graph's own
74- // name ("batchLayer_0") which never matches "batchActivation" in interpreter mode,
75- // causing wrapXBatch to be re-uploaded from host (zeros) instead of using the
76- // FP32 embeddings computed by the activation graph's convertFP16toFP32 kernel.
77- layer .consumeFromDevice ("batchActivation" , state .wrapXBatch );
72+ // wrapXBatch produced by the prefillActivation graph and persists in device memory
73+ // to consume it from there we should use the explicit uniqueTaskGraph name
74+ // the no-arg form would use current graph name, which causes NPE without CUDA Graphs
75+ batchPrefillLayer .consumeFromDevice ("prefillActivation" , state .wrapXBatch );
7876 } else {
79- // Explicit predecessor name for all objects.
80- // The no-arg form would use "batchLayer_k" as the source key, which never matches
81- // "batchLayer_{k-1}" in interpreter mode — every object would be re-uploaded from
82- // host (zeros or stale), corrupting the KV cache written by the previous layer.
83- String pred = "batchLayer_" + (layerIndex - 1 );
84- layer .consumeFromDevice (pred ,
77+ // for the same reasons as above, we should use the explicit uniqueTaskGraph name to consume
78+ String pred = "batchPrefillLayer_" + (layerIndex - 1 );
79+ batchPrefillLayer .consumeFromDevice (pred ,
8580 context ,
8681 state .wrapXBatch ,
8782 state .wrapXbFP16Batch ,
@@ -94,7 +89,7 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
9489 }
9590
9691 // Per-layer weights: upload once
97- layer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
92+ batchPrefillLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
9893 weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
9994 weights .wqLayered [layerIndex ].asHalfFloatArray (),
10095 weights .wkLayered [layerIndex ].asHalfFloatArray (),
@@ -110,18 +105,18 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
110105 int hidDim = config .hiddenDim ();
111106
112107 // ── Attention Block ────────────────────────────────────────────────────
113- layer .task ("batch_attn_rms" ,
108+ batchPrefillLayer .task ("batch_attn_rms" ,
114109 TransformerBatchPrefillKernels ::batchedRmsReduce ,
115110 context , state .wrapXBatch , state .attnScaleBatch ,
116111 dim , config .rmsNormEps ());
117112
118- layer .task ("batch_attn_rms_apply" ,
113+ batchPrefillLayer .task ("batch_attn_rms_apply" ,
119114 TransformerBatchPrefillKernels ::batchedRmsApplyFP16 ,
120115 context , state .wrapXbFP16Batch , state .wrapXBatch ,
121116 weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
122117 state .attnScaleBatch , dim );
123118
124- layer .task ("batch_qkv" ,
119+ batchPrefillLayer .task ("batch_qkv" ,
125120 TransformerBatchPrefillKernels ::batchedFusedQKVMatmul ,
126121 context ,
127122 state .wrapXbFP16Batch ,
@@ -131,34 +126,34 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
131126 weights .wvLayered [layerIndex ].asHalfFloatArray (),
132127 dim , kvDim , LOCAL_WORK_GROUP_SIZE );
133128
134- layer .task ("batch_rope_kv" ,
129+ batchPrefillLayer .task ("batch_rope_kv" ,
135130 TransformerBatchPrefillKernels ::batchedRopeWithKVCache ,
136131 context , state .batchStartPosHolder ,
137132 state .wrapQBatch , state .wrapKBatch , state .wrapVBatch ,
138133 state .wrapKeyCache , state .wrapValueCache ,
139134 kvDim , config .headSize (), layerIndex , config .contextLength (), dim );
140135
141- layer .task ("batch_attention" ,
136+ batchPrefillLayer .task ("batch_attention" ,
142137 TransformerBatchPrefillKernels ::batchedFlashAttention ,
143138 context , state .batchStartPosHolder ,
144139 state .wrapQBatch , state .wrapKeyCache , state .wrapValueCache ,
145140 state .wrapXbBatch ,
146141 config .numberOfHeads (), config .headSize (),
147142 kvDim , config .kvMul (), layerIndex , config .contextLength (), dim );
148143
149- layer .task ("batch_attn_out" ,
144+ batchPrefillLayer .task ("batch_attn_out" ,
150145 TransformerBatchPrefillKernels ::batchedMatVecWithResidual ,
151146 context , state .wrapXbBatch , state .wrapXBatch ,
152147 weights .woLayered [layerIndex ].asHalfFloatArray (),
153148 dim , dim , LOCAL_WORK_GROUP_SIZE );
154149
155150 // ── FFN Block ──────────────────────────────────────────────────────────
156- layer .task ("batch_ffn_rms" ,
151+ batchPrefillLayer .task ("batch_ffn_rms" ,
157152 TransformerBatchPrefillKernels ::batchedFFNRmsReduce ,
158153 context , state .wrapXBatch , state .ffnScaleBatch ,
159154 dim , config .rmsNormEps ());
160155
161- layer .task ("batch_ffn_gate_up" ,
156+ batchPrefillLayer .task ("batch_ffn_gate_up" ,
162157 TransformerBatchPrefillKernels ::batchedFusedRmsNormFFNGateUp ,
163158 context , state .wrapXBatch , state .wrapHbBatch ,
164159 weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (),
@@ -167,17 +162,17 @@ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
167162 weights .w3Layered [layerIndex ].asHalfFloatArray (),
168163 dim , hidDim , LOCAL_WORK_GROUP_SIZE );
169164
170- layer .task ("batch_ffn_down" ,
165+ batchPrefillLayer .task ("batch_ffn_down" ,
171166 TransformerBatchPrefillKernels ::batchedMatVecWithResidual ,
172167 context , state .wrapHbBatch , state .wrapXBatch ,
173168 weights .w2Layered [layerIndex ].asHalfFloatArray (),
174169 hidDim , dim , LOCAL_WORK_GROUP_SIZE );
175170
176171 // Persist wrapXBatch for the next layer, and KV cache so the decode
177172 // layers can consume it via the activation graph pass-through.
178- layer .persistOnDevice (state .wrapXBatch , state .wrapKeyCache , state .wrapValueCache );
173+ batchPrefillLayer .persistOnDevice (state .wrapXBatch , state .wrapKeyCache , state .wrapValueCache );
179174
180- return layer ;
175+ return batchPrefillLayer ;
181176 }
182177 // @formatter:on
183178
@@ -218,7 +213,7 @@ public void updateGridScheduler(GridScheduler scheduler) {
218213 batchSize * hidDim * LOCAL_WORK_GROUP_SIZE , LOCAL_WORK_GROUP_SIZE );
219214
220215 for (int i = 0 ; i < config .numberOfLayers (); i ++) {
221- String p = "batchLayer_ " + i + "." ;
216+ String p = "batchPrefillLayer_ " + i + "." ;
222217 scheduler .addWorkerGrid (p + "batch_attn_rms" , rmsWorker );
223218 scheduler .addWorkerGrid (p + "batch_attn_rms_apply" , rmsApplyWorker );
224219 scheduler .addWorkerGrid (p + "batch_qkv" , qkvWorker );
0 commit comments