@@ -322,234 +322,129 @@ function getNormalizedConfig(config) {
322322}
323323
324324/**
325- *
326325 * @param {PretrainedConfig } config
327- * @returns {Record<string, number[]> }
326+ * @param {{ prefix?: string, session_name?: string } } [options]
327+ * @returns {Set<string> }
328328 */
329- export function getCacheShapes ( config , options ) {
329+ export function getCacheNames ( config , options ) {
330330 if ( ! ( config instanceof PretrainedConfig ) ) {
331331 config = new PretrainedConfig ( config ) ;
332332 }
333333
334- const batch_size = options ?. batch_size ?? 1 ;
335- if ( [ 'lfm2' , 'lfm2_moe' ] . includes ( config . model_type ) ) {
336- const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
337- const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
334+ const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
335+ const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
336+ /** @type { Set<string> } */
337+ const names = new Set ( ) ;
338338
339- /** @type {Record<string, number[]> } */
340- const cache_values = { } ;
341- const { layer_types, num_attention_heads, num_key_value_heads, hidden_size, conv_L_cache } =
342- /** @type {any } */ ( config ) ;
343- const head_dim = hidden_size / num_attention_heads ;
339+ if ( [ 'lfm2' , 'lfm2_moe' ] . includes ( config . model_type ) ) {
340+ const { layer_types } = /** @type {any } */ ( config ) ;
344341 for ( let i = 0 ; i < layer_types . length ; ++ i ) {
345342 if ( layer_types [ i ] === 'full_attention' ) {
346- for ( const kv of [ 'key' , 'value' ] ) {
347- cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
348- }
343+ names . add ( `${ pkv_prefix } .${ i } .key` ) ;
344+ names . add ( `${ pkv_prefix } .${ i } .value` ) ;
349345 } else if ( layer_types [ i ] === 'conv' ) {
350- cache_values [ `${ conv_prefix } _conv.${ i } ` ] = [ batch_size , hidden_size , conv_L_cache ] ;
346+ names . add ( `${ conv_prefix } _conv.${ i } ` ) ;
351347 } else {
352348 throw new Error ( `Unsupported layer type: ${ layer_types [ i ] } ` ) ;
353349 }
354350 }
355- return cache_values ;
351+ return names ;
356352 } else if ( [ 'granitemoehybrid' , 'falcon_h1' , 'nemotron_h' ] . includes ( config . model_type ) ) {
357- const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
358- const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
359-
360353 const c = /** @type {any } */ ( config ) ;
361-
362- // Normalize config field names across model types
363354 const layer_types = c . layer_types ?? c . layers_block_type ;
364355 const num_layers = c . num_hidden_layers ?? layer_types ?. length ;
365- const num_key_value_heads = c . num_key_value_heads ;
366- const head_dim = c . head_dim ?? c . hidden_size / c . num_attention_heads ;
367- const mamba_n_heads = c . mamba_n_heads ?? c . mamba_num_heads ;
368- const mamba_d_head = c . mamba_d_head ?? c . mamba_head_dim ;
369- const mamba_d_state = c . mamba_d_state ?? c . ssm_state_size ;
370- const mamba_n_groups = c . mamba_n_groups ?? c . n_groups ;
371- const mamba_d_conv = c . mamba_d_conv ?? c . conv_kernel ;
372- const mamba_d_ssm =
373- c . mamba_d_ssm ?? ( c . mamba_expand ? c . mamba_expand * c . hidden_size : mamba_n_heads * mamba_d_head ) ;
374- const conv_d_inner = mamba_d_ssm + 2 * mamba_n_groups * mamba_d_state ;
375-
376- /** @type {Record<string, number[]> } */
377- const cache_values = { } ;
378356
379357 for ( let i = 0 ; i < num_layers ; ++ i ) {
380358 if ( ! layer_types || layer_types [ i ] === 'mamba' ) {
381- cache_values [ `${ conv_prefix } _conv.${ i } ` ] = [ batch_size , conv_d_inner , mamba_d_conv ] ;
382- cache_values [ `${ conv_prefix } _ssm.${ i } ` ] = [ batch_size , mamba_n_heads , mamba_d_head , mamba_d_state ] ;
359+ names . add ( `${ conv_prefix } _conv.${ i } ` ) ;
360+ names . add ( `${ conv_prefix } _ssm.${ i } ` ) ;
383361 }
384362 if ( ! layer_types || layer_types [ i ] === 'attention' ) {
385- for ( const kv of [ 'key' , 'value' ] ) {
386- cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
387- }
363+ names . add ( `${ pkv_prefix } .${ i } .key` ) ;
364+ names . add ( `${ pkv_prefix } .${ i } .value` ) ;
388365 }
389366 }
390- return cache_values ;
367+ return names ;
391368 } else if ( [ 'qwen3_next' , 'qwen3_5_text' , 'qwen3_5_moe_text' , 'olmo_hybrid' ] . includes ( config . model_type ) ) {
392- const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
393- const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
394-
395- /** @type {Record<string, number[]> } */
396- const cache_values = { } ;
397- const {
398- head_dim,
399- layer_types,
400- num_attention_heads,
401- num_key_value_heads,
402- hidden_size,
403- linear_num_value_heads,
404- linear_num_key_heads,
405- linear_key_head_dim,
406- linear_value_head_dim,
407- linear_conv_kernel_dim,
408- } = /** @type {any } */ ( config ) ;
409-
410- const key_dim = linear_key_head_dim * linear_num_key_heads ;
411- const value_dim = linear_value_head_dim * linear_num_value_heads ;
412-
413- const final_head_dim = head_dim ?? hidden_size / num_attention_heads ;
369+ const { layer_types } = /** @type {any } */ ( config ) ;
414370 for ( let i = 0 ; i < layer_types . length ; ++ i ) {
415371 if ( layer_types [ i ] === 'full_attention' ) {
416- for ( const kv of [ 'key' , 'value' ] ) {
417- cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , final_head_dim ] ;
418- }
372+ names . add ( `${ pkv_prefix } .${ i } .key` ) ;
373+ names . add ( `${ pkv_prefix } .${ i } .value` ) ;
419374 } else if ( layer_types [ i ] === 'linear_attention' ) {
420375 if ( config . model_type === 'olmo_hybrid' ) {
421- cache_values [ `${ conv_prefix } _conv.${ i } .key` ] = [ batch_size , key_dim , linear_conv_kernel_dim ] ;
422- cache_values [ `${ conv_prefix } _conv.${ i } .value` ] = [ batch_size , value_dim , linear_conv_kernel_dim ] ;
423- cache_values [ `${ conv_prefix } _conv.${ i } .query` ] = [ batch_size , key_dim , linear_conv_kernel_dim ] ;
376+ names . add ( `${ conv_prefix } _conv.${ i } .key` ) ;
377+ names . add ( `${ conv_prefix } _conv.${ i } .value` ) ;
378+ names . add ( `${ conv_prefix } _conv.${ i } .query` ) ;
424379 } else {
425- const conv_dim = key_dim * 2 + value_dim ;
426- cache_values [ `${ conv_prefix } _conv.${ i } ` ] = [ batch_size , conv_dim , linear_conv_kernel_dim ] ;
380+ names . add ( `${ conv_prefix } _conv.${ i } ` ) ;
427381 }
428- cache_values [ `${ conv_prefix } _recurrent.${ i } ` ] = [
429- batch_size ,
430- linear_num_value_heads ,
431- linear_key_head_dim ,
432- linear_value_head_dim ,
433- ] ;
382+ names . add ( `${ conv_prefix } _recurrent.${ i } ` ) ;
434383 } else {
435384 throw new Error ( `Unsupported layer type: ${ layer_types [ i ] } ` ) ;
436385 }
437386 }
438- return cache_values ;
387+ return names ;
439388 } else if ( [ 'gemma4' , 'gemma4_text' ] . includes ( config . model_type ) ) {
440389 const c = /** @type {any } */ (
441390 config . model_type === 'gemma4' ? /** @type {any } */ ( config ) . text_config : config
442391 ) ;
443- const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
444-
445- /** @type {Record<string, number[]> } */
446- const cache_values = { } ;
447392 const num_hidden_layers = c . num_hidden_layers ;
448393 const num_kv_shared_layers = c . num_kv_shared_layers ?? 0 ;
449394 const num_kv_layers = num_hidden_layers - num_kv_shared_layers ;
450- const num_key_value_heads = c . num_key_value_heads ;
451- const head_dim = c . head_dim ;
452- const global_head_dim = c . global_head_dim ?? head_dim ;
453- const layer_types = c . layer_types ?? [ ] ;
454395
455- // Create `num_kv_layers` unique KV entries, corresponding to the first `num_kv_layers`
456- // model layers (the remaining layers share caches with earlier ones).
457- // Full attention layers use global_head_dim, sliding attention layers use head_dim.
458396 for ( let i = 0 ; i < num_kv_layers ; ++ i ) {
459- const dim = layer_types [ i ] === 'full_attention' ? global_head_dim : head_dim ;
460- for ( const kv of [ 'key' , 'value' ] ) {
461- cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , dim ] ;
462- }
397+ names . add ( `${ pkv_prefix } .${ i } .key` ) ;
398+ names . add ( `${ pkv_prefix } .${ i } .value` ) ;
463399 }
464- return cache_values ;
400+ return names ;
465401 } else if ( [ 'lfm2_vl' , 'qwen3_5' , 'qwen3_5_moe' , 'voxtral_realtime' ] . includes ( config . model_type ) ) {
466402 let subConfig ;
467403 if ( config . model_type === 'voxtral_realtime' && options ?. session_name === 'audio_encoder' ) {
468404 subConfig = /** @type {any } */ ( config ) . audio_config ;
469405 } else {
470406 subConfig = /** @type {any } */ ( config ) . text_config ;
471407 }
472- return getCacheShapes ( subConfig , options ) ;
408+ return getCacheNames ( subConfig , options ) ;
473409 }
474410
475- return getKeyValueShapes ( config , options ) ;
411+ return getKeyValueNames ( config , { prefix : pkv_prefix } ) ;
476412}
477413
478- /** @type {typeof getKeyValueShapes } */
479- function getKeyValueShapes ( config , { prefix = 'past_key_values' , batch_size = 1 } = { } ) {
480- /** @type {Record<string, number[]> } */
481- const decoderFeeds = { } ;
414+ /**
415+ * @param {PretrainedConfig } config
416+ * @param {{ prefix?: string } } [options]
417+ * @returns {Set<string> }
418+ */
419+ function getKeyValueNames ( config , { prefix = 'past_key_values' } = { } ) {
420+ /** @type {Set<string> } */
421+ const names = new Set ( ) ;
482422 const normalized_config = config . normalized_config ;
483423
484424 if (
485425 normalized_config . is_encoder_decoder &&
486426 'num_encoder_heads' in normalized_config &&
487427 'num_decoder_heads' in normalized_config
488428 ) {
489- const encoder_dim_kv =
490- normalized_config . encoder_dim_kv ??
491- normalized_config . encoder_hidden_size / normalized_config . num_encoder_heads ;
492- const decoder_dim_kv =
493- normalized_config . decoder_dim_kv ??
494- normalized_config . decoder_hidden_size / normalized_config . num_decoder_heads ;
495-
496- const encoder_dims = [ batch_size , normalized_config . num_encoder_heads , 0 , encoder_dim_kv ] ;
497- const decoder_dims = [ batch_size , normalized_config . num_decoder_heads , 0 , decoder_dim_kv ] ;
498429 for ( let i = 0 ; i < normalized_config . num_decoder_layers ; ++ i ) {
499- decoderFeeds [ `${ prefix } .${ i } .encoder.key` ] = encoder_dims ;
500- decoderFeeds [ `${ prefix } .${ i } .encoder.value` ] = encoder_dims ;
501- decoderFeeds [ `${ prefix } .${ i } .decoder.key` ] = decoder_dims ;
502- decoderFeeds [ `${ prefix } .${ i } .decoder.value` ] = decoder_dims ;
430+ names . add ( `${ prefix } .${ i } .encoder.key` ) ;
431+ names . add ( `${ prefix } .${ i } .encoder.value` ) ;
432+ names . add ( `${ prefix } .${ i } .decoder.key` ) ;
433+ names . add ( `${ prefix } .${ i } .decoder.value` ) ;
434+ }
435+ } else if ( normalized_config . multi_query ) {
436+ // e.g., for `gpt_bigcode`
437+ for ( let i = 0 ; i < normalized_config . num_layers ; ++ i ) {
438+ names . add ( `${ prefix } .${ i } .key_value` ) ;
503439 }
504440 } else {
505- // Decoders
506- const num_heads = normalized_config . num_heads ;
507- const num_layers = normalized_config . num_layers ;
508- const dim_kv =
509- normalized_config . dim_kv ??
510- normalized_config . hidden_size / ( normalized_config . num_attention_heads ?? num_heads ) ;
511-
512- if ( normalized_config . model_type === 'falcon' ) {
513- // NOTE: Custom implementation for Falcon
514- const dims = [ batch_size * num_heads , 0 , dim_kv ] ;
515- for ( let i = 0 ; i < num_layers ; ++ i ) {
516- decoderFeeds [ `${ prefix } .${ i } .key` ] = dims ;
517- decoderFeeds [ `${ prefix } .${ i } .value` ] = dims ;
518- }
519- } else if ( normalized_config . multi_query ) {
520- // e.g., for `gpt_bigcode`
521- const dims = [ batch_size * num_heads , 0 , 2 * dim_kv ] ;
522-
523- for ( let i = 0 ; i < num_layers ; ++ i ) {
524- decoderFeeds [ `${ prefix } .${ i } .key_value` ] = dims ;
525- }
526- } else if ( normalized_config . model_type === 'bloom' ) {
527- // NOTE: Custom implementation for Bloom
528-
529- const keyDims = [ batch_size * num_heads , dim_kv , 0 ] ; // [batch_size x num_heads,64,past_sequence_length]
530- const valueDims = [ batch_size * num_heads , 0 , dim_kv ] ; // [batch_size x num_heads,past_sequence_length,64]
531- for ( let i = 0 ; i < num_layers ; ++ i ) {
532- decoderFeeds [ `${ prefix } .${ i } .key` ] = keyDims ;
533- decoderFeeds [ `${ prefix } .${ i } .value` ] = valueDims ;
534- }
535- } else if ( normalized_config . model_type === 'openelm' ) {
536- for ( let i = 0 ; i < num_layers ; ++ i ) {
537- const dims = [ batch_size , num_heads [ i ] , 0 , dim_kv ] ;
538-
539- decoderFeeds [ `${ prefix } .${ i } .key` ] = dims ;
540- decoderFeeds [ `${ prefix } .${ i } .value` ] = dims ;
541- }
542- } else {
543- // Decoder-only
544- const dims = [ batch_size , num_heads , 0 , dim_kv ] ;
545- for ( let i = 0 ; i < num_layers ; ++ i ) {
546- decoderFeeds [ `${ prefix } .${ i } .key` ] = dims ;
547- decoderFeeds [ `${ prefix } .${ i } .value` ] = dims ;
548- }
441+ for ( let i = 0 ; i < normalized_config . num_layers ; ++ i ) {
442+ names . add ( `${ prefix } .${ i } .key` ) ;
443+ names . add ( `${ prefix } .${ i } .value` ) ;
549444 }
550445 }
551446
552- return decoderFeeds ;
447+ return names ;
553448}
554449/**
555450 * Base class for all configuration classes. For more information, see the corresponding
@@ -626,7 +521,6 @@ export class AutoConfig {
626521 * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
627522 * @typedef {Object } TransformersJSConfig
628523 * @property {Record<import('./utils/devices.js').DeviceType, DeviceConfig> } [device_config] Device-specific configurations.
629- * @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>|boolean } [kv_cache_dtype] The data type of the key-value cache.
630524 * @property {Record<string, number> } [free_dimension_overrides] Override the free dimensions of the model.
631525 * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
632526 * for more information.
0 commit comments