Skip to content

Commit cec38b3

Browse files
committed
Refactor Qwen2 and Qwen3 TornadoVM layers:
- Simplified data transfer logic and removed unnecessary comments. - Deprecated `Qwen3TornadoVMLayerPlanner` and related Qwen3 implementation elements. - Renamed and aligned temporary buffers for clarity.
1 parent a8a8c68 commit cec38b3

11 files changed

Lines changed: 2654 additions & 1582 deletions

src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java

Lines changed: 357 additions & 357 deletions
Large diffs are not rendered by default.

src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java

Lines changed: 362 additions & 362 deletions
Large diffs are not rendered by default.

src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java

Lines changed: 397 additions & 397 deletions
Large diffs are not rendered by default.

src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 386 additions & 386 deletions
Large diffs are not rendered by default.

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,6 @@ private TornadoVMGenericLayerPlanner createPlannerWithStrategy(State state, Mode
147147
// Factory handles all model × quantization combinations
148148
TornadoVMGenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model);
149149

150-
// ========== STEP 3: Detect Hardware ==========
151-
SchedulerType hardwareType = this.schedulerDetectionService; // Already set in constructor
152-
153-
// ========== STEP 4: Select Strategy ==========
154-
// HardwareStrategy strategy = selectStrategy(hardwareType);
155-
156-
// ========== STEP 5: Wrap with Hardware Optimization ==========
157150
return basePlanner;
158151
}
159152

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,36 @@
22

33
import org.beehive.gpullama3.core.model.GGMLType;
44
import org.beehive.gpullama3.inference.state.LlamaState;
5+
import org.beehive.gpullama3.inference.state.Phi3State;
56
import org.beehive.gpullama3.inference.state.Qwen2State;
67
import org.beehive.gpullama3.inference.state.Qwen3State;
78
import org.beehive.gpullama3.inference.state.State;
89
import org.beehive.gpullama3.model.Model;
910
import org.beehive.gpullama3.tornadovm.TornadoVMGenericLayerPlanner;
1011
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
12+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
1113
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
1214
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
1315
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
14-
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
16+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
1517
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
18+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
1619

1720
/**
18-
* Factory: Creates the appropriate planner based on model type + quantization.
19-
*
20-
* Routing Logic: 1. Determine quantization type from GGMLType 2. Determine model type from Model 3. Instantiate appropriate planner
21-
*
22-
* Example: QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner
21+
* Factory class responsible for creating appropriate layer planners based on model type and quantization.
22+
* <p>
23+
* The factory follows a routing logic:
24+
* <ol>
25+
* <li>Determine quantization type from {@link GGMLType}</li>
26+
* <li>Determine model type from {@link Model}</li>
27+
* <li>Instantiate appropriate planner implementation</li>
28+
* </ol>
29+
* <p>
30+
* Examples:
31+
* <ul>
32+
* <li>{@code QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner}</li>
33+
* <li>{@code QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner}</li>
34+
* </ul>
2335
*/
2436
public class QuantizationPlannerFactory {
2537

@@ -36,36 +48,30 @@ public static TornadoVMGenericLayerPlanner create(GGMLType quantization, State s
3648
}
3749

3850
// ============ FP16 Planners ============
39-
4051
private static TornadoVMGenericLayerPlanner createFP16Planner(State state, Model model) {
4152
return switch (model.getModelType()) {
4253
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
43-
// case MISTRAL -> new MistralFP16LayerPlanner(state, model);
4454
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
4555
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
46-
// case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
47-
// case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
56+
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
57+
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
4858
default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType());
4959
};
5060
}
5161

5262
// ============ Q8_0 Planners ============
53-
5463
private static TornadoVMGenericLayerPlanner createQ8_0Planner(State state, Model model) {
5564
return switch (model.getModelType()) {
5665
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
5766
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
5867
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
59-
// case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
60-
// case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
61-
// case MISTRAL -> throw new UnsupportedOperationException(
62-
// "Q8_0 not supported for MISTRAL (use FP16)");
68+
case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
69+
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
6370
default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
6471
};
6572
}
6673

6774
// ============ FP32 Planners (FUTURE) ============
68-
6975
private static TornadoVMGenericLayerPlanner createFP32Planner(State state, Model model) {
7076
throw new UnsupportedOperationException("FP32 planners not yet implemented");
7177
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,59 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
3535
this.logitsTaskGraph = setupLogitsTaskGraph(fp16Weights , config);
3636
}
3737

38+
private TaskGraph setupLogitNonNVidia(FP16Weights weights, Configuration config) {
39+
TaskGraph logits = new TaskGraph("logits")
40+
.consumeFromDevice(lastTaskGraphID,
41+
state.wrapX
42+
)
43+
.transferToDevice(DataTransferMode.EVERY_EXECUTION,
44+
state.tempLogits
45+
)
46+
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
47+
context,
48+
state.wrapLogits,
49+
weights.wclsHalfFloat,
50+
weights.rms_final_weight_as_floatArray
51+
)
52+
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
53+
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
54+
.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
55+
weights.rms_final_weight_as_floatArray, state.tempLogits);
56+
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
57+
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
58+
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
59+
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
60+
return logits;
61+
}
62+
3863
/**
3964
* Builds the logits computation graph.
4065
*/
4166
private TaskGraph setupLogitsTaskGraph(FP16Weights weights, Configuration config) {
4267

43-
TaskGraph logits = new TaskGraph("logits")
44-
.consumeFromDevice(lastTaskGraphID,
45-
state.wrapX
46-
)
47-
.transferToDevice(DataTransferMode.EVERY_EXECUTION,
48-
state.tempLogits
49-
)
50-
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
51-
context,
52-
state.wrapLogits,
53-
weights.wclsHalfFloat,
54-
weights.rms_final_weight_as_floatArray
55-
)
56-
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
57-
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
58-
.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
59-
weights.rms_final_weight_as_floatArray, state.tempLogits);
60-
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric,
61-
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat,
62-
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
63-
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
64-
65-
return logits;
68+
TaskGraph logits = new TaskGraph("logits")
69+
.consumeFromDevice(lastTaskGraphID,
70+
state.wrapX
71+
)
72+
.transferToDevice(DataTransferMode.EVERY_EXECUTION,
73+
state.tempLogits
74+
)
75+
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
76+
context,
77+
state.wrapLogits,
78+
weights.wclsHalfFloat,
79+
weights.rms_final_weight_as_floatArray
80+
)
81+
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
82+
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
83+
.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
84+
weights.rms_final_weight_as_floatArray, state.tempLogits);
85+
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric,
86+
context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat,
87+
config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
88+
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
89+
90+
return logits;
6691
}
6792

6893
private GridScheduler setupGridSchedulerForLogits(Configuration config) {
@@ -85,22 +110,42 @@ private GridScheduler setupGridSchedulerForLogits(Configuration config) {
85110
return scheduler;
86111
}
87112

88-
@Override
89-
public GridScheduler updateGridScheduler(GridScheduler scheduler) {
90-
// RMSNorm operations
91-
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
92-
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
93-
rmsNormWorker.setLocalWork(256, 1, 1);
94-
95-
// Projection kernel (vocabulary size × hidden dim)
96-
int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
97-
WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal);
98-
projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
113+
// @Override
114+
// public GridScheduler updateGridScheduler(GridScheduler scheduler) {
115+
// // RMSNorm operations
116+
// WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
117+
// rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
118+
// rmsNormWorker.setLocalWork(256, 1, 1);
119+
//
120+
// // Projection kernel (vocabulary size × hidden dim)
121+
// int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
122+
// WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal);
123+
// projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
124+
//
125+
// scheduler.addWorkerGrid("logits.projection", projectionWorker);
126+
// scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
127+
// scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
128+
//
129+
// return scheduler;
130+
// }
99131

100-
scheduler.addWorkerGrid("logits.projection", projectionWorker);
101-
scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
102-
scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
103132

133+
@Override
134+
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
135+
// RMSNorm operations
136+
WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
137+
rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
138+
rmsNormWorker.setLocalWork(256, 1, 1);
139+
140+
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
141+
// CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
142+
int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
143+
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
144+
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
145+
146+
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
147+
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
148+
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
104149
return scheduler;
105150
}
106151

0 commit comments

Comments
 (0)