1- package sk.ai.net.core.factory
2-
3- import sk.ai.net.core .tensor.Shape
4- import sk.ai.net.core .tensor.data.TensorData
5- import sk.ai.net.core .types.DType
6- import sk.ai.net.core .types.FP32
7- import sk.ai.net.core .types.FP16
8- import sk.ai.net.core .types.Int32
9- import sk.ai.net.core .types.Int8
10- import sk.ai.net.core .types.Int4
11- import sk.ai.net.core .types.Ternary
1+ package sk.ai.net.core.tensor.data
2+
3+ import sk.ai.net.lang .tensor.Shape
4+ import sk.ai.net.lang .tensor.data.TensorData
5+ import sk.ai.net.lang .types.DType
6+ import sk.ai.net.lang .types.FP16
7+ import sk.ai.net.lang .types.FP32
8+ import sk.ai.net.lang .types.Int32
9+ import sk.ai.net.lang .types.Int4
10+ import sk.ai.net.lang .types.Int8
11+ import sk.ai.net.lang .types.Ternary
1212import kotlin.jvm.JvmName
1313
14- class DenseTensorDataFactory {
15- fun from (value : Int ): TensorData <Int32 , Int > {
14+ public class DenseTensorDataFactory {
15+ public fun from (value : Int ): TensorData <
16+ Int32 , Int > {
1617 return object : TensorData <Int32 , Int > {
1718 override val shape: Shape
18- get() = Shape (1 )
19-
20- override fun materialize (): TensorData <Int32 , Int > = this
19+ get() = Shape .Companion (1 )
2120
2221 override fun get (vararg indices : Int ): Int = value
2322 }
2423 }
2524
26- fun from (value : Float ): TensorData <FP32 , Float > {
25+ public fun from (value : Float ): TensorData <FP32 , Float > {
2726 return object : TensorData <FP32 , Float > {
2827 override val shape: Shape
29- get() = Shape (1 )
30-
31- override fun materialize (): TensorData <FP32 , Float > = this
28+ get() = Shape .Companion (1 )
3229
3330 override fun get (vararg indices : Int ): Float = value
3431 }
3532 }
3633
3734 @JvmName(" vectorFromInt" )
38- fun fromArray (arrayOf : Array <Int >): TensorData <Int32 , Int > {
35+ public fun fromArray (arrayOf : Array <Int >): TensorData <Int32 , Int > {
3936 class IntTensorData (private val data : IntArray ) : TensorData<Int32, Int> {
4037 override val shape: Shape
41- get() = Shape (data.size)
42-
43- override fun materialize (): TensorData <Int32 , Int > = this
38+ get() = Shape .Companion (data.size)
4439
4540 override fun get (vararg indices : Int ): Int = data[indices[0 ]]
4641 }
4742 return IntTensorData (arrayOf.toIntArray())
4843 }
4944
5045 @JvmName(" vectorFromFloat" )
51- fun fromArray (arrayOf : Array <Float >): TensorData <FP32 , Float > {
46+ public fun fromArray (arrayOf : Array <Float >): TensorData <FP32 , Float > {
5247 class FloatTensorData (private val data : FloatArray ) : TensorData<FP32, Float> {
5348 override val shape: Shape
54- get() = Shape (data.size)
55-
56- override fun materialize (): TensorData <FP32 , Float > = this
49+ get() = Shape .Companion (data.size)
5750
5851 override fun get (vararg indices : Int ): Float = data[indices[0 ]]
5952 }
6053 return FloatTensorData (arrayOf.toFloatArray())
6154 }
6255
6356 @Suppress(" UNCHECKED_CAST" )
64- fun <T : DType , V > fromFloatArray (
57+ public fun <T : DType , V > fromFloatArray (
6558 data : FloatArray ,
6659 dtype : T
6760 ): TensorData <T , V > {
6861 return when (dtype) {
6962 is FP32 -> {
7063 class FP32FloatTensorData (private val data : FloatArray ) : TensorData<FP32, Float> {
7164 override val shape: Shape
72- get() = Shape (data.size)
73-
74- override fun materialize (): TensorData <FP32 , Float > = this
65+ get() = Shape .Companion (data.size)
7566
7667 override fun get (vararg indices : Int ): Float = data[indices[0 ]]
7768 }
@@ -80,9 +71,7 @@ class DenseTensorDataFactory {
8071 is FP16 -> {
8172 class FP16FloatTensorData (private val data : FloatArray ) : TensorData<FP16, Float> {
8273 override val shape: Shape
83- get() = Shape (data.size)
84-
85- override fun materialize (): TensorData <FP16 , Float > = this
74+ get() = Shape .Companion (data.size)
8675
8776 override fun get (vararg indices : Int ): Float = data[indices[0 ]]
8877 }
@@ -93,17 +82,15 @@ class DenseTensorDataFactory {
9382 }
9483
9584 @Suppress(" UNCHECKED_CAST" )
96- fun <T : DType , V > fromIntArray (
85+ public fun <T : DType , V > fromIntArray (
9786 data : IntArray ,
9887 dtype : T
9988 ): TensorData <T , V > {
10089 return when (dtype) {
10190 is Int32 -> {
10291 class Int32IntTensorData (private val data : IntArray ) : TensorData<Int32, Int> {
10392 override val shape: Shape
104- get() = Shape (data.size)
105-
106- override fun materialize (): TensorData <Int32 , Int > = this
93+ get() = Shape .Companion (data.size)
10794
10895 override fun get (vararg indices : Int ): Int = data[indices[0 ]]
10996 }
@@ -114,17 +101,15 @@ class DenseTensorDataFactory {
114101 }
115102
116103 @Suppress(" UNCHECKED_CAST" )
117- fun <T : DType , V > fromByteArray (
118- bytes : ByteArray ,
104+ public fun <T : DType , V > fromByteArray (
105+ bytes : ByteArray ,
119106 dtype : T
120107 ): TensorData <T , V > {
121108 return when (dtype) {
122109 is FP32 -> {
123110 class FP32ByteTensorData (private val data : ByteArray ) : TensorData<FP32, Float> {
124111 override val shape: Shape
125- get() = Shape (data.size / 4 ) // 4 bytes per float
126-
127- override fun materialize (): TensorData <FP32 , Float > = this
112+ get() = Shape .Companion (data.size / 4 ) // 4 bytes per float
128113
129114 override fun get (vararg indices : Int ): Float {
130115 val index = indices[0 ] * 4
@@ -141,9 +126,7 @@ class DenseTensorDataFactory {
141126 is Int8 -> {
142127 class Int8ByteTensorData (private val data : ByteArray ) : TensorData<Int8, Byte> {
143128 override val shape: Shape
144- get() = Shape (data.size)
145-
146- override fun materialize (): TensorData <Int8 , Byte > = this
129+ get() = Shape .Companion (data.size)
147130
148131 override fun get (vararg indices : Int ): Byte = data[indices[0 ]]
149132 }
@@ -152,9 +135,7 @@ class DenseTensorDataFactory {
152135 is Int4 -> {
153136 class Int4ByteTensorData (private val data : ByteArray ) : TensorData<Int4, Byte> {
154137 override val shape: Shape
155- get() = Shape (data.size * 2 ) // 2 int4 values per byte
156-
157- override fun materialize (): TensorData <Int4 , Byte > = this
138+ get() = Shape .Companion (data.size * 2 ) // 2 int4 values per byte
158139
159140 override fun get (vararg indices : Int ): Byte {
160141 val byteIndex = indices[0 ] / 2
@@ -172,9 +153,7 @@ class DenseTensorDataFactory {
172153 is Ternary -> {
173154 class TernaryByteTensorData (private val data : ByteArray ) : TensorData<Ternary, Byte> {
174155 override val shape: Shape
175- get() = Shape (data.size * 4 ) // 4 ternary values per byte (2 bits each)
176-
177- override fun materialize (): TensorData <Ternary , Byte > = this
156+ get() = Shape .Companion (data.size * 4 ) // 4 ternary values per byte (2 bits each)
178157
179158 override fun get (vararg indices : Int ): Byte {
180159 val byteIndex = indices[0 ] / 4
0 commit comments