44import org .beehive .gpullama3 .inference .state .State ;
55import org .beehive .gpullama3 .model .Configuration ;
66import org .beehive .gpullama3 .model .Model ;
7- import org .beehive .gpullama3 .model .ModelType ;
87import org .beehive .gpullama3 .tornadovm .layerplanner .base .QuantizationPlannerFactory ;
9- import org .beehive .gpullama3 .tornadovm .layers .SchedulerDetectionService ;
10- import org .beehive .gpullama3 .tornadovm .layers .SchedulerType ;
118import uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
129import uk .ac .manchester .tornado .api .TornadoExecutionPlan ;
13- import uk .ac .manchester .tornado .api .TornadoRuntime ;
14- import uk .ac .manchester .tornado .api .runtime .TornadoRuntimeProvider ;
1510import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
1611
17- import java .util .Locale ;
18-
1912public class TornadoVMMasterPlan {
2013 public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean .parseBoolean (System .getProperty ("llama.EnableTimingForTornadoVMInit" , "False" ));
2114
2215 private final State state ;
2316 private final Configuration config ;
2417 public TornadoExecutionPlan executionPlan ;
25- private SchedulerType schedulerDetectionService ;
26- TornadoVMGenericLayerPlanner tornadoVMLayerPlanner ;
18+ GenericLayerPlanner tornadoVMLayerPlanner ;
2719
2820 public TornadoVMMasterPlan (State state , Model model ) {
29- // this.schedulerDetectionService = SchedulerDetectionService.determineSchedulerType(model);
30-
3121 this .tornadoVMLayerPlanner = createPlannerWithStrategy (state , model );
3222 this .executionPlan = new TornadoExecutionPlan (tornadoVMLayerPlanner .getCachedTaskGraphs ().toArray (new ImmutableTaskGraph [tornadoVMLayerPlanner .getCachedTaskGraphs ().size ()]));
33-
3423 this .state = state ;
3524 this .config = model .configuration ();
3625 }
@@ -57,7 +46,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
5746 }
5847
5948 // 1. Pre-allocate the TornadoVM plan
60- TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model );
49+ TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model );
6150
6251 // Record time after plan creation
6352 if (ENABLE_TORNADOVM_INIT_TIME ) {
@@ -89,81 +78,16 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
8978 return tornadoVMPlan ;
9079 }
9180
92- /**
93- * Dispatcher method to select the TornadoVMLayerPlanner for the model.
94- */
95- // TornadoVMGenericLayerPlanner createPlanner(State state, Model model) {
96- // return switch (model.getModelType()) {
97- // case LLAMA_3, MISTRAL -> whatcreateLlama3Planner(state, model);
98- // // case PHI_3 -> createPhi3Planner(state, model);
99- // // case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> createQWEN2Planner(state, model);
100- // // case QWEN_3 -> createQWEN3Planner(state, model);
101- // case QWEN_2 -> null;
102- // case QWEN_3 -> null;
103- // case DEEPSEEK_R1_DISTILL_QWEN -> null;
104- // case PHI_3 -> null;
105- // case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
106- // };
107- // }
108-
109- // private TornadoVMGenericLayerPlanner whatcreateLlama3Planner(State state, Model model) {
110- // if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
111- // return new TornadoVMQ8_0LayerPlanner(state, model);
112- // } else {
113- // return new TornadoVMLayerPlanner(state, model);
114- // }
115- // }
116-
117- // private TornadoVMGenericLayerPlanner createQWEN2Planner(State state, Model model) {
118- // if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
119- // return new Qwen2Q8_0TornadoVMLayerPlanner((Qwen2State) state, model);
120- // } else {
121- // return new Qwen2TornadoVMLayerPlanner((Qwen2State) state, model);
122- // }
123- // }
124- //
125- // private TornadoVMGenericLayerPlanner createPhi3Planner(State state, Model model) {
126- // if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
127- // return new Phi3TornadoVMLayerPlannerQ8_0((Phi3State) state, model);
128- // } else {
129- // return new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
130- // }
131- // }
132- //
133- // private TornadoVMGenericLayerPlanner createQWEN3Planner(State state, Model model) {
134- // if (model.weights().getWeightType().equals(GGMLType.Q8_0)) {
135- // return new Qwen3Q8_0TornadoVMLayerPlanner((Qwen3State) state, model);
136- // } else {
137- // return new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
138- // }
139- // }
140-
141- private TornadoVMGenericLayerPlanner createPlannerWithStrategy (State state , Model model ) {
81+ private GenericLayerPlanner createPlannerWithStrategy (State state , Model model ) {
14282
14383 // ========== STEP 1: Detect Quantization Type ==========
14484 GGMLType weightType = model .weights ().getWeightType ();
14585
14686 // ========== STEP 2: Route via Factory ==========
14787 // Factory handles all model × quantization combinations
148- TornadoVMGenericLayerPlanner basePlanner = QuantizationPlannerFactory .create (weightType , state , model );
149-
150- return basePlanner ;
151- }
152-
153-
154- public static SchedulerType shouldUseNvidiaScheduler (Model model ) {
155- TornadoRuntime runtime = TornadoRuntimeProvider .getTornadoRuntime ();
156- String platformName = runtime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase (Locale .ROOT );
88+ GenericLayerPlanner basePlanner = QuantizationPlannerFactory .create (weightType , state , model );
15789
158- boolean isNvidia = platformName .contains ("nvidia" ) || platformName .contains ("cuda" ) || platformName .contains ("ptx" );
159- boolean isNotMistral = model .getModelType () != ModelType .MISTRAL ;
160-
161-
162- if (isNvidia && isNotMistral ) {
163- return SchedulerType .NVIDIA ;
164- } else {
165- return SchedulerType .NON_NVIDIA ;
166- }
90+ return basePlanner ;
16791 }
16892
16993 /**
0 commit comments