Skip to content

Commit 7083247

Browse files
committed
Fix factory and tests
Related-To: #126
1 parent 036b97e commit 7083247

2 files changed

Lines changed: 49 additions & 79 deletions

File tree

skainet-lang/skainet-lang-memory/src/commonTest/kotlin/sk/ai/net/core/factory/DenseTensorDataFactory.kt renamed to skainet-lang/skainet-lang-memory/src/commonMain/kotlin/sk/ai/net/core/tensor/data/DenseTensorDataFactory.kt

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,68 @@
1-
package sk.ai.net.core.factory
2-
3-
import sk.ai.net.core.tensor.Shape
4-
import sk.ai.net.core.tensor.data.TensorData
5-
import sk.ai.net.core.types.DType
6-
import sk.ai.net.core.types.FP32
7-
import sk.ai.net.core.types.FP16
8-
import sk.ai.net.core.types.Int32
9-
import sk.ai.net.core.types.Int8
10-
import sk.ai.net.core.types.Int4
11-
import sk.ai.net.core.types.Ternary
1+
package sk.ai.net.core.tensor.data
2+
3+
import sk.ai.net.lang.tensor.Shape
4+
import sk.ai.net.lang.tensor.data.TensorData
5+
import sk.ai.net.lang.types.DType
6+
import sk.ai.net.lang.types.FP16
7+
import sk.ai.net.lang.types.FP32
8+
import sk.ai.net.lang.types.Int32
9+
import sk.ai.net.lang.types.Int4
10+
import sk.ai.net.lang.types.Int8
11+
import sk.ai.net.lang.types.Ternary
1212
import kotlin.jvm.JvmName
1313

14-
class DenseTensorDataFactory {
15-
fun from(value: Int): TensorData<Int32, Int> {
14+
public class DenseTensorDataFactory {
15+
public fun from(value: Int): TensorData<
16+
Int32, Int> {
1617
return object : TensorData<Int32, Int> {
1718
override val shape: Shape
18-
get() = Shape(1)
19-
20-
override fun materialize(): TensorData<Int32, Int> = this
19+
get() = Shape.Companion(1)
2120

2221
override fun get(vararg indices: Int): Int = value
2322
}
2423
}
2524

26-
fun from(value: Float): TensorData<FP32, Float> {
25+
public fun from(value: Float): TensorData<FP32, Float> {
2726
return object : TensorData<FP32, Float> {
2827
override val shape: Shape
29-
get() = Shape(1)
30-
31-
override fun materialize(): TensorData<FP32, Float> = this
28+
get() = Shape.Companion(1)
3229

3330
override fun get(vararg indices: Int): Float = value
3431
}
3532
}
3633

3734
@JvmName("vectorFromInt")
38-
fun fromArray(arrayOf: Array<Int>): TensorData<Int32, Int> {
35+
public fun fromArray(arrayOf: Array<Int>): TensorData<Int32, Int> {
3936
class IntTensorData(private val data: IntArray) : TensorData<Int32, Int> {
4037
override val shape: Shape
41-
get() = Shape(data.size)
42-
43-
override fun materialize(): TensorData<Int32, Int> = this
38+
get() = Shape.Companion(data.size)
4439

4540
override fun get(vararg indices: Int): Int = data[indices[0]]
4641
}
4742
return IntTensorData(arrayOf.toIntArray())
4843
}
4944

5045
@JvmName("vectorFromFloat")
51-
fun fromArray(arrayOf: Array<Float>): TensorData<FP32, Float> {
46+
public fun fromArray(arrayOf: Array<Float>): TensorData<FP32, Float> {
5247
class FloatTensorData(private val data: FloatArray) : TensorData<FP32, Float> {
5348
override val shape: Shape
54-
get() = Shape(data.size)
55-
56-
override fun materialize(): TensorData<FP32, Float> = this
49+
get() = Shape.Companion(data.size)
5750

5851
override fun get(vararg indices: Int): Float = data[indices[0]]
5952
}
6053
return FloatTensorData(arrayOf.toFloatArray())
6154
}
6255

6356
@Suppress("UNCHECKED_CAST")
64-
fun <T : DType, V> fromFloatArray(
57+
public fun <T : DType, V> fromFloatArray(
6558
data: FloatArray,
6659
dtype: T
6760
): TensorData<T, V> {
6861
return when (dtype) {
6962
is FP32 -> {
7063
class FP32FloatTensorData(private val data: FloatArray) : TensorData<FP32, Float> {
7164
override val shape: Shape
72-
get() = Shape(data.size)
73-
74-
override fun materialize(): TensorData<FP32, Float> = this
65+
get() = Shape.Companion(data.size)
7566

7667
override fun get(vararg indices: Int): Float = data[indices[0]]
7768
}
@@ -80,9 +71,7 @@ class DenseTensorDataFactory {
8071
is FP16 -> {
8172
class FP16FloatTensorData(private val data: FloatArray) : TensorData<FP16, Float> {
8273
override val shape: Shape
83-
get() = Shape(data.size)
84-
85-
override fun materialize(): TensorData<FP16, Float> = this
74+
get() = Shape.Companion(data.size)
8675

8776
override fun get(vararg indices: Int): Float = data[indices[0]]
8877
}
@@ -93,17 +82,15 @@ class DenseTensorDataFactory {
9382
}
9483

9584
@Suppress("UNCHECKED_CAST")
96-
fun <T : DType, V> fromIntArray(
85+
public fun <T : DType, V> fromIntArray(
9786
data: IntArray,
9887
dtype: T
9988
): TensorData<T, V> {
10089
return when (dtype) {
10190
is Int32 -> {
10291
class Int32IntTensorData(private val data: IntArray) : TensorData<Int32, Int> {
10392
override val shape: Shape
104-
get() = Shape(data.size)
105-
106-
override fun materialize(): TensorData<Int32, Int> = this
93+
get() = Shape.Companion(data.size)
10794

10895
override fun get(vararg indices: Int): Int = data[indices[0]]
10996
}
@@ -114,17 +101,15 @@ class DenseTensorDataFactory {
114101
}
115102

116103
@Suppress("UNCHECKED_CAST")
117-
fun <T : DType, V> fromByteArray(
118-
bytes: ByteArray,
104+
public fun <T : DType, V> fromByteArray(
105+
bytes: ByteArray,
119106
dtype: T
120107
): TensorData<T, V> {
121108
return when (dtype) {
122109
is FP32 -> {
123110
class FP32ByteTensorData(private val data: ByteArray) : TensorData<FP32, Float> {
124111
override val shape: Shape
125-
get() = Shape(data.size / 4) // 4 bytes per float
126-
127-
override fun materialize(): TensorData<FP32, Float> = this
112+
get() = Shape.Companion(data.size / 4) // 4 bytes per float
128113

129114
override fun get(vararg indices: Int): Float {
130115
val index = indices[0] * 4
@@ -141,9 +126,7 @@ class DenseTensorDataFactory {
141126
is Int8 -> {
142127
class Int8ByteTensorData(private val data: ByteArray) : TensorData<Int8, Byte> {
143128
override val shape: Shape
144-
get() = Shape(data.size)
145-
146-
override fun materialize(): TensorData<Int8, Byte> = this
129+
get() = Shape.Companion(data.size)
147130

148131
override fun get(vararg indices: Int): Byte = data[indices[0]]
149132
}
@@ -152,9 +135,7 @@ class DenseTensorDataFactory {
152135
is Int4 -> {
153136
class Int4ByteTensorData(private val data: ByteArray) : TensorData<Int4, Byte> {
154137
override val shape: Shape
155-
get() = Shape(data.size * 2) // 2 int4 values per byte
156-
157-
override fun materialize(): TensorData<Int4, Byte> = this
138+
get() = Shape.Companion(data.size * 2) // 2 int4 values per byte
158139

159140
override fun get(vararg indices: Int): Byte {
160141
val byteIndex = indices[0] / 2
@@ -172,9 +153,7 @@ class DenseTensorDataFactory {
172153
is Ternary -> {
173154
class TernaryByteTensorData(private val data: ByteArray) : TensorData<Ternary, Byte> {
174155
override val shape: Shape
175-
get() = Shape(data.size * 4) // 4 ternary values per byte (2 bits each)
176-
177-
override fun materialize(): TensorData<Ternary, Byte> = this
156+
get() = Shape.Companion(data.size * 4) // 4 ternary values per byte (2 bits each)
178157

179158
override fun get(vararg indices: Int): Byte {
180159
val byteIndex = indices[0] / 4

skainet-lang/skainet-lang-memory/src/commonTest/kotlin/sk/ai/net/core/factory/DenseTensorsTest.kt

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
package sk.ai.net.core.factory
22

3-
import sk.ai.net.core.tensor.Shape
4-
import sk.ai.net.core.types.FP32
5-
import sk.ai.net.core.types.FP16
6-
import sk.ai.net.core.types.Int32
7-
import sk.ai.net.core.types.Int8
8-
import sk.ai.net.core.types.Int4
9-
import sk.ai.net.core.types.Ternary
3+
import sk.ai.net.core.tensor.data.DenseTensorDataFactory
4+
import sk.ai.net.lang.tensor.Shape
5+
import sk.ai.net.lang.types.*
106
import kotlin.test.Test
11-
import kotlin.test.assertTrue
12-
import kotlin.test.assertFalse
137
import kotlin.test.assertEquals
14-
import kotlin.test.assertNotNull
15-
import kotlin.test.assertNull
168
import kotlin.test.assertFailsWith
179

1810
class DenseTensorsTest {
@@ -59,9 +51,9 @@ class DenseTensorsTest {
5951
// Test FP32 fromByteArray with 1.0f in IEEE 754 little-endian
6052
val floatBytes = byteArrayOf(0x00, 0x00, 0x80.toByte(), 0x3F.toByte())
6153
val fp32TensorData = fromByteArray<FP32, Float>(floatBytes, FP32)
62-
54+
6355
assertEquals(Shape(1), fp32TensorData.shape)
64-
assertEquals(1.0f, fp32TensorData.get(0))
56+
assertEquals(1.0f, fp32TensorData[0])
6557
}
6658
}
6759

@@ -71,7 +63,7 @@ class DenseTensorsTest {
7163
// Test Int8 fromByteArray
7264
val int8Bytes = byteArrayOf(42, -10, 100)
7365
val int8TensorData = fromByteArray<Int8, Byte>(int8Bytes, Int8)
74-
66+
7567
assertEquals(Shape(3), int8TensorData.shape)
7668
assertEquals(42.toByte(), int8TensorData[0])
7769
assertEquals((-10).toByte(), int8TensorData[1])
@@ -85,10 +77,10 @@ class DenseTensorsTest {
8577
// Test Int4 fromByteArray - 0x5A = 0101 1010 -> lower nibble 10 (0xA), upper nibble 5
8678
val int4Bytes = byteArrayOf(0x5A.toByte())
8779
val int4TensorData = fromByteArray<Int4, Byte>(int4Bytes, Int4)
88-
80+
8981
assertEquals(Shape(2), int4TensorData.shape)
90-
assertEquals(10.toByte(), int4TensorData.get(0)) // Lower nibble: 0xA = 10
91-
assertEquals(5.toByte(), int4TensorData.get(1)) // Upper nibble: 0x5 = 5
82+
assertEquals(10.toByte(), int4TensorData[0]) // Lower nibble: 0xA = 10
83+
assertEquals(5.toByte(), int4TensorData[1]) // Upper nibble: 0x5 = 5
9284
}
9385
}
9486

@@ -98,12 +90,12 @@ class DenseTensorsTest {
9890
// Test Ternary fromByteArray - 0x1B = 00011011 -> bits read as: 11 10 01 00 -> 0, 1, 0, -1
9991
val ternaryBytes = byteArrayOf(0x1B.toByte())
10092
val ternaryTensorData = fromByteArray<Ternary, Byte>(ternaryBytes, Ternary)
101-
93+
10294
assertEquals(Shape(4), ternaryTensorData.shape)
103-
assertEquals(0.toByte(), ternaryTensorData.get(0)) // bits 11 -> 0 (fallback)
104-
assertEquals(1.toByte(), ternaryTensorData.get(1)) // bits 10 -> 1
105-
assertEquals(0.toByte(), ternaryTensorData.get(2)) // bits 01 -> 0
106-
assertEquals((-1).toByte(), ternaryTensorData.get(3)) // bits 00 -> -1
95+
assertEquals(0.toByte(), ternaryTensorData[0]) // bits 11 -> 0 (fallback)
96+
assertEquals(1.toByte(), ternaryTensorData[1]) // bits 10 -> 1
97+
assertEquals(0.toByte(), ternaryTensorData[2]) // bits 01 -> 0
98+
assertEquals((-1).toByte(), ternaryTensorData[3]) // bits 00 -> -1
10799
}
108100
}
109101

@@ -113,7 +105,7 @@ class DenseTensorsTest {
113105
// Test FP32 fromFloatArray
114106
val floatData = floatArrayOf(1.5f, 2.5f, 3.5f)
115107
val fp32TensorData = fromFloatArray<FP32, Float>(floatData, FP32)
116-
108+
117109
assertEquals(Shape(3), fp32TensorData.shape)
118110
assertEquals(1.5f, fp32TensorData[0])
119111
assertEquals(2.5f, fp32TensorData[1])
@@ -127,7 +119,7 @@ class DenseTensorsTest {
127119
// Test FP16 fromFloatArray
128120
val floatData = floatArrayOf(0.5f, -1.0f, 10.0f)
129121
val fp16TensorData = fromFloatArray<FP16, Float>(floatData, FP16)
130-
122+
131123
assertEquals(Shape(3), fp16TensorData.shape)
132124
assertEquals(0.5f, fp16TensorData[0])
133125
assertEquals(-1.0f, fp16TensorData[1])
@@ -141,7 +133,7 @@ class DenseTensorsTest {
141133
// Test Int32 fromIntArray
142134
val intData = intArrayOf(42, -100, 1000)
143135
val int32TensorData = fromIntArray<Int32, Int>(intData, Int32)
144-
136+
145137
assertEquals(Shape(3), int32TensorData.shape)
146138
assertEquals(42, int32TensorData[0])
147139
assertEquals(-100, int32TensorData[1])
@@ -168,5 +160,4 @@ class DenseTensorsTest {
168160
}
169161
}
170162
}
171-
172163
}

0 commit comments

Comments
 (0)