Skip to content

Commit 7fbbe28

Browse files
Add loadTensorsTornado method in GGUF for TornadoVM-compatible tensor loading
1 parent 9ac6e96 commit 7fbbe28

2 files changed

Lines changed: 65 additions & 1 deletion

File tree

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ public final M loadModel() {
5656
C config = createConfiguration(metadata);
5757

5858
// Step 4: Load tensor entries
59-
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
59+
Map<String, GGMLTensorEntry> tensorEntries;
60+
if (useTornadovm) {
61+
tensorEntries = GGUF.loadTensorsTornado(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
62+
} else {
63+
tensorEntries = GGUF.loadTensorsStandard(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
64+
}
6065

6166
// Step 4: Load weights
6267
Weights weights = loadWeights(tensorEntries, config);

src/main/java/org/beehive/gpullama3/tensor/GGUF.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,65 @@ public static Map<String, GGMLTensorEntry> loadTensorsStandard(FileChannel fileC
129129
}
130130
return tensorEntries;
131131
}
132+
133+
/**
134+
* Loads GGUF tensor data using a TornadoVM-compatible memory layout.
135+
*
136+
* <p>This method parses the GGUF tensor list and memory-maps each tensor
137+
* in {@link TornadoNativeArray} layout directly from the underlying{@link FileChannel}.
138+
* For compatibility with {@link TornadoNativeArray} layout, an additional header is required at
139+
* the start of each tensor region. To satisfy this requirement, each tensor
140+
* is mapped using {@link FileChannel.MapMode#PRIVATE} starting 16 bytes
141+
* before the actual tensor position, providing a writable header region
142+
* without modifying the underlying GGUF file.</p>
143+
*
144+
*
145+
* @param fileChannel the channel from which tensor storage is read
146+
* @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
147+
* @param tensorInfos metadata describing all GGUF tensors
148+
*
149+
* @return a map from tensor name to {@link GGMLTensorEntry} containing
150+
* TornadoVM-compatible memory segments for each tensor
151+
*
152+
* @throws IOException if memory mapping fails or the channel cannot be read
153+
*/
154+
public static Map<String, GGMLTensorEntry> loadTensorsTornado(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
155+
156+
Arena arena = Arena.ofAuto();
157+
Map<String, GGMLTensorEntry> tensorEntries = HashMap.newHashMap(tensorInfos.size());
158+
159+
for (Map.Entry<String, GGUFTensorInfo> entry : tensorInfos.entrySet()) {
160+
GGUFTensorInfo ti = entry.getValue();
161+
162+
// skip rope_freqs.weight (not required for inference)
163+
if (ti.name().equals("rope_freqs.weight")) {
164+
continue;
165+
}
166+
167+
int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
168+
int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
169+
170+
// absolute tensor offset - relative to start of the file
171+
long mappingOffset = tensorDataOffset + ti.offset();
172+
173+
// create memory segment in TornadoVM NativeArray layout:
174+
// TornadoNativeArray.ARRAY_HEADER (16-byte) + tensor data
175+
long headerBytes = TornadoNativeArray.ARRAY_HEADER;
176+
177+
// start 16 bytes before the tensor position to include header space
178+
long offset = mappingOffset - headerBytes;
179+
long size = sizeInBytes + headerBytes;
180+
MemorySegment memorySegment =
181+
fileChannel.map(FileChannel.MapMode.PRIVATE, offset, size, arena);
182+
183+
// zero out the 16-byte header
184+
for (int i = 0; i < headerBytes; i++) {
185+
memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0);
186+
}
187+
188+
// store tornado-compatible segment
189+
tensorEntries.put(ti.name(),
190+
new GGMLTensorEntry(memorySegment, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
132191
}
133192
return tensorEntries;
134193
}

0 commit comments

Comments
 (0)