Skip to content

Commit 19d62d4

Browse files
Merge pull request #180 from SKaiNET-developers/fix/gemma-board-embed-nocopy
feat(gemma): optional maxInferenceLen on load() to cap KV cache on constrained devices (#178)
2 parents 689a283 + 4acec7f commit 19d62d4

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
120120
* Load weights and build a fully initialized DSL network.
121121
*/
122122
public suspend inline fun <reified T : DType, V> load(
123-
ctx: ExecutionContext
123+
ctx: ExecutionContext,
124+
maxInferenceLen: Int? = null,
124125
): Module<T, V> {
125126
val rawWeights: Gemma4Weights<T, V> = when (val wp = weightsProvider) {
126127
is WeightsProvider.GgufSource -> {
@@ -160,14 +161,15 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
160161
rawWeights
161162
}
162163

163-
return applyWeightsToNetwork(ctx, weights)
164+
return applyWeightsToNetwork(ctx, weights, maxInferenceLen)
164165
}
165166

166167
@PublishedApi
167168
internal inline fun <reified T : DType, V> applyWeightsToNetwork(
168169
ctx: ExecutionContext,
169-
weights: Gemma4Weights<T, V>
170-
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug)
170+
weights: Gemma4Weights<T, V>,
171+
maxInferenceLen: Int? = null,
172+
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug, maxInferenceLen)
171173
}
172174

173175
/** Shared non-reified impl used by both the inline-reified companion helpers
@@ -177,7 +179,8 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
177179
ctx: ExecutionContext,
178180
weights: Gemma4Weights<T, V>,
179181
dtype: kotlin.reflect.KClass<T>,
180-
debug: Boolean
182+
debug: Boolean,
183+
maxInferenceLen: Int? = null,
181184
): Module<T, V> {
182185
// Enable optional Gemma 4 features iff the checkpoint actually carries
183186
// their weights. Real Gemma 4 GGUFs do; synthetic toy-model tests do not,
@@ -197,6 +200,7 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
197200
val model = gemmaNetwork<T, V>(
198201
weights.metadata,
199202
dtype,
203+
maxInferenceLen = maxInferenceLen ?: minOf(weights.metadata.contextLength, 4096),
200204
qkNorm = hasQKNorm,
201205
sandwichNorms = hasSandwichNorms,
202206
layerOutputScale = hasLayerOutputScale,

0 commit comments

Comments
 (0)