Skip to content

Commit 1e71fa6

Browse files
committed
Refactor RMSNorm worker creation in LogitsQ8_0Layer and LogitsFP16Layer:
- Introduced `WorkerGridFactory` for standardizing RMSNorm worker creation. - Adjusted RMSNorm worker configuration for FP16 and Q8_0 layers with support for conditional weight types. - Removed redundant code and outdated comments for clarity.
1 parent 3ab251b commit 1e71fa6

2 files changed

Lines changed: 27 additions & 65 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import org.beehive.gpullama3.inference.weights.Weights;
66
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.FP16Weights;
77
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.LlamaTornadoWeights;
8+
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.Qwen2TornadoWeights;
89
import org.beehive.gpullama3.model.Configuration;
910
import org.beehive.gpullama3.model.Model;
1011
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
1112
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
13+
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
1214
import org.beehive.gpullama3.tornadovm.layers.AbstractLayer;
1315
import uk.ac.manchester.tornado.api.GridScheduler;
1416
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -36,31 +38,6 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
3638
this.logitsTaskGraph = setupLogitsTaskGraph(fp16Weights , config);
3739
}
3840

39-
private TaskGraph setupLogitNonNVidia(FP16Weights weights, Configuration config) {
40-
TaskGraph logits = new TaskGraph("logits")
41-
.consumeFromDevice(lastTaskGraphID,
42-
state.wrapX
43-
)
44-
.transferToDevice(DataTransferMode.EVERY_EXECUTION,
45-
state.tempLogits
46-
)
47-
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
48-
context,
49-
state.wrapLogits,
50-
weights.wclsHalfFloat,
51-
weights.rms_final_weight_as_floatArray
52-
)
53-
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
54-
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
55-
.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
56-
weights.rms_final_weight_as_floatArray, state.tempLogits);
57-
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
58-
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
59-
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
60-
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
61-
return logits;
62-
}
63-
6441
/**
6542
* Builds the logits computation graph.
6643
*/
@@ -99,57 +76,38 @@ private GridScheduler setupGridSchedulerForLogits(Configuration config) {
9976
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
10077
rmsNormWorker.setLocalWork(256, 1, 1);
10178

79+
WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
80+
10281
// Projection kernel (vocabulary size × hidden dim)
10382
int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
10483
WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal);
10584
projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
10685

10786
scheduler.addWorkerGrid("logits.projection", projectionWorker);
108-
scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
109-
scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
87+
scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
88+
scheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
11089

11190
return scheduler;
11291
}
11392

114-
// @Override
115-
// public GridScheduler updateGridScheduler(GridScheduler scheduler) {
116-
// // RMSNorm operations
117-
// WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
118-
// rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
119-
// rmsNormWorker.setLocalWork(256, 1, 1);
120-
//
121-
// // Projection kernel (vocabulary size × hidden dim)
122-
// int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
123-
// WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal);
124-
// projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
125-
//
126-
// scheduler.addWorkerGrid("logits.projection", projectionWorker);
127-
// scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
128-
// scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
129-
//
130-
// return scheduler;
131-
// }
132-
13393

13494
@Override
13595
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
136-
// RMSNorm operations
137-
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
138-
139-
rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension
14096

141-
//TODO: XXX
142-
rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size)
97+
WorkerGrid logitsRMS = null;
98+
if (weights instanceof Qwen2TornadoWeights ) {
99+
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
100+
} else {
101+
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
102+
}
143103

144-
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
145-
// CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
146104
int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
147105
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
148106
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
149107

150108
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
151-
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
152-
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
109+
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
110+
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
153111
return scheduler;
154112
}
155113

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import org.beehive.gpullama3.inference.state.State;
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.LlamaTornadoWeights;
6+
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.Qwen2TornadoWeights;
67
import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights.Q8_0Weights;
8+
import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights.Qwen3Q8_0TornadoWeights;
79
import org.beehive.gpullama3.model.Configuration;
810
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
911
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
12+
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
1013
import org.beehive.gpullama3.tornadovm.layers.AbstractLayer;
1114
import uk.ac.manchester.tornado.api.GridScheduler;
1215
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -35,17 +38,18 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi
3538

3639
@Override
3740
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
38-
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
39-
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
40-
rmsNormWorker.setLocalWork(32, 1, 1);
41-
// RMSNorm operations
42-
int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
43-
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
44-
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
41+
42+
WorkerGrid logitsRMS;
43+
if (weights instanceof Qwen3Q8_0TornadoWeights) {
44+
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
45+
} else {
46+
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
47+
}
48+
4549

4650
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
47-
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
48-
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
51+
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
52+
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
4953

5054
return tornadoForwardScheduler;
5155
}

0 commit comments

Comments
 (0)