@@ -120,7 +120,8 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
120120 * Load weights and build a fully initialized DSL network.
121121 */
122122 public suspend inline fun <reified T : DType , V > load (
123- ctx : ExecutionContext
123+ ctx : ExecutionContext ,
124+ maxInferenceLen : Int? = null,
124125 ): Module <T , V > {
125126 val rawWeights: Gemma4Weights <T , V > = when (val wp = weightsProvider) {
126127 is WeightsProvider .GgufSource -> {
@@ -160,14 +161,15 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
160161 rawWeights
161162 }
162163
163- return applyWeightsToNetwork(ctx, weights)
164+ return applyWeightsToNetwork(ctx, weights, maxInferenceLen )
164165 }
165166
166167 @PublishedApi
167168 internal inline fun <reified T : DType , V > applyWeightsToNetwork (
168169 ctx : ExecutionContext ,
169- weights : Gemma4Weights <T , V >
170- ): Module <T , V > = applyWeightsToNetworkNonReified(ctx, weights, T ::class , debug)
170+ weights : Gemma4Weights <T , V >,
171+ maxInferenceLen : Int? = null,
172+ ): Module <T , V > = applyWeightsToNetworkNonReified(ctx, weights, T ::class , debug, maxInferenceLen)
171173}
172174
173175/* * Shared non-reified impl used by both the inline-reified companion helpers
@@ -177,7 +179,8 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
177179 ctx : ExecutionContext ,
178180 weights : Gemma4Weights <T , V >,
179181 dtype : kotlin.reflect.KClass <T >,
180- debug : Boolean
182+ debug : Boolean ,
183+ maxInferenceLen : Int? = null,
181184): Module <T , V > {
182185 // Enable optional Gemma 4 features iff the checkpoint actually carries
183186 // their weights. Real Gemma 4 GGUFs do; synthetic toy-model tests do not,
@@ -197,6 +200,7 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
197200 val model = gemmaNetwork<T , V >(
198201 weights.metadata,
199202 dtype,
203+ maxInferenceLen = maxInferenceLen ? : minOf(weights.metadata.contextLength, 4096 ),
200204 qkNorm = hasQKNorm,
201205 sandwichNorms = hasSandwichNorms,
202206 layerOutputScale = hasLayerOutputScale,
0 commit comments