@@ -53,10 +53,10 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
5353 this .qwen2State = state ;
5454 this .qwen2Config = config ;
5555
56- // state .temp.init(0.0f);
57- // state .tempFFN.init(0.0f);
58- // state .tempLogits.init(0.0f);
59- // state .wrapLogits.init(0.0f);
56+ // qwen2State .temp.init(0.0f);
57+ // qwen2State .tempFFN.init(0.0f);
58+ // qwen2State .tempLogits.init(0.0f);
59+ // qwen2State .wrapLogits.init(0.0f);
6060
6161
6262 // Ensure we have Qwen2-specific weights
@@ -71,7 +71,6 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
7171
7272 @ Override
7373 public GridScheduler updateGridScheduler (GridScheduler tornadoForwardScheduler ) {
74-
7574 // Single worker for tasks running with a single thread
7675 // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1])
7776 // CUDA equivalent: kernel<<<dim3(1,1,1), dim3(1,1,1)>>>
@@ -209,9 +208,8 @@ private void setupLastID(String taskGraphID) {
209208 List <ImmutableTaskGraph > setupFFNLayered () {
210209 List <ImmutableTaskGraph > ffnGraphs = new ArrayList <>();
211210
212- state .temp .init (0.0f );
213- qwen2State
214- .tempFFN .init (0.0f );
211+ qwen2State .temp .init (0.0f );
212+ qwen2State .tempFFN .init (0.0f );
215213
216214
217215 for (int layerIndex = 0 ; layerIndex < qwen2Config .numberOfLayers (); layerIndex ++) {
@@ -229,59 +227,39 @@ List<ImmutableTaskGraph> setupFFNLayered() {
229227 * Setup a single transformer layer for Qwen2 with GQA
230228 */
231229 TaskGraph setupSingleQwen2FFNLayer (Qwen2TornadoWeights weights , int layerIndex ) {
232- TaskGraph unifiedLayer = new TaskGraph ("layer_" + layerIndex );
230+ TaskGraph unifiedLayer = new TaskGraph ("layer_" + layerIndex );
233231 unifiedLayer .consumeFromDevice (state .wrapX );
234232 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
235233 //Copy-in weights per layer for batched-layered layout
236- weights .rms_att_weightLayered [layerIndex ],
237- weights .wqLayered [layerIndex ],
238- weights .wkLayered [layerIndex ],
239- weights .wvLayered [layerIndex ],
240- weights .woLayered [layerIndex ],
241- weights .q_biasLayered [layerIndex ],
242- weights .k_biasLayered [layerIndex ],
243- weights .v_biasLayered [layerIndex ],
244- weights .rms_ffn_weightLayered [layerIndex ],
245- weights .w1Layered [layerIndex ],
246- weights .w2Layered [layerIndex ],
247- weights .w3Layered [layerIndex ]
248- );
234+ weights .rms_att_weightLayered [layerIndex ], weights .wqLayered [layerIndex ], weights .wkLayered [layerIndex ], weights .wvLayered [layerIndex ], weights .woLayered [layerIndex ],
235+ weights .q_biasLayered [layerIndex ], weights .k_biasLayered [layerIndex ], weights .v_biasLayered [layerIndex ], weights .rms_ffn_weightLayered [layerIndex ], weights .w1Layered [layerIndex ],
236+ weights .w2Layered [layerIndex ], weights .w3Layered [layerIndex ]);
249237 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
250238
251- unifiedLayer .task ("reductionsOneBlock" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .temp ,
252- state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
253- .task ("mapContext" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
254- state .wrapX , weights .rms_att_weightLayered [layerIndex ], state .temp )
255- .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
256- state .wrapXb , state .wrapQ , weights .wqLayered [layerIndex ], config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
257- .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
258- state .wrapXb , state .wrapK , weights .wkLayered [layerIndex ], config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
259- .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
260- state .wrapXb , state .wrapV , weights .wvLayered [layerIndex ], config .dim (), config .kvDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
261- .task ("qbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapQ , weights .q_biasLayered [layerIndex ], config .dim ())
262- .task ("kbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapK , weights .k_biasLayered [layerIndex ], config .kvDim ())
263- .task ("vbias" , TransformerComputeKernelsLayered ::addInPlace , state .wrapV , weights .v_biasLayered [layerIndex ], config .kvDim ())
264- .task ("rope" , Qwen3Kernels ::ropeRotation ,context , state .positionHolder , state .wrapQ , state .wrapK , config .numberOfKeyValueHeads (),
265- config .headSize ())
266- .task ("copyToCaches" , TransformerComputeKernelsLayered ::copyToCache ,
267- state .wrapKeyCache , state .wrapK , state .wrapValueCache , state .wrapV , state .positionHolder , config .kvDim (), layerIndex , config .contextLength ())
268- .task ("parallel-attention" , Qwen2Kernels ::processHeadsFlashAttention , context ,
269- state .wrapQ , state .wrapKeyCache , state .wrapValueCache , state .wrapXb ,
270- config .numberOfHeads (), config .headSize (), config .kvDim (), config .kvMul (),
271- state .positionHolder , layerIndex , config .contextLength ())
272- .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
273- state .wrapXb , state .wrapX , weights .woLayered [layerIndex ], config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
274- .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
275- state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
276- .task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
277- state .wrapX , weights .rms_ffn_weightLayered [layerIndex ], state .tempFFN )
278- .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation , context ,
279- state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ], weights .w3Layered [layerIndex ], config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
280- .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
281- state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ], config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
282- .persistOnDevice (
283- state .wrapX
284- );
239+ unifiedLayer .task ("reductionsOneBlock" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , qwen2State .temp , qwen2State .wrapX , config .dim (), config .rmsNormEps (), qwen2State .localSize )
240+ .task ("mapContext" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , qwen2State .wrapXb , qwen2State .wrapX , weights .rms_att_weightLayered [layerIndex ], qwen2State .temp )
241+ .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , qwen2State .wrapXb , qwen2State .wrapQ , weights .wqLayered [layerIndex ], config .dim (), config .dim (),
242+ LOCAL_WORK_GROUP_SIZE_ALLOC )
243+ .task ("kmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , qwen2State .wrapXb , qwen2State .wrapK , weights .wkLayered [layerIndex ], config .dim (), config .kvDim (),
244+ LOCAL_WORK_GROUP_SIZE_ALLOC )
245+ .task ("vmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , qwen2State .wrapXb , qwen2State .wrapV , weights .wvLayered [layerIndex ], config .dim (), config .kvDim (),
246+ LOCAL_WORK_GROUP_SIZE_ALLOC ).task ("qbias" , TransformerComputeKernelsLayered ::addInPlace , qwen2State .wrapQ , weights .q_biasLayered [layerIndex ], config .dim ())
247+ .task ("kbias" , TransformerComputeKernelsLayered ::addInPlace , qwen2State .wrapK , weights .k_biasLayered [layerIndex ], config .kvDim ())
248+ .task ("vbias" , TransformerComputeKernelsLayered ::addInPlace , qwen2State .wrapV , weights .v_biasLayered [layerIndex ], config .kvDim ())
249+ .task ("rope" , Qwen3Kernels ::ropeRotation , context , qwen2State .positionHolder , qwen2State .wrapQ , qwen2State .wrapK , config .numberOfKeyValueHeads (), config .headSize ())
250+ .task ("copyToCaches" , TransformerComputeKernelsLayered ::copyToCache , qwen2State .wrapKeyCache , qwen2State .wrapK , qwen2State .wrapValueCache , qwen2State .wrapV , qwen2State .positionHolder , config .kvDim (),
251+ layerIndex , config .contextLength ())
252+ .task ("parallel-attention" , Qwen2Kernels ::processHeadsFlashAttention , context , qwen2State .wrapQ , qwen2State .wrapKeyCache , qwen2State .wrapValueCache , qwen2State .wrapXb , config .numberOfHeads (),
253+ config .headSize (), config .kvDim (), config .kvMul (), qwen2State .positionHolder , layerIndex , config .contextLength ())
254+ .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , qwen2State .wrapXb , qwen2State .wrapX , weights .woLayered [layerIndex ], config .dim (), config .dim (),
255+ LOCAL_WORK_GROUP_SIZE_ALLOC )
256+ .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , qwen2State .tempFFN , qwen2State .wrapX , config .dim (), config .rmsNormEps (), qwen2State .localSize )
257+ .task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , qwen2State .wrapXb , qwen2State .wrapX , weights .rms_ffn_weightLayered [layerIndex ], qwen2State .tempFFN )
258+ .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation , context , qwen2State .wrapXb , qwen2State .wrapHb , weights .w1Layered [layerIndex ],
259+ weights .w3Layered [layerIndex ], config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
260+ .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , qwen2State .wrapHb , qwen2State .wrapX , weights .w2Layered [layerIndex ], config .hiddenDim (),
261+ config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC ).persistOnDevice (state .wrapX );
262+
285263 return unifiedLayer ;
286264 }
287265
@@ -292,19 +270,19 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
292270 // First layer: Transfer initial data to device (one-time transfer)
293271 if (layerIndex == 0 ) {
294272 // Transfer all attention-related data: query, key, value matrices and their caches
295- unifiedLayer .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .positionHolder , state .temp , state .tempFFN ); //
273+ unifiedLayer .transferToDevice (DataTransferMode .EVERY_EXECUTION , qwen2State .positionHolder , qwen2State .temp , qwen2State .tempFFN ); //
296274 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION , //
297- context , state .wrapXb , state .wrapXb2 , //
298- state .wrapQ , state .wrapK , state .wrapV , //
299- state .wrapKeyCache , state .wrapValueCache , //
300- state .wrapAtt , state .wrapHb ); //
275+ context , qwen2State .wrapXb , qwen2State .wrapXb2 , //
276+ qwen2State .wrapQ , qwen2State .wrapK , qwen2State .wrapV , //
277+ qwen2State .wrapKeyCache , qwen2State .wrapValueCache , //
278+ qwen2State .wrapAtt , qwen2State .wrapHb ); //
301279 } else {
302280 // Subsequent layers: Consume data already on device from previous layer
303- unifiedLayer .consumeFromDevice (context , state .wrapXb , state .wrapXb2 , //
304- state .wrapQ , state .wrapK , state .wrapV , //
305- state .wrapKeyCache , state .wrapValueCache , //
306- state .wrapAtt , state .wrapHb , //
307- state .positionHolder //
281+ unifiedLayer .consumeFromDevice (context , qwen2State .wrapXb , qwen2State .wrapXb2 , //
282+ qwen2State .wrapQ , qwen2State .wrapK , qwen2State .wrapV , //
283+ qwen2State .wrapKeyCache , qwen2State .wrapValueCache , //
284+ qwen2State .wrapAtt , qwen2State .wrapHb , //
285+ qwen2State .positionHolder //
308286 );
309287 }
310288 return unifiedLayer ;
0 commit comments