Skip to content

Commit 4acec7f

Browse files
michalharakalclaude
andcommitted
feat(gemma): optional maxInferenceLen on GemmaNetworkLoader.load() (#178)
The eager network sizes its KV cache + RoPE tables for maxInferenceLen (= min(contextLength, 4096) by default). On the 1.9 GB SL2610 that ~0.4 GB KV cache (allocated at the first forward) OOMs the board even after the packed Q8_0 lm_head dropped the weight footprint to ~1.06 GB resident. Thread an optional `maxInferenceLen: Int? = null` through load() -> applyWeightsToNetwork -> applyWeightsToNetworkNonReified -> gemmaNetwork so a constrained-device consumer can cap the context (e.g. 32 for a short tool-call prompt), shrinking the KV cache ~100x. Default null preserves the existing min(contextLength, 4096) behaviour. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent f94ce6c commit 4acec7f

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
120120
* Load weights and build a fully initialized DSL network.
121121
*/
122122
public suspend inline fun <reified T : DType, V> load(
123-
ctx: ExecutionContext
123+
ctx: ExecutionContext,
124+
maxInferenceLen: Int? = null,
124125
): Module<T, V> {
125126
val rawWeights: Gemma4Weights<T, V> = when (val wp = weightsProvider) {
126127
is WeightsProvider.GgufSource -> {
@@ -160,14 +161,15 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
160161
rawWeights
161162
}
162163

163-
return applyWeightsToNetwork(ctx, weights)
164+
return applyWeightsToNetwork(ctx, weights, maxInferenceLen)
164165
}
165166

166167
@PublishedApi
167168
internal inline fun <reified T : DType, V> applyWeightsToNetwork(
168169
ctx: ExecutionContext,
169-
weights: Gemma4Weights<T, V>
170-
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug)
170+
weights: Gemma4Weights<T, V>,
171+
maxInferenceLen: Int? = null,
172+
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug, maxInferenceLen)
171173
}
172174

173175
/** Shared non-reified impl used by both the inline-reified companion helpers
@@ -177,7 +179,8 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
177179
ctx: ExecutionContext,
178180
weights: Gemma4Weights<T, V>,
179181
dtype: kotlin.reflect.KClass<T>,
180-
debug: Boolean
182+
debug: Boolean,
183+
maxInferenceLen: Int? = null,
181184
): Module<T, V> {
182185
// Enable optional Gemma 4 features iff the checkpoint actually carries
183186
// their weights. Real Gemma 4 GGUFs do; synthetic toy-model tests do not,
@@ -197,6 +200,7 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
197200
val model = gemmaNetwork<T, V>(
198201
weights.metadata,
199202
dtype,
203+
maxInferenceLen = maxInferenceLen ?: minOf(weights.metadata.contextLength, 4096),
200204
qkNorm = hasQKNorm,
201205
sandwichNorms = hasSandwichNorms,
202206
layerOutputScale = hasLayerOutputScale,

0 commit comments

Comments
 (0)