Skip to content

Commit 3d5c340

Browse files
Refactor loadTornadoTensorAsFP32 to perform the temporary manual conversion to FP32
1 parent 7fbbe28 commit 3d5c340

1 file changed

Lines changed: 17 additions & 13 deletions

File tree

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,25 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction<GG
147147
}
148148

149149
/**
150-
* Load a tensor and ensure it's FP32 (FloatArray).
151-
* Used for embeddings and normalization weights that must always be FP32.
150+
* Load a tensor and manually convert to FP32 (FloatArray).
151+
* Used for embeddings that currently are treated as FP32.
152+
* TODO: it is ultra-slow and will be removed
152153
*/
153154
public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
154-
// If already F32, load directly
155-
if (entry.ggmlType() == GGMLType.F32) {
156-
return new FP32TornadoTensor(
157-
FloatTensor.numberOfElements(entry.shape()),
158-
entry.memorySegment()
159-
);
160-
}
161-
162-
// Otherwise, dequantize to F32
163-
FloatArray floatArray = loadTensorAsFloatArray(entry);
164-
return new FP32TornadoTensor(floatArray);
155+
TornadoTensor tensor = loadTornadoTensor(entry);
156+
return switch (tensor.type()) {
157+
case F32 -> tensor;
158+
case F16 -> {
159+
HalfFloatArray tensorHFA = tensor.asHalfFloatArray();
160+
int numOfElements = tensorHFA.getSize();
161+
FloatArray tensorFA = new FloatArray(numOfElements);
162+
for(int i = 0; i < numOfElements; i++) {
163+
tensorFA.set(i, tensorHFA.get(i).getFloat32());
164+
}
165+
yield new FP32TornadoTensor(tensorFA);
166+
}
167+
default -> { throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type()); }
168+
};
165169
}
166170

167171
// Helper methods

0 commit comments

Comments
 (0)