@@ -147,21 +147,25 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction<GG
147147 }
148148
149149 /**
150- * Load a tensor and ensure it's FP32 (FloatArray).
151- * Used for embeddings and normalization weights that must always be FP32.
150+ * Load a tensor and manually convert to FP32 (FloatArray).
151+ * Used for embeddings that currently are treated as FP32.
152+ * TODO: it is ultra-slow and will be removed
152153 */
153154 public static TornadoTensor loadTornadoTensorAsFP32 (GGMLTensorEntry entry ) {
154- // If already F32, load directly
155- if (entry .ggmlType () == GGMLType .F32 ) {
156- return new FP32TornadoTensor (
157- FloatTensor .numberOfElements (entry .shape ()),
158- entry .memorySegment ()
159- );
160- }
161-
162- // Otherwise, dequantize to F32
163- FloatArray floatArray = loadTensorAsFloatArray (entry );
164- return new FP32TornadoTensor (floatArray );
155+ TornadoTensor tensor = loadTornadoTensor (entry );
156+ return switch (tensor .type ()) {
157+ case F32 -> tensor ;
158+ case F16 -> {
159+ HalfFloatArray tensorHFA = tensor .asHalfFloatArray ();
160+ int numOfElements = tensorHFA .getSize ();
161+ FloatArray tensorFA = new FloatArray (numOfElements );
162+ for (int i = 0 ; i < numOfElements ; i ++) {
163+ tensorFA .set (i , tensorHFA .get (i ).getFloat32 ());
164+ }
165+ yield new FP32TornadoTensor (tensorFA );
166+ }
167+ default -> { throw new UnsupportedOperationException ("Unsupported tensor type: " + tensor .type ()); }
168+ };
165169 }
166170
167171 // Helper methods
0 commit comments