Skip to content

Commit 3fb3a0c

Browse files
committed
Add Phi3 TornadoVM layer planners for FP16 and Q8_0 quantizations
Introduce `Phi3FP16LayerPlanner` and `Phi3Q8_0LayerPlanner`, enabling TornadoVM support for the Phi3 model with FP16 and Q8_0 weights. These planners implement Phi3-specific layer components and caching mechanisms for task graphs and schedulers.
1 parent e2be881 commit 3fb3a0c

2 files changed

Lines changed: 247 additions & 0 deletions

File tree

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
2+
3+
import org.beehive.gpullama3.auxiliary.Tuple2;
4+
import org.beehive.gpullama3.inference.state.Phi3State;
5+
import org.beehive.gpullama3.inference.weights.tornado.FP16Weights.Phi3TornadoWeights;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
9+
import org.beehive.gpullama3.tornadovm.layers.Activation;
10+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
11+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers;
12+
import uk.ac.manchester.tornado.api.GridScheduler;
13+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
18+
/**
19+
* Phi3FP16LayerPlanner: Phi3 model with FP16 weights.
20+
*
21+
* Follows the same pattern as Qwen3FP16LayerPlanner but with:
22+
* - Phi3-specific FFN layers (combined QKV + gate/up FFN)
23+
* - Phi3TornadoWeights
24+
* - Phi3Configuration
25+
*
26+
* Inherits from FP16LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeights>
27+
*/
28+
public class Phi3FP16LayerPlanner extends FP16LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeights> {
29+
30+
private Activation activationLayer;
31+
private Phi3FP16FFNLayers ffnLayers;
32+
private LogitsFP16Layer logitsLayer;
33+
34+
// Cache
35+
private List<ImmutableTaskGraph> cachedTaskGraphs;
36+
private GridScheduler cachedScheduler;
37+
38+
public Phi3FP16LayerPlanner(Phi3State state, Model model) {
39+
super(state, model);
40+
validateQuantizationType();
41+
setupTornadoForwardPlan();
42+
}
43+
44+
@Override
45+
protected void initializeLayerComponents() {
46+
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
47+
48+
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config);
49+
50+
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config,
51+
ffnLayers.getLastTaskGraphID());
52+
}
53+
54+
@Override
55+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() {
56+
if (this.cachedTaskGraphs != null && this.cachedScheduler != null) {
57+
return new Tuple2<>(this.cachedTaskGraphs, this.cachedScheduler);
58+
}
59+
60+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
61+
GridScheduler masterScheduler = new GridScheduler();
62+
63+
// 1. Activation layer
64+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
65+
activationLayer.updateGridScheduler(masterScheduler);
66+
67+
// 2. FFN layers (N transformer layers)
68+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
69+
ffnLayers.updateGridScheduler(masterScheduler);
70+
71+
// 3. Logits layer
72+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
73+
logitsLayer.updateGridScheduler(masterScheduler);
74+
75+
// Cache
76+
this.cachedTaskGraphs = allTaskGraphs;
77+
this.cachedScheduler = masterScheduler;
78+
79+
return new Tuple2<>(allTaskGraphs, masterScheduler);
80+
}
81+
82+
public void setupTornadoForwardPlan() {
83+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
84+
GridScheduler masterScheduler = new GridScheduler();
85+
86+
// 1. Activation layer
87+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
88+
activationLayer.updateGridScheduler(masterScheduler);
89+
90+
// 2. FFN layers (N transformer layers)
91+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
92+
ffnLayers.updateGridScheduler(masterScheduler);
93+
94+
// 3. Logits layer
95+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
96+
logitsLayer.updateGridScheduler(masterScheduler);
97+
98+
// Cache
99+
this.cachedTaskGraphs = allTaskGraphs;
100+
this.cachedScheduler = masterScheduler;
101+
}
102+
103+
@Override
104+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
105+
// For now, same as NVIDIA version
106+
// Hardware strategy will optimize scheduler
107+
return setupTornadoForwardPlanLayered();
108+
}
109+
110+
public List<ImmutableTaskGraph> getCachedTaskGraphs() {
111+
return this.cachedTaskGraphs;
112+
}
113+
114+
@Override
115+
public GridScheduler getCachedGridScheduler() {
116+
return this.cachedScheduler;
117+
}
118+
119+
public void clearCache() {
120+
this.cachedTaskGraphs = null;
121+
this.cachedScheduler = null;
122+
}
123+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
2+
3+
import org.beehive.gpullama3.auxiliary.Tuple2;
4+
import org.beehive.gpullama3.inference.state.Phi3State;
5+
import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights.Phi3TornadoWeightsQ8_0;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner;
9+
import org.beehive.gpullama3.tornadovm.layers.Activation;
10+
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
11+
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers;
12+
import uk.ac.manchester.tornado.api.GridScheduler;
13+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
18+
/**
19+
* Phi3Q8_0LayerPlanner: Phi3 model with Q8_0-quantized weights.
20+
*
21+
* Follows the same pattern as Qwen3Q8_0LayerPlanner but with:
22+
* - Phi3-specific FFN layers (combined QKV + gate/up FFN)
23+
* - Phi3TornadoWeightsQ8_0 (8-bit integer quantization)
24+
* - Phi3Configuration
25+
* - 2x memory compression vs FP16
26+
*
27+
* Inherits from Q8_0LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeightsQ8_0>
28+
*/
29+
public class Phi3Q8_0LayerPlanner extends Q8_0LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeightsQ8_0> {
30+
31+
private Activation activationLayer;
32+
private Phi3Q8_0FFNLayers ffnLayers;
33+
private LogitsQ8_0Layer logitsLayer;
34+
35+
// Cache
36+
private List<ImmutableTaskGraph> cachedTaskGraphs;
37+
private GridScheduler cachedScheduler;
38+
39+
public Phi3Q8_0LayerPlanner(Phi3State state, Model model) {
40+
super(state, model);
41+
validateQuantizationType();
42+
setupTornadoForwardPlan();
43+
}
44+
45+
@Override
46+
protected void initializeLayerComponents() {
47+
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
48+
49+
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config);
50+
51+
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config,
52+
ffnLayers.getLastTaskGraphID());
53+
}
54+
55+
@Override
56+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() {
57+
if (this.cachedTaskGraphs != null && this.cachedScheduler != null) {
58+
return new Tuple2<>(this.cachedTaskGraphs, this.cachedScheduler);
59+
}
60+
61+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
62+
GridScheduler masterScheduler = new GridScheduler();
63+
64+
// 1. Activation layer
65+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
66+
activationLayer.updateGridScheduler(masterScheduler);
67+
68+
// 2. FFN layers (N transformer layers with Q8_0 quantization)
69+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
70+
ffnLayers.updateGridScheduler(masterScheduler);
71+
72+
// 3. Logits layer
73+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
74+
logitsLayer.updateGridScheduler(masterScheduler);
75+
76+
// Cache
77+
this.cachedTaskGraphs = allTaskGraphs;
78+
this.cachedScheduler = masterScheduler;
79+
80+
return new Tuple2<>(allTaskGraphs, masterScheduler);
81+
}
82+
83+
public void setupTornadoForwardPlan() {
84+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
85+
GridScheduler masterScheduler = new GridScheduler();
86+
87+
// 1. Activation layer
88+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
89+
activationLayer.updateGridScheduler(masterScheduler);
90+
91+
// 2. FFN layers (N transformer layers with Q8_0 quantization)
92+
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
93+
ffnLayers.updateGridScheduler(masterScheduler);
94+
95+
// 3. Logits layer
96+
allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot());
97+
logitsLayer.updateGridScheduler(masterScheduler);
98+
99+
// Cache
100+
this.cachedTaskGraphs = allTaskGraphs;
101+
this.cachedScheduler = masterScheduler;
102+
}
103+
104+
@Override
105+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
106+
// For now, same as NVIDIA version
107+
// Hardware strategy will optimize scheduler
108+
return setupTornadoForwardPlanLayered();
109+
}
110+
111+
public List<ImmutableTaskGraph> getCachedTaskGraphs() {
112+
return this.cachedTaskGraphs;
113+
}
114+
115+
@Override
116+
public GridScheduler getCachedGridScheduler() {
117+
return this.cachedScheduler;
118+
}
119+
120+
public void clearCache() {
121+
this.cachedTaskGraphs = null;
122+
this.cachedScheduler = null;
123+
}
124+
}

0 commit comments

Comments
 (0)