Skip to content

Commit 5fd5950

Browse files
committed
Improve test with reshaping
Related-To: #121 #95
1 parent c87b31b commit 5fd5950

6 files changed

Lines changed: 94 additions & 36 deletions

File tree

skainet-core/skainet-tensors-api/src/commonMain/kotlin/sk/ainet/core/tensor/dsl/TensorsRangeBuilder.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package sk.ainet.core.tensor.dsl
22

33
import sk.ainet.core.tensor.DType
4-
import sk.ainet.core.tensor.NCHWViewHelper
54
import sk.ainet.core.tensor.Slice
65
import sk.ainet.core.tensor.SliceDescriptor
7-
import sk.ainet.core.tensor.SliceIndexMapper
86
import sk.ainet.core.tensor.Tensor
97

108
/**

skainet-core/skainet-tensors/src/commonMain/kotlin/sk/ainet/core/tensor/backend/CpuBackend.kt

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ public class CpuTensorInt8(
591591
// For integer tensors, sigmoid is approximated and scaled to byte range
592592
val result = this.data.map {
593593
val sigmoid = 1.0 / (1.0 + exp(-it.toDouble() / 127.0)) // Scale input to [-1, 1] range
594-
clampToByte((sigmoid * 254 - 127).toDouble()) // Scale output to byte range
594+
clampToByte((sigmoid * 254 - 127)) // Scale output to byte range
595595
}.toByteArray()
596596
return CpuTensorInt8(this.shape, result)
597597
}
@@ -601,7 +601,7 @@ public class CpuTensorInt8(
601601
// For integer tensors, tanh is approximated and scaled to byte range
602602
val result = this.data.map {
603603
val tanhValue = tanh(it.toDouble() / 127.0) // Scale input to [-1, 1] range
604-
clampToByte((tanhValue * 127).toDouble()) // Scale output to byte range
604+
clampToByte((tanhValue * 127)) // Scale output to byte range
605605
}.toByteArray()
606606
return CpuTensorInt8(this.shape, result)
607607
}
@@ -2365,7 +2365,7 @@ public class CpuBackendInt8 : ComputeBackend<Int8, Byte> {
23652365
override fun Double.minus(t: Tensor<Int8, Byte>): Tensor<Int8, Byte> =
23662366
(t as CpuTensorInt8).minus(this).let {
23672367
CpuTensorInt8.fromArray(t.shape, ByteArray(t.shape.volume) { i ->
2368-
((this - (t as CpuTensorInt8).data[i].toInt()).toInt().toByte())
2368+
((this - t.data[i].toInt()).toInt().toByte())
23692369
})
23702370
}
23712371

@@ -2395,11 +2395,47 @@ public class CpuBackendInt8 : ComputeBackend<Int8, Byte> {
23952395
override fun Tensor<Int8, Byte>.flatten(startDim: Int, endDim: Int): Tensor<Int8, Byte> =
23962396
(this as CpuTensorInt8).flatten(startDim, endDim)
23972397

2398-
override fun Tensor<Int8, Byte>.reshape(newShape: Shape): Tensor<Int8, Byte> =
2399-
(this as CpuTensorInt8).reshape(newShape)
2398+
override fun Tensor<Int8, Byte>.reshape(newShape: Shape): Tensor<Int8, Byte> {
2399+
require(this.shape.volume == newShape.volume) {
2400+
"Cannot reshape tensor with ${this.shape.volume} elements to shape with ${newShape.volume} elements"
2401+
}
2402+
2403+
// Handle both CpuTensorInt8 and sliced tensors
2404+
if (this is CpuTensorInt8) {
2405+
// Direct access to data for CpuTensorInt8
2406+
return CpuTensorInt8.fromArray(newShape, this.data.copyOf())
2407+
} else {
2408+
// For sliced tensors, extract data using copyTo
2409+
val tensorData = Array<Byte>(this.shape.volume) { 0 }
2410+
this.copyTo(tensorData)
2411+
return CpuTensorInt8.fromArray(newShape, tensorData.toByteArray())
2412+
}
2413+
}
24002414

2401-
override fun Tensor<Int8, Byte>.reshape(vararg dimensions: Int): Tensor<Int8, Byte> =
2402-
(this as CpuTensorInt8).reshape(*dimensions)
2415+
override fun Tensor<Int8, Byte>.reshape(vararg dimensions: Int): Tensor<Int8, Byte> {
2416+
// Count -1 dimensions and validate
2417+
val minusOneCount = dimensions.count { it == -1 }
2418+
require(minusOneCount <= 1) { "Only one dimension can be -1, found $minusOneCount" }
2419+
2420+
val totalKnownElements = dimensions.filter { it != -1 }.fold(1) { acc, dim ->
2421+
require(dim > 0) { "All dimensions must be positive or -1, got $dim" }
2422+
acc * dim
2423+
}
2424+
2425+
val inferredDimensions = if (minusOneCount == 1) {
2426+
// Calculate the missing dimension
2427+
val missingDimSize = this.shape.volume / totalKnownElements
2428+
require(this.shape.volume % totalKnownElements == 0) {
2429+
"Cannot infer dimension size: volume ${this.shape.volume} is not divisible by known dimensions product $totalKnownElements"
2430+
}
2431+
dimensions.map { if (it == -1) missingDimSize else it }.toIntArray()
2432+
} else {
2433+
dimensions
2434+
}
2435+
2436+
val newShape = Shape(inferredDimensions)
2437+
return this.reshape(newShape)
2438+
}
24032439

24042440
override fun zeros(shape: Shape): Tensor<Int8, Byte> =
24052441
CpuTensorInt8.zeros(shape)
@@ -2570,11 +2606,47 @@ public class CpuBackendInt32 : ComputeBackend<Int32, Int> {
25702606
override fun Tensor<Int32, Int>.flatten(startDim: Int, endDim: Int): Tensor<Int32, Int> =
25712607
(this as CpuTensorInt32).flatten(startDim, endDim)
25722608

2573-
override fun Tensor<Int32, Int>.reshape(newShape: Shape): Tensor<Int32, Int> =
2574-
(this as CpuTensorInt32).reshape(newShape)
2609+
override fun Tensor<Int32, Int>.reshape(newShape: Shape): Tensor<Int32, Int> {
2610+
require(this.shape.volume == newShape.volume) {
2611+
"Cannot reshape tensor with ${this.shape.volume} elements to shape with ${newShape.volume} elements"
2612+
}
2613+
2614+
// Handle both CpuTensorInt32 and sliced tensors
2615+
if (this is CpuTensorInt32) {
2616+
// Direct access to data for CpuTensorInt32
2617+
return CpuTensorInt32.fromArray(newShape, this.data.copyOf())
2618+
} else {
2619+
// For sliced tensors, extract data using copyTo
2620+
val tensorData = Array<Int>(this.shape.volume) { 0 }
2621+
this.copyTo(tensorData)
2622+
return CpuTensorInt32.fromArray(newShape, tensorData.toIntArray())
2623+
}
2624+
}
25752625

2576-
override fun Tensor<Int32, Int>.reshape(vararg dimensions: Int): Tensor<Int32, Int> =
2577-
(this as CpuTensorInt32).reshape(*dimensions)
2626+
override fun Tensor<Int32, Int>.reshape(vararg dimensions: Int): Tensor<Int32, Int> {
2627+
// Count -1 dimensions and validate
2628+
val minusOneCount = dimensions.count { it == -1 }
2629+
require(minusOneCount <= 1) { "Only one dimension can be -1, found $minusOneCount" }
2630+
2631+
val totalKnownElements = dimensions.filter { it != -1 }.fold(1) { acc, dim ->
2632+
require(dim > 0) { "All dimensions must be positive or -1, got $dim" }
2633+
acc * dim
2634+
}
2635+
2636+
val inferredDimensions = if (minusOneCount == 1) {
2637+
// Calculate the missing dimension
2638+
val missingDimSize = this.shape.volume / totalKnownElements
2639+
require(this.shape.volume % totalKnownElements == 0) {
2640+
"Cannot infer dimension size: volume ${this.shape.volume} is not divisible by known dimensions product $totalKnownElements"
2641+
}
2642+
dimensions.map { if (it == -1) missingDimSize else it }.toIntArray()
2643+
} else {
2644+
dimensions
2645+
}
2646+
2647+
val newShape = Shape(inferredDimensions)
2648+
return this.reshape(newShape)
2649+
}
25782650

25792651
override fun zeros(shape: Shape): Tensor<Int32, Int> =
25802652
CpuTensorInt32.zeros(shape)

skainet-core/skainet-tensors/src/commonMain/kotlin/sk/ainet/core/tensor/factory/FP32TensorFactory.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package sk.ainet.core.tensor.factory
22

3-
import sk.ainet.core.tensor.DType
43
import sk.ainet.core.tensor.FP32
54
import sk.ainet.core.tensor.Shape
65
import sk.ainet.core.tensor.Tensor

skainet-core/skainet-tensors/src/commonMain/kotlin/sk/ainet/core/tensor/factory/TensorFactoryDocumentation.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package sk.ainet.core.tensor.factory
22

3-
import sk.ainet.core.tensor.DType
4-
import sk.ainet.core.tensor.Shape
5-
import sk.ainet.core.tensor.Tensor
6-
73
/**
84
* # Tensor Factory Documentation and Examples
95
*

skainet-core/skainet-tensors/src/commonTest/kotlin/sk/ainet/core/tensor/ReshapeSlicingDslTest.kt

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import sk.ainet.core.tensor.backend.CpuBackendInt8
88
import sk.ainet.core.tensor.backend.CpuBackendInt32
99
import sk.ainet.core.tensor.backend.CpuTensorFP16
1010
import sk.ainet.core.tensor.backend.CpuBackendFP16
11-
import sk.ainet.core.tensor.dsl.*
1211
import kotlin.test.*
1312

1413
class ReshapeSlicingDslTest {
@@ -125,7 +124,7 @@ class ReshapeSlicingDslTest {
125124
segment { all() }
126125
segment { all() }
127126
}.let { view ->
128-
original.reshape(Shape(6, 2))
127+
view.reshape(Shape(6, 2))
129128
}
130129
}
131130

@@ -153,7 +152,7 @@ class ReshapeSlicingDslTest {
153152
segment { all() }
154153
segment { all() }
155154
}.let { view ->
156-
original.reshape(Shape(2, 3))
155+
view.reshape(Shape(2, 3))
157156
}
158157
}
159158

@@ -178,7 +177,7 @@ class ReshapeSlicingDslTest {
178177
sliceTensor(original) {
179178
segment { all() }
180179
}.let { view ->
181-
original.reshape(Shape(1, 1))
180+
view.reshape(Shape(1, 1))
182181
}
183182
}
184183
assertEquals(Shape(1, 1), reshaped1.shape)
@@ -188,7 +187,7 @@ class ReshapeSlicingDslTest {
188187
sliceTensor(original) {
189188
segment { all() }
190189
}.let { view ->
191-
original.reshape(Shape(1, 1, 1))
190+
view.reshape(Shape(1, 1, 1))
192191
}
193192
}
194193
assertEquals(Shape(1, 1, 1), reshaped2.shape)
@@ -205,7 +204,7 @@ class ReshapeSlicingDslTest {
205204
sliceTensor(original) {
206205
segment { all() }
207206
}.let { view ->
208-
original.reshape(Shape(2, 3, 2, 2))
207+
view.reshape(Shape(2, 3, 2, 2))
209208
}
210209
}
211210

@@ -260,7 +259,7 @@ class ReshapeSlicingDslTest {
260259
sliceTensor(original) {
261260
segment { all() }
262261
}.let { view ->
263-
original.reshape(Shape(2, 2))
262+
view.reshape(Shape(2, 2))
264263
}
265264
}
266265

@@ -288,7 +287,7 @@ class ReshapeSlicingDslTest {
288287
segment { all() }
289288
segment { all() }
290289
}.let { view ->
291-
original.reshape(Shape(12, 5))
290+
view.reshape(Shape(12, 5))
292291
}
293292
}
294293

@@ -297,15 +296,15 @@ class ReshapeSlicingDslTest {
297296
segment { all() }
298297
segment { all() }
299298
}.let { view ->
300-
step1.reshape(Shape(60))
299+
view.reshape(Shape(60))
301300
}
302301
}
303302

304303
val step3 = with(backendFP32) {
305304
sliceTensor(step2) {
306305
segment { all() }
307306
}.let { view ->
308-
step2.reshape(Shape(6, 10))
307+
view.reshape(Shape(6, 10))
309308
}
310309
}
311310

@@ -314,7 +313,7 @@ class ReshapeSlicingDslTest {
314313
segment { all() }
315314
segment { all() }
316315
}.let { view ->
317-
step3.reshape(Shape(2, 3, 10))
316+
view.reshape(Shape(2, 3, 10))
318317
}
319318
}
320319

@@ -390,11 +389,6 @@ class ReshapeSlicingDslTest {
390389

391390
assertEquals(firstRow.shape, Shape(1, 5))
392391

393-
val lastRow = sliceTensor(original) {
394-
segment { range(3, 4) } // Last row only
395-
segment { all() } // All columns
396-
}
397-
398392
// For this test, we'll just demonstrate with first row reshaped
399393
// Convert view to CpuTensorFP32 first, then reshape
400394
val dataArray = Array<Float>(firstRow.shape.volume) { 0f }
@@ -424,7 +418,7 @@ class ReshapeSlicingDslTest {
424418
segment { all() }
425419
segment { all() }
426420
}.let { view ->
427-
tensor.reshape(Shape(2, 4)) // Volume = 8, original volume = 6
421+
view.reshape(Shape(2, 4)) // Volume = 8, original volume = 6
428422
}
429423
}
430424
}

skainet-core/skainet-tensors/src/commonTest/kotlin/sk/ainet/core/tensor/factory/ByteArrayConverterTest.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package sk.ainet.core.tensor.factory
22

3-
import sk.ainet.core.tensor.*
43
import kotlin.test.*
54

65
/**

0 commit comments

Comments
 (0)