Skip to content

Commit d8a3586

Browse files
Merge pull request #636 from SKaiNET-developers/feature/tensor-convert-pr3
Implement CPU tensor convert op
2 parents d2950c6 + 5a3d497 commit d8a3586

2 files changed

Lines changed: 179 additions & 1 deletion

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@ import sk.ainet.lang.ops.Backend
99
import sk.ainet.lang.ops.TensorOp
1010
import sk.ainet.lang.ops.InProgress
1111
import sk.ainet.lang.tensor.data.FloatArrayTensorData
12+
import sk.ainet.lang.tensor.data.IntArrayTensorData
13+
import sk.ainet.lang.tensor.data.TensorData
1214
import sk.ainet.lang.tensor.data.TensorDataFactory
1315
import sk.ainet.lang.tensor.ops.UpsampleMode
16+
import sk.ainet.lang.types.FP16
1417
import sk.ainet.lang.types.FP32
18+
import sk.ainet.lang.types.Int32
19+
import sk.ainet.lang.types.Int8
1520
import kotlin.math.ln
1621
import kotlin.math.log10 as kmLog10
1722
import kotlin.math.log2 as kmLog2
1823
import kotlin.math.pow
1924
import kotlin.math.sqrt
25+
import kotlin.reflect.KClass
2026

2127
@Backend(id = "cpu", displayName = "CPU")
2228
@InProgress("cpu", owner = "team:cpu", issue = "task-ops.md#defaultcpuops")
@@ -43,6 +49,56 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
4349
vararg inputs: Tensor<T, V>
4450
): Tensor<T, V> = CpuTensor(data, this, dtype, gradStateFrom(*inputs))
4551

52+
private fun rowMajorStrides(shape: Shape): IntArray {
53+
val strides = IntArray(shape.rank)
54+
var stride = 1
55+
for (i in shape.rank - 1 downTo 0) {
56+
strides[i] = stride
57+
stride *= shape[i]
58+
}
59+
return strides
60+
}
61+
62+
private fun flatIndexToIndices(flatIndex: Int, strides: IntArray): IntArray {
63+
val indices = IntArray(strides.size)
64+
var remaining = flatIndex
65+
for (i in strides.indices) {
66+
indices[i] = remaining / strides[i]
67+
remaining %= strides[i]
68+
}
69+
return indices
70+
}
71+
72+
private fun <T : DType, V> copyTensorValuesAsFloatArray(tensor: Tensor<T, V>): FloatArray {
73+
val data = tensor.data
74+
return when (data) {
75+
is FloatArrayTensorData<*> -> data.buffer.copyOf()
76+
is IntArrayTensorData<*> -> FloatArray(data.buffer.size) { data.buffer[it].toFloat() }
77+
else -> {
78+
val strides = rowMajorStrides(tensor.shape)
79+
FloatArray(tensor.shape.volume) { flatIndex ->
80+
val indices = flatIndexToIndices(flatIndex, strides)
81+
(data.get(*indices) as Number).toFloat()
82+
}
83+
}
84+
}
85+
}
86+
87+
private fun <T : DType, V> copyTensorValuesAsIntArray(tensor: Tensor<T, V>): IntArray {
88+
val data = tensor.data
89+
return when (data) {
90+
is IntArrayTensorData<*> -> data.buffer.copyOf()
91+
is FloatArrayTensorData<*> -> IntArray(data.buffer.size) { data.buffer[it].toInt() }
92+
else -> {
93+
val strides = rowMajorStrides(tensor.shape)
94+
IntArray(tensor.shape.volume) { flatIndex ->
95+
val indices = flatIndexToIndices(flatIndex, strides)
96+
(data.get(*indices) as Number).toInt()
97+
}
98+
}
99+
}
100+
}
101+
46102
protected fun broadcastShapes(a: Shape, b: Shape): Shape {
47103
val ad = a.dimensions
48104
val bd = b.dimensions
@@ -2427,7 +2483,30 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
24272483
tensor: Tensor<TFrom, V>,
24282484
targetType: TTo
24292485
): Tensor<TTo, V> {
2430-
TODO("Not yet implemented")
2486+
@Suppress("UNCHECKED_CAST")
2487+
val targetClass = targetType::class as KClass<TTo>
2488+
if (tensor.dtype == targetClass) {
2489+
@Suppress("UNCHECKED_CAST")
2490+
return tensor as Tensor<TTo, V>
2491+
}
2492+
2493+
@Suppress("UNCHECKED_CAST")
2494+
val outData = when (targetClass) {
2495+
FP32::class, FP16::class -> dataFactory.fromFloatArray<TTo, Float>(
2496+
tensor.shape,
2497+
targetClass,
2498+
copyTensorValuesAsFloatArray(tensor)
2499+
) as TensorData<TTo, V>
2500+
Int32::class, Int8::class -> dataFactory.fromIntArray<TTo, Int>(
2501+
tensor.shape,
2502+
targetClass,
2503+
copyTensorValuesAsIntArray(tensor)
2504+
) as TensorData<TTo, V>
2505+
else -> throw IllegalArgumentException(
2506+
"convert supports FP32, FP16, Int32, and Int8 targets, got ${targetType.name}"
2507+
)
2508+
}
2509+
return CpuTensor(outData, this, targetClass, GradState(requiresGrad = tensor.requiresGrad))
24312510
}
24322511

24332512
override fun <T : DType, V> gather(input: Tensor<T, V>, indices: Tensor<DType, *>, dim: Int): Tensor<T, V> {
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertContentEquals
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertFailsWith
7+
import kotlin.test.assertSame
8+
import kotlin.test.assertTrue
9+
import sk.ainet.lang.tensor.GradState
10+
import sk.ainet.lang.tensor.Shape
11+
import sk.ainet.lang.tensor.VoidOpsTensor
12+
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
13+
import sk.ainet.lang.tensor.data.FloatArrayTensorData
14+
import sk.ainet.lang.tensor.data.IntArrayTensorData
15+
import sk.ainet.lang.types.FP16
16+
import sk.ainet.lang.types.FP32
17+
import sk.ainet.lang.types.Int16
18+
import sk.ainet.lang.types.Int32
19+
20+
class DefaultCpuOpsConvertTest {
21+
private val dataFactory = DenseTensorDataFactory()
22+
private val ops = DefaultCpuOps(dataFactory)
23+
24+
private fun fp32Tensor(
25+
shape: Shape,
26+
values: FloatArray,
27+
requiresGrad: Boolean = false
28+
): VoidOpsTensor<FP32, Float> {
29+
val data = dataFactory.fromFloatArray<FP32, Float>(shape, FP32::class, values)
30+
return VoidOpsTensor(data, FP32::class, GradState(requiresGrad = requiresGrad))
31+
}
32+
33+
private fun int32Tensor(shape: Shape, values: IntArray): VoidOpsTensor<Int32, Int> {
34+
val data = dataFactory.fromIntArray<Int32, Int>(shape, Int32::class, values)
35+
return VoidOpsTensor(data, Int32::class)
36+
}
37+
38+
@Test
39+
fun convertFp32ToFp16PreservesShapeValuesAndGradRequirement() {
40+
val input = fp32Tensor(
41+
Shape(2, 2),
42+
floatArrayOf(1.25f, -2.5f, 3.75f, 4.5f),
43+
requiresGrad = true
44+
)
45+
46+
val result = ops.convert(input, FP16)
47+
48+
assertEquals(Shape(2, 2), result.shape)
49+
assertEquals(FP16::class, result.dtype)
50+
assertTrue(result.requiresGrad)
51+
assertContentEquals(
52+
floatArrayOf(1.25f, -2.5f, 3.75f, 4.5f),
53+
(result.data as FloatArrayTensorData<*>).buffer
54+
)
55+
}
56+
57+
@Test
58+
fun convertInt32ToFp32CastsValuesToFloat() {
59+
val input = int32Tensor(Shape(2, 2), intArrayOf(1, -2, 3, 4))
60+
61+
val result = ops.convert(input, FP32)
62+
63+
assertEquals(Shape(2, 2), result.shape)
64+
assertEquals(FP32::class, result.dtype)
65+
assertContentEquals(
66+
floatArrayOf(1f, -2f, 3f, 4f),
67+
(result.data as FloatArrayTensorData<*>).buffer
68+
)
69+
}
70+
71+
@Test
72+
fun convertFp32ToInt32CastsValuesToInt() {
73+
val input = fp32Tensor(Shape(4), floatArrayOf(1.9f, -2.1f, 3.0f, 4.8f))
74+
75+
val result = ops.convert(input, Int32)
76+
77+
assertEquals(Shape(4), result.shape)
78+
assertEquals(Int32::class, result.dtype)
79+
assertContentEquals(intArrayOf(1, -2, 3, 4), (result.data as IntArrayTensorData<*>).buffer)
80+
}
81+
82+
@Test
83+
fun convertToSameDtypeReturnsInputTensor() {
84+
val input = fp32Tensor(Shape(2), floatArrayOf(1f, 2f))
85+
86+
val result = ops.convert(input, FP32)
87+
88+
assertSame(input, result)
89+
}
90+
91+
@Test
92+
fun convertRejectsUnsupportedTargetDtype() {
93+
val input = fp32Tensor(Shape(2), floatArrayOf(1f, 2f))
94+
95+
assertFailsWith<IllegalArgumentException> {
96+
ops.convert(input, Int16)
97+
}
98+
}
99+
}

0 commit comments

Comments
 (0)