Skip to content

Commit bc98607

Browse files
Remove unnecessary size field and constructor from TornadoTensor and refactor subclasses accordingly.
1 parent 1267d73 commit bc98607

4 files changed

Lines changed: 16 additions & 26 deletions

File tree

src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
package org.beehive.gpullama3.tensor.tornado;
22

33
import org.beehive.gpullama3.tensor.GGMLType;
4+
import uk.ac.manchester.tornado.api.types.HalfFloat;
45
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
56

67
import java.lang.foreign.MemorySegment;
78

89
public class FP16TornadoTensor extends TornadoTensor {
9-
private final HalfFloatArray values;
10+
private final HalfFloatArray tornadoNativeArray;
1011

11-
public FP16TornadoTensor(int size, MemorySegment segment) {
12-
super(size);
13-
this.values = new HalfFloatArray(size);
14-
this.values.getSegment().copyFrom(segment);
12+
public FP16TornadoTensor(HalfFloatArray halfFloatArray) {
13+
this.tornadoNativeArray = halfFloatArray;
14+
}
15+
16+
public FP16TornadoTensor(MemorySegment segment) {
17+
this.tornadoNativeArray = new HalfFloatArray(segment);
1518
}
1619

1720
@Override
1821
public HalfFloatArray asHalfFloatArray() {
19-
return values;
22+
return tornadoNativeArray;
2023
}
2124

2225
@Override

src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,19 @@
66
import java.lang.foreign.MemorySegment;
77

88
public class FP32TornadoTensor extends TornadoTensor {
9-
private final FloatArray values;
9+
private final FloatArray tornadoNativeArray;
1010

11-
public FP32TornadoTensor(FloatArray values) {
12-
super(values.getSize());
13-
this.values = values;
11+
public FP32TornadoTensor(FloatArray floatArray) {
12+
this.tornadoNativeArray = floatArray;
1413
}
1514

16-
public FP32TornadoTensor(int size, MemorySegment segment) {
17-
super(size);
18-
this.values = new FloatArray(size);
19-
this.values.getSegment().copyFrom(segment);
15+
public FP32TornadoTensor(MemorySegment segment) {
16+
this.tornadoNativeArray = new FloatArray(segment);
2017
}
2118

2219
@Override
2320
public FloatArray asFloatArray() {
24-
return values;
21+
return tornadoNativeArray;
2522
}
2623

2724
@Override

src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ public class Q8_0TornadoTensor extends TornadoTensor {
1717
private final Int8Array quants; // Quantized int8 values
1818
private MemorySegment segment;
1919

20-
public Q8_0TornadoTensor(int size, HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
21-
super(size);
20+
public Q8_0TornadoTensor(HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
2221
this.scales = scales;
2322
this.quants = quants;
2423
this.segment = segment;

src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@
1010
* These tensors wrap TornadoVM native arrays for GPU execution.
1111
*/
1212
public abstract class TornadoTensor {
13-
protected final int size;
14-
15-
protected TornadoTensor(int size) {
16-
this.size = size;
17-
}
18-
19-
public int size() {
20-
return size;
21-
}
2213

2314
public abstract GGMLType type();
2415

0 commit comments

Comments
 (0)