Skip to content

Commit 8a00ded

Browse files
authored
Merge pull request #101 from beehive-lab/refactor/simplify-layerplanner
[refactor] Simplify and unify the TornadoVM layer planner infrastructure
2 parents e12a4c1 + a3f1450 commit 8a00ded

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1076
-1285
lines changed

src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ public List<Integer> encodeMessage(Message message) {
101101

102102
@Override
103103
public int getBeginOfText() {
104-
return beginOfText;
104+
if (beginOfText == -1) {
105+
// deepseek-r1
106+
return startHeader;
107+
} else {
108+
return beginOfText;
109+
}
105110
}
106111

107112
@Override

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
1313
import org.beehive.gpullama3.model.format.ChatFormat;
1414
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
15+
import org.beehive.gpullama3.model.qwen2.DeepSeekR1Qwen;
1516
import org.beehive.gpullama3.model.qwen2.Qwen2;
1617
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1718
import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
@@ -85,7 +86,9 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig
8586
// Qwen2.5-Coder uses <|endoftext|> as stop-token.
8687
ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
8788
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
88-
return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
89+
return isDeepSeekR1DistillQwen
90+
? new DeepSeekR1Qwen(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens))
91+
: new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
8992
}
9093
// @formatter:on
9194

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.beehive.gpullama3.model.qwen2;
2+
3+
import org.beehive.gpullama3.inference.weights.Weights;
4+
import org.beehive.gpullama3.model.ModelType;
5+
import org.beehive.gpullama3.model.format.ChatFormat;
6+
import org.beehive.gpullama3.tokenizer.Tokenizer;
7+
8+
public class DeepSeekR1Qwen extends Qwen2 {
9+
10+
public DeepSeekR1Qwen(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
11+
super(configuration, tokenizer, weights, chatFormat);
12+
}
13+
14+
@Override
15+
public ModelType getModelType() {
16+
return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
17+
}
18+
19+
@Override
20+
public boolean shouldAddBeginOfText() {
21+
return true;
22+
}
23+
}

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import org.beehive.gpullama3.model.Configuration;
55
import org.beehive.gpullama3.model.Model;
66
import org.beehive.gpullama3.tensor.GGMLType;
7-
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
89
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
910
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
1011
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java renamed to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.beehive.gpullama3.tornadovm;
1+
package org.beehive.gpullama3.tornadovm.layerplanner;
22

33
import uk.ac.manchester.tornado.api.GridScheduler;
44
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java renamed to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.beehive.gpullama3.tornadovm.layerplanner.base;
1+
package org.beehive.gpullama3.tornadovm.layerplanner;
22

33
import org.beehive.gpullama3.inference.state.GraniteState;
44
import org.beehive.gpullama3.tensor.GGMLType;
@@ -8,14 +8,15 @@
88
import org.beehive.gpullama3.inference.state.Qwen3State;
99
import org.beehive.gpullama3.inference.state.State;
1010
import org.beehive.gpullama3.model.Model;
11-
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
1211
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
1312
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
13+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner;
1414
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
1515
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
1616
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
1717
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner;
1818
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
19+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner;
1920
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
2021
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
2122
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
@@ -54,7 +55,8 @@ public static GenericLayerPlanner create(GGMLType quantization, State state, Mod
5455
// ============ FP16 Planners ============
5556
private static GenericLayerPlanner createFP16Planner(State state, Model model) {
5657
return switch (model.getModelType()) {
57-
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
58+
case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model);
59+
case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model);
5860
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
5961
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
6062
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
@@ -67,7 +69,8 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) {
6769
// ============ Q8_0 Planners ============
6870
private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
6971
return switch (model.getModelType()) {
70-
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
72+
case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
73+
case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model);
7174
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
7275
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
7376
case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner;
2+
3+
import org.beehive.gpullama3.inference.state.State;
4+
import org.beehive.gpullama3.inference.weights.Weights;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
9+
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
10+
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
11+
import org.beehive.gpullama3.tornadovm.layers.Activation;
12+
import uk.ac.manchester.tornado.api.GridScheduler;
13+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
14+
import uk.ac.manchester.tornado.api.KernelContext;
15+
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
19+
/**
20+
* Abstract base for all quantization-specific planners.
21+
*
22+
* Extracts common state from the model, detects the hardware scheduler type,
23+
* and assembles the full execution plan via createTornadoInferencePlan().
24+
* Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation.
25+
*/
26+
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights>
27+
implements GenericLayerPlanner {
28+
29+
protected final S state;
30+
protected final C config;
31+
protected final W weights;
32+
protected final KernelContext context;
33+
protected final Model model;
34+
protected final SchedulerType schedulerType;
35+
36+
protected Activation activationLayer;
37+
protected AbstractFFNLayers<W, C> ffnLayers;
38+
protected AbstractLogitsLayer logitsLayer;
39+
40+
private List<ImmutableTaskGraph> immutableTaskGraphs;
41+
private GridScheduler gridScheduler;
42+
43+
@SuppressWarnings("unchecked")
44+
protected QuantizedLayerPlanner(S state, Model model) {
45+
this.state = state;
46+
this.model = model;
47+
this.config = (C) model.configuration();
48+
this.weights = (W) model.weights();
49+
this.context = new KernelContext();
50+
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
51+
validateQuantizationType();
52+
}
53+
54+
/** Validates that the model weights match the expected quantization type. */
55+
protected abstract void validateQuantizationType();
56+
57+
/**
58+
* Creates the TornadoVM inference execution pipeline.
59+
* It represents the entire Feed-Forward Network (FFN) and consists of:
60+
* <ul>
61+
* <li>Activation layer</li>
62+
* <li>FFN layers (N transformer layers, model-specific)</li>
63+
* <li>Logits layer</li>
64+
* </ul>
65+
* <p>
66+
* Each component is represented as an {@link ImmutableTaskGraph}, along with a
67+
* corresponding {@link GridScheduler} configuration that defines how tasks are
68+
* mapped on the GPU.
69+
* </p>
70+
* This method assembles all components into a unified execution pipeline and
71+
* caches the resulting task graphs and scheduler for reuse across inference runs.
72+
*/
73+
protected final void createTornadoInferencePlan() {
74+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
75+
GridScheduler masterScheduler = new GridScheduler();
76+
77+
// 1. Activation layer (common to all models)
78+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
79+
activationLayer.updateGridScheduler(masterScheduler);
80+
81+
// 2. FFN layers (N transformer layers - model-specific)
82+
allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
83+
ffnLayers.updateGridScheduler(masterScheduler);
84+
85+
// 3. Logits layer (common to all models)
86+
allTaskGraphs.add(logitsLayer.getImmutableTaskGraph());
87+
logitsLayer.updateGridScheduler(masterScheduler);
88+
89+
// Cache for future retrievals
90+
this.immutableTaskGraphs = allTaskGraphs;
91+
this.gridScheduler = masterScheduler;
92+
}
93+
94+
@Override
95+
public final List<ImmutableTaskGraph> getImmutableTaskGraphs() {
96+
return this.immutableTaskGraphs;
97+
}
98+
99+
@Override
100+
public final GridScheduler getGridScheduler() {
101+
return this.gridScheduler;
102+
}
103+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java

Lines changed: 0 additions & 65 deletions
This file was deleted.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
2+
3+
import org.beehive.gpullama3.tensor.GGMLType;
4+
import org.beehive.gpullama3.inference.state.State;
5+
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
6+
import org.beehive.gpullama3.model.Configuration;
7+
import org.beehive.gpullama3.model.Model;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner;
9+
10+
/**
11+
* Base for all FP16-quantized layer planners.
12+
*/
13+
public abstract class FP16LayerPlanner<S extends State, C extends Configuration, W extends TornadoWeights> extends QuantizedLayerPlanner<S, C, W> {
14+
15+
protected FP16LayerPlanner(S state, Model model) {
16+
super(state, model);
17+
}
18+
19+
@Override
20+
protected void validateQuantizationType() {
21+
if (this.weights.getWeightType() != GGMLType.F16) {
22+
throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType());
23+
}
24+
}
25+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,17 @@
44
import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
55
import org.beehive.gpullama3.model.Model;
66
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
7-
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
87
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
98
import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
109
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
1110

1211
public class GraniteFP16LayerPlanner extends FP16LayerPlanner<GraniteState, GraniteConfiguration, GraniteTornadoWeights> {
12+
1313
public GraniteFP16LayerPlanner(GraniteState state, Model model) {
1414
super(state, model);
15-
validateQuantizationType();
16-
setupTornadoForwardPlan();
17-
}
18-
19-
@Override
20-
protected void initializeLayerComponents() {
21-
this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
22-
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
23-
this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
15+
this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
16+
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
17+
this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
18+
createTornadoInferencePlan();
2419
}
25-
2620
}

0 commit comments

Comments
 (0)