Skip to content

Commit e2be881

Browse files
committed
Refactor TornadoVM Qwen layers and grid scheduler:
- Remove `GenericLayerPlanner` interface and consolidate layer planner logic. - Standardize `qwen2State` and `qwen3State` usage across layers. - Adjust local work sizes for better efficiency in FP16 and Q8_0 layers. - Cleanup redundant code and comments for improved readability.
1 parent cec38b3 commit e2be881

6 files changed

Lines changed: 191 additions & 350 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.FP16Weights;
77
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.LlamaTornadoWeights;
88
import org.beehive.gpullama3.model.Configuration;
9+
import org.beehive.gpullama3.model.Model;
910
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
1011
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
1112
import org.beehive.gpullama3.tornadovm.layers.AbstractLayer;
@@ -133,9 +134,12 @@ private GridScheduler setupGridSchedulerForLogits(Configuration config) {
133134
@Override
134135
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
135136
// RMSNorm operations
136-
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
137-
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
138-
rmsNormWorker.setLocalWork(256, 1, 1);
137+
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
138+
139+
rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension
140+
141+
//TODO: XXX
142+
rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size)
139143

140144
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
141145
// CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java

Lines changed: 44 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
5353
this.qwen2State = state;
5454
this.qwen2Config = config;
5555

56-
// state.temp.init(0.0f);
57-
// state.tempFFN.init(0.0f);
58-
// state.tempLogits.init(0.0f);
59-
// state.wrapLogits.init(0.0f);
56+
// qwen2State.temp.init(0.0f);
57+
// qwen2State.tempFFN.init(0.0f);
58+
// qwen2State.tempLogits.init(0.0f);
59+
// qwen2State.wrapLogits.init(0.0f);
6060

6161

6262
// Ensure we have Qwen2-specific weights
@@ -71,7 +71,6 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
7171

7272
@Override
7373
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
74-
7574
// Single worker for tasks running with a single thread
7675
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1])
7776
// CUDA equivalent: kernel<<<dim3(1,1,1), dim3(1,1,1)>>>
@@ -209,9 +208,8 @@ private void setupLastID(String taskGraphID) {
209208
List<ImmutableTaskGraph> setupFFNLayered() {
210209
List<ImmutableTaskGraph> ffnGraphs = new ArrayList<>();
211210

212-
state.temp.init(0.0f);
213-
qwen2State
214-
.tempFFN.init(0.0f);
211+
qwen2State.temp.init(0.0f);
212+
qwen2State.tempFFN.init(0.0f);
215213

216214

217215
for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) {
@@ -229,59 +227,39 @@ List<ImmutableTaskGraph> setupFFNLayered() {
229227
* Setup a single transformer layer for Qwen2 with GQA
230228
*/
231229
TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) {
232-
TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex);
230+
TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex);
233231
unifiedLayer.consumeFromDevice(state.wrapX);
234232
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
235233
//Copy-in weights per layer for batched-layered layout
236-
weights.rms_att_weightLayered[layerIndex],
237-
weights.wqLayered[layerIndex],
238-
weights.wkLayered[layerIndex],
239-
weights.wvLayered[layerIndex],
240-
weights.woLayered[layerIndex],
241-
weights.q_biasLayered[layerIndex],
242-
weights.k_biasLayered[layerIndex],
243-
weights.v_biasLayered[layerIndex],
244-
weights.rms_ffn_weightLayered[layerIndex],
245-
weights.w1Layered[layerIndex],
246-
weights.w2Layered[layerIndex],
247-
weights.w3Layered[layerIndex]
248-
);
234+
weights.rms_att_weightLayered[layerIndex], weights.wqLayered[layerIndex], weights.wkLayered[layerIndex], weights.wvLayered[layerIndex], weights.woLayered[layerIndex],
235+
weights.q_biasLayered[layerIndex], weights.k_biasLayered[layerIndex], weights.v_biasLayered[layerIndex], weights.rms_ffn_weightLayered[layerIndex], weights.w1Layered[layerIndex],
236+
weights.w2Layered[layerIndex], weights.w3Layered[layerIndex]);
249237
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
250238

251-
unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp,
252-
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
253-
.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
254-
state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
255-
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
256-
state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
257-
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
258-
state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
259-
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
260-
state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
261-
.task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex], config.dim())
262-
.task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex], config.kvDim())
263-
.task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex], config.kvDim())
264-
.task("rope", Qwen3Kernels::ropeRotation,context, state.positionHolder, state.wrapQ, state.wrapK, config.numberOfKeyValueHeads(),
265-
config.headSize())
266-
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
267-
state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength())
268-
.task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context,
269-
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
270-
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
271-
state.positionHolder, layerIndex, config.contextLength())
272-
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
273-
state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
274-
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
275-
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
276-
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
277-
state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
278-
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
279-
state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
280-
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
281-
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
282-
.persistOnDevice(
283-
state.wrapX
284-
);
239+
unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), qwen2State.localSize)
240+
.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex], qwen2State.temp)
241+
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(),
242+
LOCAL_WORK_GROUP_SIZE_ALLOC)
243+
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(),
244+
LOCAL_WORK_GROUP_SIZE_ALLOC)
245+
.task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(),
246+
LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex], config.dim())
247+
.task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex], config.kvDim())
248+
.task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex], config.kvDim())
249+
.task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize())
250+
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder, config.kvDim(),
251+
layerIndex, config.contextLength())
252+
.task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, config.numberOfHeads(),
253+
config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength())
254+
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(),
255+
LOCAL_WORK_GROUP_SIZE_ALLOC)
256+
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), qwen2State.localSize)
257+
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex], qwen2State.tempFFN)
258+
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex],
259+
weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
260+
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(),
261+
config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
262+
285263
return unifiedLayer;
286264
}
287265

@@ -292,19 +270,19 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
292270
// First layer: Transfer initial data to device (one-time transfer)
293271
if (layerIndex == 0) {
294272
// Transfer all attention-related data: query, key, value matrices and their caches
295-
unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
273+
unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); //
296274
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
297-
context, state.wrapXb, state.wrapXb2, //
298-
state.wrapQ, state.wrapK, state.wrapV, //
299-
state.wrapKeyCache, state.wrapValueCache, //
300-
state.wrapAtt, state.wrapHb); //
275+
context, qwen2State.wrapXb, qwen2State.wrapXb2, //
276+
qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, //
277+
qwen2State.wrapKeyCache, qwen2State.wrapValueCache, //
278+
qwen2State.wrapAtt, qwen2State.wrapHb); //
301279
} else {
302280
// Subsequent layers: Consume data already on device from previous layer
303-
unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
304-
state.wrapQ, state.wrapK, state.wrapV, //
305-
state.wrapKeyCache, state.wrapValueCache, //
306-
state.wrapAtt, state.wrapHb, //
307-
state.positionHolder //
281+
unifiedLayer.consumeFromDevice(context, qwen2State.wrapXb, qwen2State.wrapXb2, //
282+
qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, //
283+
qwen2State.wrapKeyCache, qwen2State.wrapValueCache, //
284+
qwen2State.wrapAtt, qwen2State.wrapHb, //
285+
qwen2State.positionHolder //
308286
);
309287
}
310288
return unifiedLayer;

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi
3737
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
3838
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
3939
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
40-
rmsNormWorker.setLocalWork(256, 1, 1);
40+
rmsNormWorker.setLocalWork(32, 1, 1);
4141
// RMSNorm operations
4242
int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
4343
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);

0 commit comments

Comments
 (0)