Skip to content

Commit 2376059

Browse files
authored
Merge pull request beehive-lab#64 from orionpapadakis/fix/bug-fixes
Bug fixes in sizes and names of GridScheduler
2 parents c9f00ca + 90c3592 commit 2376059

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ List<ImmutableTaskGraph> setupFFNLayered() {
172172
* Setup a single transformer layer for Qwen3 with GQA
173173
*/
174174
TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) {
175-
var taskGraphName = "ffn_layer_" + layerIndex;
175+
var taskGraphName = "layer_" + layerIndex;
176176
TaskGraph unifiedLayer = new TaskGraph(taskGraphName);
177177
unifiedLayer.consumeFromDevice(state.wrapX);
178178
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.beehive.gpullama3.inference.state.State;
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.inference.weights.tornado.q8_0.LlamaTornadoWeightsQ8_0;
6-
import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen3TornadoWeightsQ8_0;
6+
import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen2TornadoWeightsQ8_0;
77
import org.beehive.gpullama3.model.Configuration;
88
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
99
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
@@ -34,7 +34,7 @@ public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Confi
3434
@Override
3535
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
3636
WorkerGrid logitsRMS;
37-
if (weights instanceof Qwen3TornadoWeightsQ8_0) {
37+
if (weights instanceof Qwen2TornadoWeightsQ8_0) {
3838
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32);
3939
} else {
4040
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);

0 commit comments

Comments
 (0)