@@ -591,7 +591,7 @@ public class CpuTensorInt8(
591591 // For integer tensors, sigmoid is approximated and scaled to byte range
592592 val result = this .data.map {
593593 val sigmoid = 1.0 / (1.0 + exp(- it.toDouble() / 127.0 )) // Scale input to [-1, 1] range
594- clampToByte((sigmoid * 254 - 127 ).toDouble() ) // Scale output to byte range
594+ clampToByte((sigmoid * 254 - 127 )) // Scale output to byte range
595595 }.toByteArray()
596596 return CpuTensorInt8 (this .shape, result)
597597 }
@@ -601,7 +601,7 @@ public class CpuTensorInt8(
601601 // For integer tensors, tanh is approximated and scaled to byte range
602602 val result = this .data.map {
603603 val tanhValue = tanh(it.toDouble() / 127.0 ) // Scale input to [-1, 1] range
604- clampToByte((tanhValue * 127 ).toDouble() ) // Scale output to byte range
604+ clampToByte((tanhValue * 127 )) // Scale output to byte range
605605 }.toByteArray()
606606 return CpuTensorInt8 (this .shape, result)
607607 }
@@ -2365,7 +2365,7 @@ public class CpuBackendInt8 : ComputeBackend<Int8, Byte> {
23652365 override fun Double.minus (t : Tensor <Int8 , Byte >): Tensor <Int8 , Byte > =
23662366 (t as CpuTensorInt8 ).minus(this ).let {
23672367 CpuTensorInt8 .fromArray(t.shape, ByteArray (t.shape.volume) { i ->
2368- ((this - (t as CpuTensorInt8 ) .data[i].toInt()).toInt().toByte())
2368+ ((this - t .data[i].toInt()).toInt().toByte())
23692369 })
23702370 }
23712371
@@ -2395,11 +2395,47 @@ public class CpuBackendInt8 : ComputeBackend<Int8, Byte> {
23952395 override fun Tensor <Int8 , Byte >.flatten (startDim : Int , endDim : Int ): Tensor <Int8 , Byte > =
23962396 (this as CpuTensorInt8 ).flatten(startDim, endDim)
23972397
2398- override fun Tensor <Int8 , Byte >.reshape (newShape : Shape ): Tensor <Int8 , Byte > =
2399- (this as CpuTensorInt8 ).reshape(newShape)
2398+ override fun Tensor <Int8 , Byte >.reshape (newShape : Shape ): Tensor <Int8 , Byte > {
2399+ require(this .shape.volume == newShape.volume) {
2400+ " Cannot reshape tensor with ${this .shape.volume} elements to shape with ${newShape.volume} elements"
2401+ }
2402+
2403+ // Handle both CpuTensorInt8 and sliced tensors
2404+ if (this is CpuTensorInt8 ) {
2405+ // Direct access to data for CpuTensorInt8
2406+ return CpuTensorInt8 .fromArray(newShape, this .data.copyOf())
2407+ } else {
2408+ // For sliced tensors, extract data using copyTo
2409+ val tensorData = Array <Byte >(this .shape.volume) { 0 }
2410+ this .copyTo(tensorData)
2411+ return CpuTensorInt8 .fromArray(newShape, tensorData.toByteArray())
2412+ }
2413+ }
24002414
2401- override fun Tensor <Int8 , Byte >.reshape (vararg dimensions : Int ): Tensor <Int8 , Byte > =
2402- (this as CpuTensorInt8 ).reshape(* dimensions)
2415+ override fun Tensor <Int8 , Byte >.reshape (vararg dimensions : Int ): Tensor <Int8 , Byte > {
2416+ // Count -1 dimensions and validate
2417+ val minusOneCount = dimensions.count { it == - 1 }
2418+ require(minusOneCount <= 1 ) { " Only one dimension can be -1, found $minusOneCount " }
2419+
2420+ val totalKnownElements = dimensions.filter { it != - 1 }.fold(1 ) { acc, dim ->
2421+ require(dim > 0 ) { " All dimensions must be positive or -1, got $dim " }
2422+ acc * dim
2423+ }
2424+
2425+ val inferredDimensions = if (minusOneCount == 1 ) {
2426+ // Calculate the missing dimension
2427+ val missingDimSize = this .shape.volume / totalKnownElements
2428+ require(this .shape.volume % totalKnownElements == 0 ) {
2429+ " Cannot infer dimension size: volume ${this .shape.volume} is not divisible by known dimensions product $totalKnownElements "
2430+ }
2431+ dimensions.map { if (it == - 1 ) missingDimSize else it }.toIntArray()
2432+ } else {
2433+ dimensions
2434+ }
2435+
2436+ val newShape = Shape (inferredDimensions)
2437+ return this .reshape(newShape)
2438+ }
24032439
24042440 override fun zeros (shape : Shape ): Tensor <Int8 , Byte > =
24052441 CpuTensorInt8 .zeros(shape)
@@ -2570,11 +2606,47 @@ public class CpuBackendInt32 : ComputeBackend<Int32, Int> {
25702606 override fun Tensor <Int32 , Int >.flatten (startDim : Int , endDim : Int ): Tensor <Int32 , Int > =
25712607 (this as CpuTensorInt32 ).flatten(startDim, endDim)
25722608
2573- override fun Tensor <Int32 , Int >.reshape (newShape : Shape ): Tensor <Int32 , Int > =
2574- (this as CpuTensorInt32 ).reshape(newShape)
2609+ override fun Tensor <Int32 , Int >.reshape (newShape : Shape ): Tensor <Int32 , Int > {
2610+ require(this .shape.volume == newShape.volume) {
2611+ " Cannot reshape tensor with ${this .shape.volume} elements to shape with ${newShape.volume} elements"
2612+ }
2613+
2614+ // Handle both CpuTensorInt32 and sliced tensors
2615+ if (this is CpuTensorInt32 ) {
2616+ // Direct access to data for CpuTensorInt32
2617+ return CpuTensorInt32 .fromArray(newShape, this .data.copyOf())
2618+ } else {
2619+ // For sliced tensors, extract data using copyTo
2620+ val tensorData = Array <Int >(this .shape.volume) { 0 }
2621+ this .copyTo(tensorData)
2622+ return CpuTensorInt32 .fromArray(newShape, tensorData.toIntArray())
2623+ }
2624+ }
25752625
2576- override fun Tensor <Int32 , Int >.reshape (vararg dimensions : Int ): Tensor <Int32 , Int > =
2577- (this as CpuTensorInt32 ).reshape(* dimensions)
2626+ override fun Tensor <Int32 , Int >.reshape (vararg dimensions : Int ): Tensor <Int32 , Int > {
2627+ // Count -1 dimensions and validate
2628+ val minusOneCount = dimensions.count { it == - 1 }
2629+ require(minusOneCount <= 1 ) { " Only one dimension can be -1, found $minusOneCount " }
2630+
2631+ val totalKnownElements = dimensions.filter { it != - 1 }.fold(1 ) { acc, dim ->
2632+ require(dim > 0 ) { " All dimensions must be positive or -1, got $dim " }
2633+ acc * dim
2634+ }
2635+
2636+ val inferredDimensions = if (minusOneCount == 1 ) {
2637+ // Calculate the missing dimension
2638+ val missingDimSize = this .shape.volume / totalKnownElements
2639+ require(this .shape.volume % totalKnownElements == 0 ) {
2640+ " Cannot infer dimension size: volume ${this .shape.volume} is not divisible by known dimensions product $totalKnownElements "
2641+ }
2642+ dimensions.map { if (it == - 1 ) missingDimSize else it }.toIntArray()
2643+ } else {
2644+ dimensions
2645+ }
2646+
2647+ val newShape = Shape (inferredDimensions)
2648+ return this .reshape(newShape)
2649+ }
25782650
25792651 override fun zeros (shape : Shape ): Tensor <Int32 , Int > =
25802652 CpuTensorInt32 .zeros(shape)
0 commit comments