Skip to content

Commit e263025

Browse files
committed
[WIP] Support Q8_0 for Phi3 - testing pending
1 parent a18a6c9 commit e263025

4 files changed

Lines changed: 451 additions & 6 deletions

File tree

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.beehive.gpullama3.inference.weights.tornado;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
5+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
6+
7+
8+
public class Phi3TornadoWeightsQ8_0 extends Q8_0Weights {
9+
10+
// Phi3-specific weight arrays
11+
public Q8_0QuantizedTensor[] wqkvLayered; // Combined QKV weights: (layer, op_size, dim) where op_size = dim + 2 * (n_kv_heads * head_dim)
12+
public Q8_0QuantizedTensor[] wDownLayered; // FFN down projection: (layer, dim, hidden_dim)
13+
public Q8_0QuantizedTensor[] wUpLayered; // FFN up projection: (layer, hidden_dim, dim)
14+
15+
// @formatter:off
16+
public Phi3TornadoWeightsQ8_0(
17+
FloatArray tokenEmbeddingTable,
18+
FloatArray[] rms_att_weightLayered,
19+
Q8_0QuantizedTensor[] wqkvLayered, // Combined QKV weights for Phi3
20+
Q8_0QuantizedTensor[] woLayered,
21+
FloatArray[] rms_ffn_weightLayered,
22+
Q8_0QuantizedTensor[] wDownLayered, // FFN down weights
23+
Q8_0QuantizedTensor[] wUpLayered, // FFN up weights
24+
FloatArray rms_final_weight_as_floatArray,
25+
FloatArray freq_cis_realFlat,
26+
FloatArray freq_cis_imagFlat,
27+
Q8_0QuantizedTensor wclsByteArray,
28+
GGMLType weightType) {
29+
30+
// Call to Q8_0Weights constructor with null values for unused standard weights
31+
super(tokenEmbeddingTable,
32+
rms_att_weightLayered,
33+
null, // wqLayered - not used in Phi3, using combined wqkv instead
34+
null, // wkLayered - not used in Phi3, using combined wqkv instead
35+
null, // wvLayered - not used in Phi3, using combined wqkv instead
36+
woLayered,
37+
rms_ffn_weightLayered,
38+
null, // w1Layered - not used in Phi3, using wUp instead
39+
null, // w2Layered - not used in Phi3, using wDown instead
40+
null, // w3Layered - not used in Phi3, using wUp instead
41+
rms_final_weight_as_floatArray,
42+
freq_cis_realFlat,
43+
freq_cis_imagFlat,
44+
wclsByteArray,
45+
weightType);
46+
47+
// Initialize Phi3-specific fields
48+
this.wqkvLayered = wqkvLayered;
49+
this.wDownLayered = wDownLayered;
50+
this.wUpLayered = wUpLayered;
51+
}
52+
// @formatter:on
53+
}

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.beehive.gpullama3.inference.weights.Weights;
1313
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
1414
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
15+
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0;
1516
import org.beehive.gpullama3.model.Configuration;
1617
import org.beehive.gpullama3.model.format.ChatFormat;
1718
import org.beehive.gpullama3.model.phi3.Phi3;
@@ -100,20 +101,42 @@ private Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configur
100101

101102
if (useTornadovm) {
102103
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
103-
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
104+
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")");
105+
}
106+
if (outputWeight.ggmlType() == GGMLType.Q8_0) {
107+
return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
108+
} else {
109+
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
104110
}
105-
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
106111
} else {
107112
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
108113
}
109114
}
110115
// @formatter:on
111116

112117
// @formatter:off
113-
@Override
114-
public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config,
118+
public Weights createTornadoVMWeightsQ8_0(Map<String, GGMLTensorEntry> tensorEntries, Configuration config,
115119
Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
116120
GGMLTensorEntry outputWeight) {
121+
return new Phi3TornadoWeightsQ8_0(
122+
loadTensorAsFloatArray(tokenEmbeddings),
123+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
124+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
125+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
126+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
127+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
128+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
129+
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
130+
FloatArray.fromArray(ropeFreqs.first()),
131+
FloatArray.fromArray(ropeFreqs.second()),
132+
loadQ8_0QuantizedTensor(outputWeight),
133+
outputWeight.ggmlType()
134+
);
135+
}
136+
137+
public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config,
138+
Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
139+
GGMLTensorEntry outputWeight) {
117140
return new Phi3TornadoWeights(
118141
loadTensorAsFloatArray(tokenEmbeddings),
119142
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),

0 commit comments

Comments
 (0)