Skip to content

Commit ce8b9ee

Browse files
michalharakalclaude
andcommitted
feat(bert): wire ScratchPool into the embedding hot path
Two changes that activate SKaiNET 0.21.0's ScratchPool SPI for the BERT encoder workload: 1. Wrap BertRuntime.forward in ctx.scratch.scope { ... }. Upstream SIMD kernels (matmul, dequant) acquire workspace from ctx.scratch internally; the scope drains acquired buffers back to the pool on exit. With the default NoopScratchPool this is a pass-through; with a real pool it eliminates per-forward FloatArray allocations on what is typically the busiest path for an embedding workload (encode() called many times in a row). 2. Add PooledExecutionContext — a thin ExecutionContext delegate that provides a SizeClassedScratchPool. Wire it as the default ctx in KBertJava.loadSafeTensors, since Java embedding consumers virtually always batch many encode() calls. Default behavior is preserved: callers that construct BertRuntime with a plain DirectCpuExecutionContext (no PooledExecutionContext wrapper) continue to use NoopScratchPool and see no change. 22/22 BertRuntime + BertNumericalAccuracy + HuggingFaceTokenizer tests green on JDK 25. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2da0a2a commit ce8b9ee

3 files changed

Lines changed: 67 additions & 20 deletions

File tree

llm-inference/bert/src/commonMain/kotlin/sk/ainet/models/bert/BertRuntime.kt

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,27 +110,34 @@ public class BertRuntime<T : DType>(
110110
* @param tokenTypeIds optional segment IDs, shape [seqLen] (defaults to all zeros)
111111
* @return hidden states tensor of shape [seqLen, hiddenSize]
112112
*/
113-
public fun forward(tokenIds: IntArray, tokenTypeIds: IntArray? = null): Tensor<T, Float> {
114-
val seqLen = tokenIds.size
115-
val typeIds = tokenTypeIds ?: IntArray(seqLen) { 0 }
116-
val positionIds = IntArray(seqLen) { it }
117-
118-
// Embedding: word + position + token_type
119-
val wordEmb = wordEmbedding.forward(tokenIds, ctx)
120-
val posEmb = positionEmbedding.forward(positionIds, ctx)
121-
val typeEmb = tokenTypeEmbedding.forward(typeIds, ctx)
122-
123-
var hidden = wordEmb + posEmb + typeEmb
124-
hidden = embeddingLayerNorm.forward(hidden, ctx)
125-
126-
// Encoder layers
127-
for (i in weights.layers.indices) {
128-
hidden = runEncoderLayer(i, hidden)
113+
public fun forward(tokenIds: IntArray, tokenTypeIds: IntArray? = null): Tensor<T, Float> =
114+
ctx.scratch.scope {
115+
// ScratchPool scope for the whole forward pass: upstream SIMD
116+
// kernels (matmul, dequant) acquire their per-call workspace
117+
// from ctx.scratch and the buffers are returned to the pool on
118+
// scope exit. With the default NoopScratchPool this is a plain
119+
// pass-through; with a SizeClassedScratchPool it eliminates per-
120+
// forward FloatArray allocations on the embedding hot path.
121+
val seqLen = tokenIds.size
122+
val typeIds = tokenTypeIds ?: IntArray(seqLen) { 0 }
123+
val positionIds = IntArray(seqLen) { it }
124+
125+
// Embedding: word + position + token_type
126+
val wordEmb = wordEmbedding.forward(tokenIds, ctx)
127+
val posEmb = positionEmbedding.forward(positionIds, ctx)
128+
val typeEmb = tokenTypeEmbedding.forward(typeIds, ctx)
129+
130+
var hidden = wordEmb + posEmb + typeEmb
131+
hidden = embeddingLayerNorm.forward(hidden, ctx)
132+
133+
// Encoder layers
134+
for (i in weights.layers.indices) {
135+
hidden = runEncoderLayer(i, hidden)
136+
}
137+
138+
hidden
129139
}
130140

131-
return hidden
132-
}
133-
134141
/**
135142
* Encode text tokens into a single embedding vector (mean pooling + optional projection + L2 norm).
136143
*
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package sk.ainet.models.bert
2+
3+
import sk.ainet.context.ExecutionContext
4+
import sk.ainet.lang.tensor.scratch.ScratchPool
5+
import sk.ainet.lang.tensor.scratch.SizeClassedScratchPool
6+
7+
/**
8+
* Wraps an [ExecutionContext] with a [SizeClassedScratchPool] so that
9+
* upstream SIMD kernels and per-forward intermediates are pooled across
10+
* encoder calls.
11+
*
12+
* Use this when you intend to compute many embeddings from the same model:
13+
*
14+
* ```kotlin
15+
* val baseCtx = DirectCpuExecutionContext(tensorDataFactory = memSegFactory)
16+
* val pooledCtx = PooledExecutionContext(baseCtx)
17+
*
18+
* val runtime = BertRuntime(pooledCtx, weights, FP32::class)
19+
*
20+
* // Each forward acquires + releases scratch buffers in a per-call scope.
21+
* val v1 = runtime.encode(tokens1)
22+
* val v2 = runtime.encode(tokens2) // reuses pooled buffers
23+
* ```
24+
*
25+
* For one-shot use the default `NoopScratchPool` on a plain
26+
* `DirectCpuExecutionContext` is fine — pooling has no benefit when the
27+
* pool is never reused.
28+
*
29+
* **Threading:** `SizeClassedScratchPool` is single-threaded by intent.
30+
* Concurrent encoder calls must each have their own pooled context.
31+
*/
32+
public class PooledExecutionContext(
33+
private val delegate: ExecutionContext,
34+
override val scratch: ScratchPool = SizeClassedScratchPool(),
35+
) : ExecutionContext by delegate

llm-inference/bert/src/jvmMain/kotlin/sk/ainet/models/bert/java/KBertJava.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package sk.ainet.models.bert.java
55
import kotlinx.coroutines.runBlocking
66
import sk.ainet.models.bert.*
77
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.models.bert.PooledExecutionContext
89
import sk.ainet.io.JvmRandomAccessSource
910
import sk.ainet.io.safetensors.SafeTensorsParametersLoader
1011
import sk.ainet.lang.types.FP32
@@ -46,7 +47,11 @@ public object KBertJava {
4647
val config = detectConfig(modelDir)
4748

4849
val tokenizer = HuggingFaceTokenizer.fromVocabTxt(vocabPath.readText())
49-
val ctx = DirectCpuExecutionContext()
50+
// Pool scratch buffers across encode() calls — embedding workloads
51+
// typically encode many strings in a row, so the SizeClassedScratchPool
52+
// returns real wins. With a single one-shot call the pool is no
53+
// worse than NoopScratchPool.
54+
val ctx = PooledExecutionContext(DirectCpuExecutionContext())
5055

5156
val ingestion = BertIngestion<FP32>(ctx, FP32::class, config)
5257
val loader = SafeTensorsParametersLoader(

0 commit comments

Comments
 (0)