@@ -9,14 +9,20 @@ import sk.ainet.lang.ops.Backend
99import sk.ainet.lang.ops.TensorOp
1010import sk.ainet.lang.ops.InProgress
1111import sk.ainet.lang.tensor.data.FloatArrayTensorData
12+ import sk.ainet.lang.tensor.data.IntArrayTensorData
13+ import sk.ainet.lang.tensor.data.TensorData
1214import sk.ainet.lang.tensor.data.TensorDataFactory
1315import sk.ainet.lang.tensor.ops.UpsampleMode
16+ import sk.ainet.lang.types.FP16
1417import sk.ainet.lang.types.FP32
18+ import sk.ainet.lang.types.Int32
19+ import sk.ainet.lang.types.Int8
1520import kotlin.math.ln
1621import kotlin.math.log10 as kmLog10
1722import kotlin.math.log2 as kmLog2
1823import kotlin.math.pow
1924import kotlin.math.sqrt
25+ import kotlin.reflect.KClass
2026
2127@Backend(id = " cpu" , displayName = " CPU" )
2228@InProgress(" cpu" , owner = " team:cpu" , issue = " task-ops.md#defaultcpuops" )
@@ -43,6 +49,56 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
4349 vararg inputs : Tensor <T , V >
4450 ): Tensor <T , V > = CpuTensor (data, this , dtype, gradStateFrom(* inputs))
4551
52+ private fun rowMajorStrides (shape : Shape ): IntArray {
53+ val strides = IntArray (shape.rank)
54+ var stride = 1
55+ for (i in shape.rank - 1 downTo 0 ) {
56+ strides[i] = stride
57+ stride * = shape[i]
58+ }
59+ return strides
60+ }
61+
62+ private fun flatIndexToIndices (flatIndex : Int , strides : IntArray ): IntArray {
63+ val indices = IntArray (strides.size)
64+ var remaining = flatIndex
65+ for (i in strides.indices) {
66+ indices[i] = remaining / strides[i]
67+ remaining % = strides[i]
68+ }
69+ return indices
70+ }
71+
72+ private fun <T : DType , V > copyTensorValuesAsFloatArray (tensor : Tensor <T , V >): FloatArray {
73+ val data = tensor.data
74+ return when (data) {
75+ is FloatArrayTensorData <* > -> data.buffer.copyOf()
76+ is IntArrayTensorData <* > -> FloatArray (data.buffer.size) { data.buffer[it].toFloat() }
77+ else -> {
78+ val strides = rowMajorStrides(tensor.shape)
79+ FloatArray (tensor.shape.volume) { flatIndex ->
80+ val indices = flatIndexToIndices(flatIndex, strides)
81+ (data.get(* indices) as Number ).toFloat()
82+ }
83+ }
84+ }
85+ }
86+
87+ private fun <T : DType , V > copyTensorValuesAsIntArray (tensor : Tensor <T , V >): IntArray {
88+ val data = tensor.data
89+ return when (data) {
90+ is IntArrayTensorData <* > -> data.buffer.copyOf()
91+ is FloatArrayTensorData <* > -> IntArray (data.buffer.size) { data.buffer[it].toInt() }
92+ else -> {
93+ val strides = rowMajorStrides(tensor.shape)
94+ IntArray (tensor.shape.volume) { flatIndex ->
95+ val indices = flatIndexToIndices(flatIndex, strides)
96+ (data.get(* indices) as Number ).toInt()
97+ }
98+ }
99+ }
100+ }
101+
46102 protected fun broadcastShapes (a : Shape , b : Shape ): Shape {
47103 val ad = a.dimensions
48104 val bd = b.dimensions
@@ -2427,7 +2483,30 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
24272483 tensor : Tensor <TFrom , V >,
24282484 targetType : TTo
24292485 ): Tensor <TTo , V > {
2430- TODO (" Not yet implemented" )
2486+ @Suppress(" UNCHECKED_CAST" )
2487+ val targetClass = targetType::class as KClass <TTo >
2488+ if (tensor.dtype == targetClass) {
2489+ @Suppress(" UNCHECKED_CAST" )
2490+ return tensor as Tensor <TTo , V >
2491+ }
2492+
2493+ @Suppress(" UNCHECKED_CAST" )
2494+ val outData = when (targetClass) {
2495+ FP32 ::class , FP16 ::class -> dataFactory.fromFloatArray<TTo , Float >(
2496+ tensor.shape,
2497+ targetClass,
2498+ copyTensorValuesAsFloatArray(tensor)
2499+ ) as TensorData <TTo , V >
2500+ Int32 ::class , Int8 ::class -> dataFactory.fromIntArray<TTo , Int >(
2501+ tensor.shape,
2502+ targetClass,
2503+ copyTensorValuesAsIntArray(tensor)
2504+ ) as TensorData <TTo , V >
2505+ else -> throw IllegalArgumentException (
2506+ " convert supports FP32, FP16, Int32, and Int8 targets, got ${targetType.name} "
2507+ )
2508+ }
2509+ return CpuTensor (outData, this , targetClass, GradState (requiresGrad = tensor.requiresGrad))
24312510 }
24322511
24332512 override fun <T : DType , V > gather (input : Tensor <T , V >, indices : Tensor <DType , * >, dim : Int ): Tensor <T , V > {
0 commit comments