Skip to content

Commit cec6c9d

Browse files
Refactor tensor creation methods to use static factory methods and remove redundant constructors.
1 parent bc98607 commit cec6c9d

4 files changed

Lines changed: 8 additions & 13 deletions

File tree

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
126126
GGMLType ggmlType = entry.ggmlType();
127127
int size = FloatTensor.numberOfElements(entry.shape());
128128
return switch (ggmlType) {
129-
case F32 -> new FP32TornadoTensor(size, entry.memorySegment());
130-
case F16 -> new FP16TornadoTensor(size, entry.memorySegment());
129+
case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
130+
case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
131131
case Q8_0 -> Q8_0TornadoTensor.create(entry);
132132
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
133133
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);

src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ public FP16TornadoTensor(HalfFloatArray halfFloatArray) {
1313
this.tornadoNativeArray = halfFloatArray;
1414
}
1515

16-
public FP16TornadoTensor(MemorySegment segment) {
17-
this.tornadoNativeArray = new HalfFloatArray(segment);
16+
public static FP16TornadoTensor fromTornadoMemorySegment(MemorySegment segment) {
17+
return new FP16TornadoTensor(HalfFloatArray.fromSegmentShallow(segment));
1818
}
1919

2020
@Override

src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ public FP32TornadoTensor(FloatArray floatArray) {
1212
this.tornadoNativeArray = floatArray;
1313
}
1414

15-
public FP32TornadoTensor(MemorySegment segment) {
16-
this.tornadoNativeArray = new FloatArray(segment);
15+
public static FP32TornadoTensor fromTornadoMemorySegment(MemorySegment segment) {
16+
return new FP32TornadoTensor(FloatArray.fromSegmentShallow(segment));
1717
}
1818

1919
@Override

src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ public Int8Array getQuants() {
4141
return quants;
4242
}
4343

44-
@Override
45-
public int size() {
46-
return size;
47-
}
48-
4944
@Override
5045
public GGMLType type() {
5146
return GGMLType.Q8_0;
@@ -62,7 +57,7 @@ public MemorySegment asMemorySegment() {
6257
* @return Dequantized float value
6358
*/
6459
public float getFloat(int index) {
65-
assert 0 <= index && index < size;
60+
assert 0 <= index;
6661
int blockIdx = index / GGMLType.Q8_0.getBlockSize();
6762
float scale = scales.get(blockIdx).getFloat32();
6863
byte quant = quants.get(index);
@@ -108,6 +103,6 @@ public static Q8_0TornadoTensor create(GGMLTensorEntry entry) {
108103
}
109104
}
110105

111-
return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
106+
return new Q8_0TornadoTensor(scales, quants, q8Segment);
112107
}
113108
}

0 commit comments

Comments
 (0)