Skip to content

Commit 9ac6e96

Browse files
Rename loadTensors to loadTensorsStandard and enhance documentation and code clarity for tensor data loading.
1 parent cec6c9d commit 9ac6e96

1 file changed

Lines changed: 33 additions & 4 deletions

File tree

  • src/main/java/org/beehive/gpullama3/tensor

src/main/java/org/beehive/gpullama3/tensor/GGUF.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,45 @@ public static GGUF loadGGUFMetadata(Path modelPath) throws IOException {
9090
}
9191
}
9292

93-
public static Map<String, GGMLTensorEntry> loadTensors(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
93+
/**
94+
* Loads tensor data from a given file channel based on the tensor metadata information.
95+
* The mapping is read-
96+
*/
97+
public static Map<String, GGMLTensorEntry> loadTensorsStandard(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
9498
Arena arena = Arena.ofAuto();
95-
MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena);
99+
100+
// absolute file offset where the tensor-data section begins
101+
long mappingOffset = tensorDataOffset;
102+
// size of the entire tensor-data section
103+
long mappingSize = fileChannel.size() - tensorDataOffset;
104+
105+
MemorySegment tensorData =
106+
fileChannel.map(FileChannel.MapMode.READ_ONLY, mappingOffset, mappingSize, arena);
107+
96108
Map<String, GGMLTensorEntry> tensorEntries = HashMap.newHashMap(tensorInfos.size());
109+
97110
for (Map.Entry<String, GGUFTensorInfo> entry : tensorInfos.entrySet()) {
98111
GGUFTensorInfo ti = entry.getValue();
112+
113+
// skip rope_freqs.weight (not needed for inference)
114+
if (ti.name().equals("rope_freqs.weight")) {
115+
continue;
116+
}
117+
99118
int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
100119
int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
101-
MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes);
102-
tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
120+
121+
// per-tensor slice offset; ti.offset() is relative to tensor-data start
122+
long offset = ti.offset();
123+
124+
// per-tensor slice segment
125+
MemorySegment memorySegment = tensorData.asSlice(offset, sizeInBytes);
126+
127+
tensorEntries.put(ti.name(),
128+
new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
129+
}
130+
return tensorEntries;
131+
}
103132
}
104133
return tensorEntries;
105134
}

0 commit comments

Comments
 (0)