|
12 | 12 | import org.beehive.gpullama3.inference.weights.Weights; |
13 | 13 | import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; |
14 | 14 | import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; |
| 15 | +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0; |
15 | 16 | import org.beehive.gpullama3.model.Configuration; |
16 | 17 | import org.beehive.gpullama3.model.format.ChatFormat; |
17 | 18 | import org.beehive.gpullama3.model.phi3.Phi3; |
@@ -100,20 +101,42 @@ private Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configur |
100 | 101 |
|
101 | 102 | if (useTornadovm) { |
102 | 103 | if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { |
103 | | - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); |
| 104 | + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")"); |
| 105 | + } |
| 106 | + if (outputWeight.ggmlType() == GGMLType.Q8_0) { |
| 107 | + return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); |
| 108 | + } else { |
| 109 | + return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); |
104 | 110 | } |
105 | | - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); |
106 | 111 | } else { |
107 | 112 | return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); |
108 | 113 | } |
109 | 114 | } |
110 | 115 | // @formatter:on |
111 | 116 |
|
112 | 117 | // @formatter:off |
113 | | - @Override |
114 | | - public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, |
| 118 | + public Weights createTornadoVMWeightsQ8_0(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, |
115 | 119 | Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, |
116 | 120 | GGMLTensorEntry outputWeight) { |
| 121 | + return new Phi3TornadoWeightsQ8_0( |
| 122 | + loadTensorAsFloatArray(tokenEmbeddings), |
| 123 | + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), |
| 124 | + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV |
| 125 | + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo |
| 126 | + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), |
| 127 | + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown |
| 128 | + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) |
| 129 | + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), |
| 130 | + FloatArray.fromArray(ropeFreqs.first()), |
| 131 | + FloatArray.fromArray(ropeFreqs.second()), |
| 132 | + loadQ8_0QuantizedTensor(outputWeight), |
| 133 | + outputWeight.ggmlType() |
| 134 | + ); |
| 135 | + } |
| 136 | + |
| 137 | + public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, |
| 138 | + Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, |
| 139 | + GGMLTensorEntry outputWeight) { |
117 | 140 | return new Phi3TornadoWeights( |
118 | 141 | loadTensorAsFloatArray(tokenEmbeddings), |
119 | 142 | loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), |
|
0 commit comments