@@ -1960,6 +1960,194 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
19601960 return newTensor(outData, tensor.dtype, tensor)
19611961 }
19621962
1963+ // ---- TinyFoA ops: abs, sign, clamp, lt, ge ----
1964+
1965+ @TensorOp()
1966+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-abs" )
1967+ override fun <T : DType , V > abs (tensor : Tensor <T , V >): Tensor <T , V > {
1968+ val outData = dataFactory.init <T , V >(tensor.shape, tensor.dtype) { idx ->
1969+ when (tensor.dtype) {
1970+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
1971+ val v = tensor.data.get(* idx) as Float
1972+ @Suppress(" UNCHECKED_CAST" )
1973+ kotlin.math.abs(v) as V
1974+ }
1975+ sk.ainet.lang.types.Int32 ::class -> {
1976+ val v = tensor.data.get(* idx) as Int
1977+ @Suppress(" UNCHECKED_CAST" )
1978+ kotlin.math.abs(v) as V
1979+ }
1980+ else -> throw IllegalArgumentException (" Unsupported dtype for abs: ${tensor.dtype} " )
1981+ }
1982+ }
1983+ return newTensor(outData, tensor.dtype, tensor)
1984+ }
1985+
1986+ @TensorOp()
1987+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-sign" )
1988+ override fun <T : DType , V > sign (tensor : Tensor <T , V >): Tensor <T , V > {
1989+ val outData = dataFactory.init <T , V >(tensor.shape, tensor.dtype) { idx ->
1990+ when (tensor.dtype) {
1991+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
1992+ val v = tensor.data.get(* idx) as Float
1993+ @Suppress(" UNCHECKED_CAST" )
1994+ (if (v > 0f ) 1f else if (v < 0f ) - 1f else 0f ) as V
1995+ }
1996+ sk.ainet.lang.types.Int32 ::class -> {
1997+ val v = tensor.data.get(* idx) as Int
1998+ @Suppress(" UNCHECKED_CAST" )
1999+ (if (v > 0 ) 1 else if (v < 0 ) - 1 else 0 ) as V
2000+ }
2001+ else -> throw IllegalArgumentException (" Unsupported dtype for sign: ${tensor.dtype} " )
2002+ }
2003+ }
2004+ return newTensor(outData, tensor.dtype, tensor)
2005+ }
2006+
2007+ @TensorOp()
2008+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-clamp" )
2009+ override fun <T : DType , V > clamp (tensor : Tensor <T , V >, minVal : Float , maxVal : Float ): Tensor <T , V > {
2010+ require(minVal <= maxVal) { " clamp: minVal ($minVal ) must be <= maxVal ($maxVal )" }
2011+ val outData = dataFactory.init <T , V >(tensor.shape, tensor.dtype) { idx ->
2012+ when (tensor.dtype) {
2013+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
2014+ val v = tensor.data.get(* idx) as Float
2015+ @Suppress(" UNCHECKED_CAST" )
2016+ v.coerceIn(minVal, maxVal) as V
2017+ }
2018+ sk.ainet.lang.types.Int32 ::class -> {
2019+ val v = tensor.data.get(* idx) as Int
2020+ @Suppress(" UNCHECKED_CAST" )
2021+ v.coerceIn(minVal.toInt(), maxVal.toInt()) as V
2022+ }
2023+ else -> throw IllegalArgumentException (" Unsupported dtype for clamp: ${tensor.dtype} " )
2024+ }
2025+ }
2026+ return newTensor(outData, tensor.dtype, tensor)
2027+ }
2028+
2029+ @TensorOp()
2030+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-lt" )
2031+ override fun <T : DType , V > lt (tensor : Tensor <T , V >, value : Float ): Tensor <T , V > {
2032+ val outData = dataFactory.init <T , V >(tensor.shape, tensor.dtype) { idx ->
2033+ when (tensor.dtype) {
2034+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
2035+ val v = tensor.data.get(* idx) as Float
2036+ @Suppress(" UNCHECKED_CAST" )
2037+ (if (v < value) 1f else 0f ) as V
2038+ }
2039+ sk.ainet.lang.types.Int32 ::class -> {
2040+ val v = tensor.data.get(* idx) as Int
2041+ @Suppress(" UNCHECKED_CAST" )
2042+ (if (v < value.toInt()) 1 else 0 ) as V
2043+ }
2044+ else -> throw IllegalArgumentException (" Unsupported dtype for lt: ${tensor.dtype} " )
2045+ }
2046+ }
2047+ return newTensor(outData, tensor.dtype, tensor)
2048+ }
2049+
2050+ @TensorOp()
2051+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-ge" )
2052+ override fun <T : DType , V > ge (tensor : Tensor <T , V >, value : Float ): Tensor <T , V > {
2053+ val outData = dataFactory.init <T , V >(tensor.shape, tensor.dtype) { idx ->
2054+ when (tensor.dtype) {
2055+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
2056+ val v = tensor.data.get(* idx) as Float
2057+ @Suppress(" UNCHECKED_CAST" )
2058+ (if (v >= value) 1f else 0f ) as V
2059+ }
2060+ sk.ainet.lang.types.Int32 ::class -> {
2061+ val v = tensor.data.get(* idx) as Int
2062+ @Suppress(" UNCHECKED_CAST" )
2063+ (if (v >= value.toInt()) 1 else 0 ) as V
2064+ }
2065+ else -> throw IllegalArgumentException (" Unsupported dtype for ge: ${tensor.dtype} " )
2066+ }
2067+ }
2068+ return newTensor(outData, tensor.dtype, tensor)
2069+ }
2070+
2071+ // ---- narrow, pad2d, unfold ----
2072+
2073+ @TensorOp()
2074+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-narrow" )
2075+ override fun <T : DType , V > narrow (tensor : Tensor <T , V >, dim : Int , start : Int , length : Int ): Tensor <T , V > {
2076+ val actualDim = if (dim < 0 ) tensor.shape.rank + dim else dim
2077+ require(actualDim in 0 until tensor.shape.rank) { " narrow dim $dim out of bounds for rank ${tensor.shape.rank} " }
2078+ require(start >= 0 && start + length <= tensor.shape.dimensions[actualDim]) {
2079+ " narrow: start=$start length=$length exceeds dim size ${tensor.shape.dimensions[actualDim]} "
2080+ }
2081+ val resultDims = tensor.shape.dimensions.copyOf()
2082+ resultDims[actualDim] = length
2083+ val outShape = Shape (resultDims)
2084+ val outData = dataFactory.init <T , V >(outShape, tensor.dtype) { idx ->
2085+ val srcIdx = idx.copyOf()
2086+ srcIdx[actualDim] = srcIdx[actualDim] + start
2087+ tensor.data.get(* srcIdx)
2088+ }
2089+ return newTensor(outData, tensor.dtype, tensor)
2090+ }
2091+
2092+ @TensorOp()
2093+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-pad2d" )
2094+ override fun <T : DType , V > pad2d (tensor : Tensor <T , V >, padLeft : Int , padRight : Int , padTop : Int , padBottom : Int ): Tensor <T , V > {
2095+ require(tensor.shape.rank == 4 ) { " pad2d requires 4D tensor [N,C,H,W], got rank ${tensor.shape.rank} " }
2096+ val (n, c, h, w) = tensor.shape.dimensions.toList()
2097+ val newH = h + padTop + padBottom
2098+ val newW = w + padLeft + padRight
2099+ val outShape = Shape (n, c, newH, newW)
2100+ val outData = dataFactory.init <T , V >(outShape, tensor.dtype) { idx ->
2101+ val srcRow = idx[2 ] - padTop
2102+ val srcCol = idx[3 ] - padLeft
2103+ if (srcRow in 0 until h && srcCol in 0 until w) {
2104+ tensor.data.get(idx[0 ], idx[1 ], srcRow, srcCol)
2105+ } else {
2106+ // Zero padding
2107+ when (tensor.dtype) {
2108+ sk.ainet.lang.types.FP32 ::class , sk.ainet.lang.types.FP16 ::class -> {
2109+ @Suppress(" UNCHECKED_CAST" ) (0f as V )
2110+ }
2111+ sk.ainet.lang.types.Int32 ::class -> {
2112+ @Suppress(" UNCHECKED_CAST" ) (0 as V )
2113+ }
2114+ else -> throw IllegalArgumentException (" Unsupported dtype for pad2d: ${tensor.dtype} " )
2115+ }
2116+ }
2117+ }
2118+ return newTensor(outData, tensor.dtype, tensor)
2119+ }
2120+
2121+ @TensorOp()
2122+ @InProgress(" cpu" , owner = " team:tinyfoa" , issue = " PRD-tinyFoA#op-unfold" )
2123+ override fun <T : DType , V > unfold (tensor : Tensor <T , V >, dim : Int , size : Int , step : Int ): Tensor <T , V > {
2124+ val actualDim = if (dim < 0 ) tensor.shape.rank + dim else dim
2125+ require(actualDim in 0 until tensor.shape.rank) { " unfold dim $dim out of bounds for rank ${tensor.shape.rank} " }
2126+ val dimSize = tensor.shape.dimensions[actualDim]
2127+ require(size <= dimSize) { " unfold size $size > dim size $dimSize " }
2128+ val numWindows = (dimSize - size) / step + 1
2129+ val resultDims = IntArray (tensor.shape.rank + 1 )
2130+ for (i in 0 until tensor.shape.rank) {
2131+ resultDims[i] = if (i == actualDim) numWindows else tensor.shape.dimensions[i]
2132+ }
2133+ resultDims[tensor.shape.rank] = size
2134+ val outShape = Shape (resultDims)
2135+ val outData = dataFactory.init <T , V >(outShape, tensor.dtype) { idx ->
2136+ // idx has rank+1 dimensions. Last dimension is the window element index.
2137+ val windowIdx = idx[tensor.shape.rank]
2138+ val srcIdx = IntArray (tensor.shape.rank)
2139+ for (i in 0 until tensor.shape.rank) {
2140+ srcIdx[i] = if (i == actualDim) {
2141+ idx[i] * step + windowIdx
2142+ } else {
2143+ idx[i]
2144+ }
2145+ }
2146+ tensor.data.get(* srcIdx)
2147+ }
2148+ return newTensor(outData, tensor.dtype, tensor)
2149+ }
2150+
19632151 @TensorOp()
19642152 @InProgress(" cpu" , owner = " team:cpu" , issue = " task-ops.md#op-convert" )
19652153 override fun <TFrom : DType , TTo : DType , V > convert (
0 commit comments