-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathModelLoader.java
More file actions
282 lines (252 loc) · 11.5 KB
/
Copy pathModelLoader.java
File metadata and controls
282 lines (252 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
package org.beehive.gpullama3.model.loader;
import org.beehive.gpullama3.Options;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tensor.GGUF;
import org.beehive.gpullama3.tensor.*;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.tensor.standard.*;
import org.beehive.gpullama3.tensor.tornado.FP16TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.Q8_0TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.TornadoTensor;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.*;
import java.io.IOException;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
import java.util.function.IntFunction;
public abstract class ModelLoader {
protected FileChannel fileChannel;
protected GGUF gguf;
protected int contextLength;
protected boolean loadWeights;
protected boolean useTornadovm;
public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
this.fileChannel = fileChannel;
this.gguf = gguf;
this.contextLength = contextLength;
this.loadWeights = loadWeights;
this.useTornadovm = useTornadovm;
}
private static ModelType detectModelType(Map<String, Object> metadata) {
String name = (String) metadata.get("general.name");
// Check by name first
if (name != null) {
String lowerName = name.toLowerCase();
if (lowerName.contains("mistral")) {
return ModelType.MISTRAL;
} else if (lowerName.contains("llama")) {
return ModelType.LLAMA_3;
} else if (lowerName.contains("qwen2")) {
return ModelType.QWEN_2;
} else if (lowerName.contains("qwen3")) {
return ModelType.QWEN_3;
} else if (lowerName.contains("deepseek r1 distill")) {
return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
} else if (lowerName.contains("phi3") || lowerName.contains("phi-3")) {
return ModelType.PHI_3;
}
}
return ModelType.UNKNOWN;
}
/**
* Loads the language model based on the given options.
* <p>
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
* </p>
*
* @param options
* the parsed CLI options containing model path and max token limit
* @return the loaded {@link Model} instance
* @throws IOException
* if the model fails to load
* @throws IllegalStateException
* if AOT loading is enabled but the preloaded model is unavailable
*/
public static Model loadModel(Options options) throws IOException {
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
}
public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights, boolean useTornadovm) throws IOException {
// initial load of metadata from gguf file
GGUF gguf = GGUF.loadModel(ggufPath);
FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
// detect model type
ModelType modelType = detectModelType(gguf.getMetadata());
// model type-specific load
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
/**
* Dispatcher method for loading a standard (non-tornado) tensor based on GGML type.
* Used in CPU-path.
*/
public static FloatTensor loadTensor(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
case F32 -> new FP32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
/**
* Dispatcher method for loading a standard tensor array based on type.
* Used in CPU-path.
*/
public static FloatTensor[] loadArrayOfTensors(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatTensor[] array = new FloatTensor[size];
for (int i = 0; i < size; i++) {
array[i] = loadTensor(getTensorEntry.apply(i));
}
return array;
}
/**
* Dispatcher method for loading a TornadoVM-compatible tensor based on GGML type.
* Used in GPU-path.
*/
public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
int size = FloatTensor.numberOfElements(entry.shape());
System.out.println("Loading tensor of type " + ggmlType + " with shape " + entry.name() + " -> " + entry.shape() + " and memory segment " + entry.memorySegment());
//
return switch (ggmlType) {
case F32 -> new FP32TornadoTensor(size, entry.memorySegment());
case F16 -> new FP16TornadoTensor(size, entry.memorySegment());
case Q8_0 -> Q8_0TornadoTensor.create(entry);
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
/**
* Dispatcher method for loading a TornadoVM tensor array based on type.
* Used in GPU-path.
*/
public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
TornadoTensor[] array = new TornadoTensor[size];
for (int i = 0; i < size; i++) {
array[i] = loadTornadoTensor(getTensorEntry.apply(i));
}
return array;
}
/**
* Load a tensor and ensure it's FP32 (FloatArray).
* Used for embeddings and normalization weights that must always be FP32.
*/
public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
// If already F32, load directly
if (entry.ggmlType() == GGMLType.F32) {
return new FP32TornadoTensor(
FloatTensor.numberOfElements(entry.shape()),
entry.memorySegment()
);
}
// Otherwise, dequantize to F32
FloatArray floatArray = loadTensorAsFloatArray(entry);
return new FP32TornadoTensor(floatArray);
}
/**
* Load array of tensors as FP32.
* Used for normalization weight arrays.
*/
public static TornadoTensor[] loadArrayOfTornadoTensorsAsFP32(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
TornadoTensor[] array = new TornadoTensor[size];
for (int i = 0; i < size; i++) {
array[i] = loadTornadoTensorAsFP32(getTensorEntry.apply(i));
}
return array;
}
// Helper methods
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatArray[] array = new FloatArray[size];
for (int i = 0; i < size; i++) {
array[i] = loadTensorAsFloatArray(getTensorEntry.apply(i));
}
return array;
}
public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
HalfFloatArray[] array = new HalfFloatArray[size];
for (int i = 0; i < size; i++) {
array[i] = loadTensorAsHalfFloatArray(getTensorEntry.apply(i));
}
return array;
}
public static Q8_0TornadoTensor[] loadArrayAsQ8_0TornadoTensor(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
Q8_0TornadoTensor[] array = new Q8_0TornadoTensor[size];
for (int i = 0; i < size; i++) {
array[i] = Q8_0TornadoTensor.create(getTensorEntry.apply(i));
}
return array;
}
public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
if (tensorEntry.ggmlType() == GGMLType.F32) {
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
return FloatArray.fromFloatBuffer(buffer);
} else {
throw new UnsupportedOperationException("Conversion to FloatArray from " + tensorEntry.ggmlType());
}
}
public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatArray[] array = new FloatArray[size];
for (int i = 0; i < size; i++) {
array[i] = floatBufferToFloatArray(getTensorEntry.apply(i));
}
return array;
}
public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
FloatTensor tensor = loadTensor(entry);
return ByteArray.fromSegment(tensor.asMemorySegment());
}
public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
if (entry.ggmlType() == GGMLType.F32) {
// For F32, we can directly create FloatArray from memory
FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
FloatArray array = new FloatArray(buffer.remaining());
for (int i = 0; i < buffer.remaining(); i++) {
array.set(i, buffer.get());
}
return array;
} else {
// For quantized formats, we need to load through FloatTensor
FloatTensor tensor = loadTensor(entry);
FloatArray array = new FloatArray(tensor.size());
for (int i = 0; i < tensor.size(); i++) {
array.set(i, tensor.getFloat(i));
}
return array;
}
}
public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
if (entry.ggmlType() == GGMLType.F32) {
System.out.println("Loading F32 tensor as HalfFloatArray");
return null;
} else {
// For quantized formats, we need to load through FloatTensor
FloatTensor tensor = loadTensor(entry);
HalfFloatArray array = new HalfFloatArray(tensor.size());
for (int i = 0; i < tensor.size(); i++) {
HalfFloat x = new HalfFloat(tensor.getFloat(i));
array.set(i, x);
}
return array;
}
}
public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
FloatBuffer[] array = new FloatBuffer[size];
for (int i = 0; i < size; i++) {
array[i] = toFloatBuffer(getTensorEntry.apply(i));
}
return array;
}
public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
GGMLType ggmlType = tensorEntry.ggmlType();
return switch (ggmlType) {
case F32 -> tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
default -> throw new UnsupportedOperationException("Conversion to " + ggmlType);
};
}
}