|
5 | 5 | import org.beehive.gpullama3.inference.weights.Weights; |
6 | 6 | import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.FP16Weights; |
7 | 7 | import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.LlamaTornadoWeights; |
| 8 | +import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.Qwen2TornadoWeights; |
8 | 9 | import org.beehive.gpullama3.model.Configuration; |
9 | 10 | import org.beehive.gpullama3.model.Model; |
10 | 11 | import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; |
11 | 12 | import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; |
| 13 | +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; |
12 | 14 | import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; |
13 | 15 | import uk.ac.manchester.tornado.api.GridScheduler; |
14 | 16 | import uk.ac.manchester.tornado.api.ImmutableTaskGraph; |
@@ -36,31 +38,6 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration |
36 | 38 | this.logitsTaskGraph = setupLogitsTaskGraph(fp16Weights , config); |
37 | 39 | } |
38 | 40 |
|
39 | | - private TaskGraph setupLogitNonNVidia(FP16Weights weights, Configuration config) { |
40 | | - TaskGraph logits = new TaskGraph("logits") |
41 | | - .consumeFromDevice(lastTaskGraphID, |
42 | | - state.wrapX |
43 | | - ) |
44 | | - .transferToDevice(DataTransferMode.EVERY_EXECUTION, |
45 | | - state.tempLogits |
46 | | - ) |
47 | | - .transferToDevice(DataTransferMode.FIRST_EXECUTION, |
48 | | - context, |
49 | | - state.wrapLogits, |
50 | | - weights.wclsHalfFloat, |
51 | | - weights.rms_final_weight_as_floatArray |
52 | | - ) |
53 | | - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, |
54 | | - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) |
55 | | - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, |
56 | | - weights.rms_final_weight_as_floatArray, state.tempLogits); |
57 | | - logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // |
58 | | - context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, // |
59 | | - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // |
60 | | - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); |
61 | | - return logits; |
62 | | - } |
63 | | - |
64 | 41 | /** |
65 | 42 | * Builds the logits computation graph. |
66 | 43 | */ |
@@ -99,57 +76,38 @@ private GridScheduler setupGridSchedulerForLogits(Configuration config) { |
99 | 76 | rmsNormWorker.setGlobalWork(config.dim(), 1, 1); |
100 | 77 | rmsNormWorker.setLocalWork(256, 1, 1); |
101 | 78 |
|
| 79 | + WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); |
| 80 | + |
102 | 81 | // Projection kernel (vocabulary size × hidden dim) |
103 | 82 | int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; |
104 | 83 | WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal); |
105 | 84 | projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); |
106 | 85 |
|
107 | 86 | scheduler.addWorkerGrid("logits.projection", projectionWorker); |
108 | | - scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); |
109 | | - scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); |
| 87 | + scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); |
| 88 | + scheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); |
110 | 89 |
|
111 | 90 | return scheduler; |
112 | 91 | } |
113 | 92 |
|
114 | | -// @Override |
115 | | -// public GridScheduler updateGridScheduler(GridScheduler scheduler) { |
116 | | -// // RMSNorm operations |
117 | | -// WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); |
118 | | -// rmsNormWorker.setGlobalWork(config.dim(), 1, 1); |
119 | | -// rmsNormWorker.setLocalWork(256, 1, 1); |
120 | | -// |
121 | | -// // Projection kernel (vocabulary size × hidden dim) |
122 | | -// int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; |
123 | | -// WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal); |
124 | | -// projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); |
125 | | -// |
126 | | -// scheduler.addWorkerGrid("logits.projection", projectionWorker); |
127 | | -// scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); |
128 | | -// scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); |
129 | | -// |
130 | | -// return scheduler; |
131 | | -// } |
132 | | - |
133 | 93 |
|
134 | 94 | @Override |
135 | 95 | public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { |
136 | | - // RMSNorm operations |
137 | | - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); |
138 | | - |
139 | | - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension |
140 | 96 |
|
141 | | - //TODO: XXX |
142 | | - rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size) |
| 97 | + WorkerGrid logitsRMS = null; |
| 98 | + if (weights instanceof Qwen2TornadoWeights ) { |
| 99 | + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); |
| 100 | + } else { |
| 101 | + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); |
| 102 | + } |
143 | 103 |
|
144 | | - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) |
145 | | - // CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>> |
146 | 104 | int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; |
147 | 105 | WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); |
148 | 106 | vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); |
149 | 107 |
|
150 | 108 | tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); |
151 | | - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); |
152 | | - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); |
| 109 | + tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); |
| 110 | + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); |
153 | 111 | return scheduler; |
154 | 112 | } |
155 | 113 |
|
|
0 commit comments