Skip to content

Commit a8a8c68

Browse files
committed
Add Q8_0 and FP16 layer implementations for Qwen3 and TornadoVM
This commit introduces new layer components and planners tailored for Qwen3 and TornadoVM environments, including: - LogitsQ8_0Layer class for handling Q8_0 weights - Qwen3FP16FFNLayers for managing FP16 weights in Qwen3 architecture - Qwen3FP16LayerPlanner for planning TornadoVM operations using FP16 layers and weights These additions enhance compatibility and extend functionality for the Qwen3 model.
1 parent 3ead7e4 commit a8a8c68

18 files changed

Lines changed: 3254 additions & 0 deletions
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.beehive.gpullama3.tornadovm;
2+
3+
import java.io.IOException;
4+
5+
public class GPULLlama3TypeException extends IllegalArgumentException {
6+
public GPULLlama3TypeException(String message) {
7+
super(message);
8+
}
9+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner;
2+
3+
import org.beehive.gpullama3.auxiliary.Tuple2;
4+
import uk.ac.manchester.tornado.api.GridScheduler;
5+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
6+
7+
import java.util.List;
8+
9+
public interface GenericLayerPlanner {
10+
Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered();
11+
12+
Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia();
13+
14+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.base;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import org.beehive.gpullama3.inference.state.LlamaState;
5+
import org.beehive.gpullama3.inference.state.Qwen2State;
6+
import org.beehive.gpullama3.inference.state.Qwen3State;
7+
import org.beehive.gpullama3.inference.state.State;
8+
import org.beehive.gpullama3.model.Model;
9+
import org.beehive.gpullama3.tornadovm.TornadoVMGenericLayerPlanner;
10+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
11+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
12+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
13+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
14+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
15+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
16+
17+
/**
18+
* Factory: Creates the appropriate planner based on model type + quantization.
19+
*
20+
* Routing Logic: 1. Determine quantization type from GGMLType 2. Determine model type from Model 3. Instantiate appropriate planner
21+
*
22+
* Example: QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner
23+
*/
24+
public class QuantizationPlannerFactory {
25+
26+
/**
27+
* Main factory method: create planner for given model + quantization
28+
*/
29+
public static TornadoVMGenericLayerPlanner create(GGMLType quantization, State state, Model model) {
30+
return switch (quantization) {
31+
case F32 -> createFP32Planner(state, model);
32+
case F16 -> createFP16Planner(state, model);
33+
case Q8_0 -> createQ8_0Planner(state, model);
34+
default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization);
35+
};
36+
}
37+
38+
// ============ FP16 Planners ============
39+
40+
private static TornadoVMGenericLayerPlanner createFP16Planner(State state, Model model) {
41+
return switch (model.getModelType()) {
42+
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
43+
// case MISTRAL -> new MistralFP16LayerPlanner(state, model);
44+
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
45+
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
46+
// case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
47+
// case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
48+
default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType());
49+
};
50+
}
51+
52+
// ============ Q8_0 Planners ============
53+
54+
private static TornadoVMGenericLayerPlanner createQ8_0Planner(State state, Model model) {
55+
return switch (model.getModelType()) {
56+
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
57+
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
58+
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
59+
// case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
60+
// case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
61+
// case MISTRAL -> throw new UnsupportedOperationException(
62+
// "Q8_0 not supported for MISTRAL (use FP16)");
63+
default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
64+
};
65+
}
66+
67+
// ============ FP32 Planners (FUTURE) ============
68+
69+
private static TornadoVMGenericLayerPlanner createFP32Planner(State state, Model model) {
70+
throw new UnsupportedOperationException("FP32 planners not yet implemented");
71+
}
72+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.base;
2+
3+
import org.beehive.gpullama3.inference.state.State;
4+
import org.beehive.gpullama3.inference.weights.Weights;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.tornadovm.TornadoVMGenericLayerPlanner;
8+
import uk.ac.manchester.tornado.api.KernelContext;
9+
10+
/**
11+
* Abstract base for all quantization-specific planners.
12+
*
13+
* Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc.
14+
*/
15+
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights> implements TornadoVMGenericLayerPlanner {
16+
17+
// Common state for all quantizations
18+
protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;
19+
protected static final int THREAD_SCALE_FOR_LOGITS = 8;
20+
21+
protected final S state;
22+
protected final C config;
23+
protected final W weights;
24+
protected final KernelContext context;
25+
26+
/**
27+
* Constructor: validate quantization type, extract model components
28+
*/
29+
protected QuantizedLayerPlanner(S state, Model model) {
30+
this.state = state;
31+
this.config = (C) model.configuration();
32+
this.weights = (W) model.weights();
33+
this.context = new KernelContext();
34+
35+
validateQuantizationType();
36+
}
37+
38+
/**
39+
* Override in subclasses to validate correct quantization format. E.g., FP16LayerPlanner checks: weights instanceof FP16Weights
40+
*/
41+
protected abstract void validateQuantizationType();
42+
43+
/**
44+
* Override in subclasses for model-specific initialization
45+
*/
46+
protected abstract void initializeLayerComponents();
47+
48+
// Common helper methods for all quantizations
49+
protected C getConfig() {
50+
return config;
51+
}
52+
53+
protected W getWeights() {
54+
return weights;
55+
}
56+
57+
protected S getState() {
58+
return state;
59+
}
60+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
2+
3+
import org.beehive.gpullama3.auxiliary.Tuple2;
4+
import org.beehive.gpullama3.inference.state.LlamaState;
5+
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.LlamaTornadoWeights;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
8+
import org.beehive.gpullama3.tornadovm.GPULLlama3TypeException;
9+
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
10+
import org.beehive.gpullama3.tornadovm.layers.Activation;
11+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
12+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
13+
import uk.ac.manchester.tornado.api.GridScheduler;
14+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
15+
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
19+
public class LlamaFP16LayerPlanner extends FP16LayerPlanner<LlamaState, LlamaConfiguration, LlamaTornadoWeights> {
20+
21+
private Activation activationLayer;
22+
private LlamaFP16FFNLayers ffnLayers;
23+
private LogitsFP16Layer logitsLayer;
24+
25+
// Cache
26+
private List<ImmutableTaskGraph> cachedTaskGraphs;
27+
private GridScheduler cachedScheduler;
28+
29+
public LlamaFP16LayerPlanner(LlamaState state, Model model) {
30+
super(state, model);
31+
validateQuantizationType();
32+
setupTornadoForwardPlan();
33+
}
34+
35+
@Override
36+
protected void initializeLayerComponents() {
37+
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
38+
39+
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config);
40+
41+
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
42+
}
43+
44+
@Override
45+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() {
46+
if (this.cachedTaskGraphs != null && this.cachedScheduler != null) {
47+
return new Tuple2<>(this.cachedTaskGraphs, this.cachedScheduler);
48+
}
49+
50+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
51+
GridScheduler masterScheduler = new GridScheduler();
52+
53+
// 1. Activation layer
54+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
55+
activationLayer.updateGridScheduler(masterScheduler);
56+
57+
// 2. FFN layers (N transformer layers)
58+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
59+
ffnLayers.updateGridScheduler(masterScheduler);
60+
61+
// 3. Logits layer
62+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
63+
logitsLayer.updateGridScheduler(masterScheduler);
64+
65+
// Cache
66+
this.cachedTaskGraphs = allTaskGraphs;
67+
this.cachedScheduler = masterScheduler;
68+
69+
return new Tuple2<>(allTaskGraphs, masterScheduler);
70+
}
71+
72+
public void setupTornadoForwardPlan() {
73+
74+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
75+
GridScheduler masterScheduler = new GridScheduler();
76+
77+
// 1. Activation layer
78+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
79+
activationLayer.updateGridScheduler(masterScheduler);
80+
81+
// 2. FFN layers (N transformer layers)
82+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
83+
ffnLayers.updateGridScheduler(masterScheduler);
84+
85+
// 3. Logits layer
86+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
87+
logitsLayer.updateGridScheduler(masterScheduler);
88+
89+
// Cache
90+
this.cachedTaskGraphs = allTaskGraphs;
91+
this.cachedScheduler = masterScheduler;
92+
93+
}
94+
95+
@Override
96+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
97+
// For now, same as NVIDIA version
98+
// Hardware strategy will optimize scheduler
99+
return setupTornadoForwardPlanLayered();
100+
}
101+
102+
public List<ImmutableTaskGraph> getCachedTaskGraphs() {
103+
return this.cachedTaskGraphs;
104+
}
105+
106+
@Override
107+
public GridScheduler getCachedGridScheduler() {
108+
return this.cachedScheduler;
109+
}
110+
111+
public void clearCache() {
112+
this.cachedTaskGraphs = null;
113+
this.cachedScheduler = null;
114+
}
115+
}

0 commit comments

Comments
 (0)