Skip to content

Commit 8ee0c5f

Browse files
Merge pull request #741 from SKaiNET-developers/feature/rowdequant-gather
RowDequantSource in the engine + ops.gather row-dequant path
2 parents fa3610d + a24f21d commit 8ee0c5f

3 files changed

Lines changed: 85 additions & 5 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import sk.ainet.lang.ops.TensorOp
1010
import sk.ainet.lang.ops.InProgress
1111
import sk.ainet.backend.api.kernel.KernelProvider
1212
import sk.ainet.backend.api.kernel.KernelRegistry
13+
import sk.ainet.lang.tensor.data.RowDequantSource
1314
import sk.ainet.lang.tensor.data.FloatArrayTensorData
1415
import sk.ainet.lang.tensor.data.IntArrayTensorData
1516
import sk.ainet.lang.tensor.data.Q4_0TensorData
@@ -2633,18 +2634,33 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
26332634
// Preserve index shape + embedding dim
26342635
Shape(IntArray(indices.rank) { indices.shape[it] } + intArrayOf(embDim))
26352636
}
2636-
val outData = dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
2637-
// Map multi-dim output index to flat index and embedding position
2637+
fun rowOf(outIdx: IntArray): Int {
2638+
// Map multi-dim output index to the flat index into the index list.
26382639
val flatIdx = if (outIdx.size == 2) outIdx[0] else {
26392640
var flat = 0
26402641
for (d in 0 until outIdx.size - 1) {
26412642
flat = flat * (if (d < indices.rank) indices.shape[d] else 1) + outIdx[d]
26422643
}
26432644
flat
26442645
}
2645-
val row = indexList[flatIdx]
2646-
val col = outIdx[outIdx.size - 1]
2647-
input.data[row, col]
2646+
return indexList[flatIdx]
2647+
}
2648+
val src = input.data
2649+
val outData = if (src is RowDequantSource) {
2650+
// Packed / oversized table (e.g. a Q-quantised embedding): dequantise only the rows
2651+
// actually touched — never materialise the whole table, never call get() (unsupported on
2652+
// such tensors). Each unique row is dequantised once; logical dtype is FP32.
2653+
val rowCache = HashMap<Int, FloatArray>()
2654+
dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
2655+
val row = rowOf(outIdx)
2656+
val col = outIdx[outIdx.size - 1]
2657+
@Suppress("UNCHECKED_CAST")
2658+
(rowCache.getOrPut(row) { src.dequantRow(row) }[col] as V)
2659+
}
2660+
} else {
2661+
dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
2662+
input.data[rowOf(outIdx), outIdx[outIdx.size - 1]]
2663+
}
26482664
}
26492665
return newTensor(outData, input.dtype, input)
26502666
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import sk.ainet.context.DirectCpuExecutionContext
4+
import sk.ainet.lang.tensor.Shape
5+
import sk.ainet.lang.tensor.Tensor
6+
import sk.ainet.lang.tensor.data.RowDequantSource
7+
import sk.ainet.lang.tensor.data.TensorData
8+
import sk.ainet.lang.types.FP32
9+
import sk.ainet.lang.types.Int32
10+
import kotlin.test.Test
11+
import kotlin.test.assertContentEquals
12+
import kotlin.test.assertEquals
13+
14+
/**
15+
* `ops.gather` on a [RowDequantSource] table must dequantise only the touched rows — never materialise the
16+
* whole table and never call `get()` (which such tensors don't support). The fake table below throws from
17+
* `get`/`set`, so the test passes only if gather went through [RowDequantSource.dequantRow].
18+
*/
19+
class GatherRowDequantTest {
20+
21+
/** A 4×3 "packed" table: row r dequants to [r*10, r*10+1, r*10+2]. Element access is unsupported. */
22+
private class FakeRowDequantTable : TensorData<FP32, Float>, RowDequantSource {
23+
override val shape: Shape = Shape(4, 3)
24+
override fun dequantRow(rowIdx: Int): FloatArray = FloatArray(3) { rowIdx * 10f + it }
25+
override fun get(vararg indices: Int): Float = error("get() must not be called — use dequantRow()")
26+
override fun set(vararg indices: Int, value: Float) = error("set() unsupported")
27+
override fun copyToFloatArray(): FloatArray = error("copyToFloatArray() must not be called")
28+
}
29+
30+
@Test
31+
fun gatherDequantsTouchedRowsOnly() {
32+
val ctx = DirectCpuExecutionContext.create()
33+
val table = ctx.fromData<FP32, Float>(FakeRowDequantTable(), FP32::class)
34+
val ids = ctx.fromIntArray<Int32, Int>(Shape(3), Int32::class, intArrayOf(2, 0, 3))
35+
36+
@Suppress("UNCHECKED_CAST")
37+
val out = ctx.ops.gather(table, ids as Tensor<sk.ainet.lang.types.DType, *>, dim = 0)
38+
39+
assertEquals(listOf(3, 3), out.shape.dimensions.toList())
40+
assertContentEquals(
41+
floatArrayOf(20f, 21f, 22f, /* row 2 */ 0f, 1f, 2f, /* row 0 */ 30f, 31f, 32f /* row 3 */),
42+
out.data.copyToFloatArray(),
43+
)
44+
}
45+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package sk.ainet.lang.tensor.data
2+
3+
/**
4+
* Marker for a 2-D [TensorData] whose rows can be **dequantised on demand**, for tables that cannot (or
5+
* should not) be materialised as a single dense `FloatArray` — e.g. a packed-quant embedding whose logical
6+
* size exceeds `Int.MAX_VALUE` elements / 2 GB, or one kept packed to save memory.
7+
*
8+
* Such a tensor declares its **logical** dtype `FP32` (the dequantised value type); its packed bytes are an
9+
* internal storage detail, and `get`/`copyToFloatArray()` are typically unsupported. Ops that read whole
10+
* rows — primarily **embedding lookup** (`ops.gather` / `ops.indexSelect`, `dim = 0`, indices = token ids)
11+
* — MUST use [dequantRow] instead of element access, dequantising only the rows actually touched.
12+
*
13+
* This is the engine-level home of the contract; model-specific implementations (e.g. a GGUF Q6_K /
14+
* SafeTensors BF16 embedding) provide [dequantRow] over their own packed source.
15+
*/
16+
public interface RowDequantSource {
17+
/** Dequantise logical row [rowIdx] (`0 until shape[0]`) to a fresh `FloatArray` of length `shape[1]`. */
18+
public fun dequantRow(rowIdx: Int): FloatArray
19+
}

0 commit comments

Comments
 (0)