Skip to content

Commit 2befb97

Browse files
committed
Refactor TornadoVM layer planners and worker grid logic:
- Replaced `TornadoVMGenericLayerPlanner` with `GenericLayerPlanner` for consistency across planners. - Updated `QuantizationPlannerFactory` and related classes to use the new interface. - Added `createSingleWorker` method to `WorkerGridFactory` for standardized single worker creation. - Simplified and cleaned up TornadoVMMasterPlan, removing unused methods and comments.
1 parent 72c6619 commit 2befb97

7 files changed

Lines changed: 25 additions & 92 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMGenericLayerPlanner.java renamed to src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import java.util.List;
88

9-
public interface TornadoVMGenericLayerPlanner {
9+
public interface GenericLayerPlanner {
1010

1111
Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered();
1212

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 5 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,22 @@
44
import org.beehive.gpullama3.inference.state.State;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.Model;
7-
import org.beehive.gpullama3.model.ModelType;
87
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
9-
import org.beehive.gpullama3.tornadovm.layers.SchedulerDetectionService;
10-
import org.beehive.gpullama3.tornadovm.layers.SchedulerType;
118
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
129
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
13-
import uk.ac.manchester.tornado.api.TornadoRuntime;
14-
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
1510
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1611

17-
import java.util.Locale;
18-
1912
public class TornadoVMMasterPlan {
2013
public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
2114

2215
private final State state;
2316
private final Configuration config;
2417
public TornadoExecutionPlan executionPlan;
25-
private SchedulerType schedulerDetectionService;
26-
TornadoVMGenericLayerPlanner tornadoVMLayerPlanner;
18+
GenericLayerPlanner tornadoVMLayerPlanner;
2719

2820
public TornadoVMMasterPlan(State state, Model model) {
29-
// this.schedulerDetectionService = SchedulerDetectionService.determineSchedulerType(model);
30-
3121
this.tornadoVMLayerPlanner = createPlannerWithStrategy(state, model);
3222
this.executionPlan = new TornadoExecutionPlan(tornadoVMLayerPlanner.getCachedTaskGraphs().toArray(new ImmutableTaskGraph[tornadoVMLayerPlanner.getCachedTaskGraphs().size()]));
33-
3423
this.state = state;
3524
this.config = model.configuration();
3625
}
@@ -57,7 +46,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
5746
}
5847

5948
// 1. Pre-allocate the TornadoVM plan
60-
TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model );
49+
TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model);
6150

6251
// Record time after plan creation
6352
if (ENABLE_TORNADOVM_INIT_TIME) {
@@ -89,81 +78,16 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
8978
return tornadoVMPlan;
9079
}
9180

92-
/**
93-
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
94-
*/
95-
// TornadoVMGenericLayerPlanner createPlanner(State state, Model model) {
96-
// return switch (model.getModelType()) {
97-
// case LLAMA_3, MISTRAL -> whatcreateLlama3Planner(state, model);
98-
// // case PHI_3 -> createPhi3Planner(state, model);
99-
// // case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> createQWEN2Planner(state, model);
100-
// // case QWEN_3 -> createQWEN3Planner(state, model);
101-
// case QWEN_2 -> null;
102-
// case QWEN_3 -> null;
103-
// case DEEPSEEK_R1_DISTILL_QWEN -> null;
104-
// case PHI_3 -> null;
105-
// case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
106-
// };
107-
// }
108-
109-
// private TornadoVMGenericLayerPlanner whatcreateLlama3Planner(State state, Model model) {
110-
// if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
111-
// return new TornadoVMQ8_0LayerPlanner(state, model);
112-
// } else {
113-
// return new TornadoVMLayerPlanner(state, model);
114-
// }
115-
// }
116-
117-
// private TornadoVMGenericLayerPlanner createQWEN2Planner(State state, Model model) {
118-
// if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
119-
// return new Qwen2Q8_0TornadoVMLayerPlanner((Qwen2State) state, model);
120-
// } else {
121-
// return new Qwen2TornadoVMLayerPlanner((Qwen2State) state, model);
122-
// }
123-
// }
124-
//
125-
// private TornadoVMGenericLayerPlanner createPhi3Planner(State state, Model model) {
126-
// if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
127-
// return new Phi3TornadoVMLayerPlannerQ8_0((Phi3State) state, model);
128-
// } else {
129-
// return new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
130-
// }
131-
// }
132-
//
133-
// private TornadoVMGenericLayerPlanner createQWEN3Planner(State state, Model model) {
134-
// if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
135-
// return new Qwen3Q8_0TornadoVMLayerPlanner((Qwen3State) state, model);
136-
// } else {
137-
// return new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
138-
// }
139-
// }
140-
141-
private TornadoVMGenericLayerPlanner createPlannerWithStrategy(State state, Model model) {
81+
private GenericLayerPlanner createPlannerWithStrategy(State state, Model model) {
14282

14383
// ========== STEP 1: Detect Quantization Type ==========
14484
GGMLType weightType = model.weights().getWeightType();
14585

14686
// ========== STEP 2: Route via Factory ==========
14787
// Factory handles all model × quantization combinations
148-
TornadoVMGenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model);
149-
150-
return basePlanner;
151-
}
152-
153-
154-
public static SchedulerType shouldUseNvidiaScheduler(Model model) {
155-
TornadoRuntime runtime = TornadoRuntimeProvider.getTornadoRuntime();
156-
String platformName = runtime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT);
88+
GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model);
15789

158-
boolean isNvidia = platformName.contains("nvidia") || platformName.contains("cuda") || platformName.contains("ptx");
159-
boolean isNotMistral = model.getModelType() != ModelType.MISTRAL;
160-
161-
162-
if (isNvidia && isNotMistral) {
163-
return SchedulerType.NVIDIA;
164-
} else {
165-
return SchedulerType.NON_NVIDIA;
166-
}
90+
return basePlanner;
16791
}
16892

16993
/**

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ public static WorkerGrid createRmsNormWorker(int dim, int localSize) {
1717
return worker;
1818
}
1919

20+
public static WorkerGrid createSingleWorker() {
21+
WorkerGrid worker = new WorkerGrid1D(1);
22+
worker.setGlobalWork(1, 1, 1);
23+
worker.setLocalWork(1, 1, 1);
24+
return worker;
25+
}
26+
2027
/**
2128
* QKV matmul worker: combined projection output
2229
*/

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import org.beehive.gpullama3.inference.state.Qwen3State;
88
import org.beehive.gpullama3.inference.state.State;
99
import org.beehive.gpullama3.model.Model;
10-
import org.beehive.gpullama3.tornadovm.TornadoVMGenericLayerPlanner;
10+
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
1111
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
1212
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
1313
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
@@ -38,7 +38,7 @@ public class QuantizationPlannerFactory {
3838
/**
3939
* Main factory method: create planner for given model + quantization
4040
*/
41-
public static TornadoVMGenericLayerPlanner create(GGMLType quantization, State state, Model model) {
41+
public static GenericLayerPlanner create(GGMLType quantization, State state, Model model) {
4242
return switch (quantization) {
4343
case F32 -> createFP32Planner(state, model);
4444
case F16 -> createFP16Planner(state, model);
@@ -48,7 +48,7 @@ public static TornadoVMGenericLayerPlanner create(GGMLType quantization, State s
4848
}
4949

5050
// ============ FP16 Planners ============
51-
private static TornadoVMGenericLayerPlanner createFP16Planner(State state, Model model) {
51+
private static GenericLayerPlanner createFP16Planner(State state, Model model) {
5252
return switch (model.getModelType()) {
5353
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
5454
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
@@ -60,7 +60,7 @@ private static TornadoVMGenericLayerPlanner createFP16Planner(State state, Model
6060
}
6161

6262
// ============ Q8_0 Planners ============
63-
private static TornadoVMGenericLayerPlanner createQ8_0Planner(State state, Model model) {
63+
private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
6464
return switch (model.getModelType()) {
6565
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
6666
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
@@ -72,7 +72,7 @@ private static TornadoVMGenericLayerPlanner createQ8_0Planner(State state, Model
7272
}
7373

7474
// ============ FP32 Planners (FUTURE) ============
75-
private static TornadoVMGenericLayerPlanner createFP32Planner(State state, Model model) {
75+
private static GenericLayerPlanner createFP32Planner(State state, Model model) {
7676
throw new UnsupportedOperationException("FP32 planners not yet implemented");
7777
}
7878
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.Model;
7-
import org.beehive.gpullama3.tornadovm.TornadoVMGenericLayerPlanner;
7+
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
88
import uk.ac.manchester.tornado.api.KernelContext;
99

1010
/**
1111
* Abstract base for all quantization-specific planners.
1212
*
1313
* Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc.
1414
*/
15-
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights> implements TornadoVMGenericLayerPlanner {
15+
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights> implements GenericLayerPlanner {
1616

1717
// Common state for all quantizations
1818
protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;

src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
78
import uk.ac.manchester.tornado.api.GridScheduler;
89
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
910
import uk.ac.manchester.tornado.api.TaskGraph;
@@ -25,9 +26,7 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur
2526

2627
@Override
2728
public GridScheduler updateGridScheduler(GridScheduler scheduler) {
28-
WorkerGrid singleWorker = new WorkerGrid1D(1);
29-
singleWorker.setGlobalWork(1, 1, 1);
30-
singleWorker.setLocalWork(1, 1, 1);
29+
WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker();
3130
scheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
3231
return scheduler;
3332
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
4646
logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
4747
}
4848

49+
int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
50+
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
51+
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
4952

5053
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
5154
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);

0 commit comments

Comments
 (0)