Skip to content

Commit b93766d

Browse files
authored
Use inputMetadata API for simplified internals (#1657)
* use session.inputMetadata for cache initialization * no longer need kv_cache_dtype * Update modeling_utils.js
1 parent 6eaf11a commit b93766d

4 files changed

Lines changed: 112 additions & 200 deletions

File tree

packages/transformers/src/configs.js

Lines changed: 53 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -322,234 +322,129 @@ function getNormalizedConfig(config) {
322322
}
323323

324324
/**
325-
*
326325
* @param {PretrainedConfig} config
327-
* @returns {Record<string, number[]>}
326+
* @param {{ prefix?: string, session_name?: string }} [options]
327+
* @returns {Set<string>}
328328
*/
329-
export function getCacheShapes(config, options) {
329+
export function getCacheNames(config, options) {
330330
if (!(config instanceof PretrainedConfig)) {
331331
config = new PretrainedConfig(config);
332332
}
333333

334-
const batch_size = options?.batch_size ?? 1;
335-
if (['lfm2', 'lfm2_moe'].includes(config.model_type)) {
336-
const pkv_prefix = options?.prefix ?? 'past_key_values';
337-
const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past';
334+
const pkv_prefix = options?.prefix ?? 'past_key_values';
335+
const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past';
336+
/** @type {Set<string>} */
337+
const names = new Set();
338338

339-
/** @type {Record<string, number[]>} */
340-
const cache_values = {};
341-
const { layer_types, num_attention_heads, num_key_value_heads, hidden_size, conv_L_cache } =
342-
/** @type {any} */ (config);
343-
const head_dim = hidden_size / num_attention_heads;
339+
if (['lfm2', 'lfm2_moe'].includes(config.model_type)) {
340+
const { layer_types } = /** @type {any} */ (config);
344341
for (let i = 0; i < layer_types.length; ++i) {
345342
if (layer_types[i] === 'full_attention') {
346-
for (const kv of ['key', 'value']) {
347-
cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim];
348-
}
343+
names.add(`${pkv_prefix}.${i}.key`);
344+
names.add(`${pkv_prefix}.${i}.value`);
349345
} else if (layer_types[i] === 'conv') {
350-
cache_values[`${conv_prefix}_conv.${i}`] = [batch_size, hidden_size, conv_L_cache];
346+
names.add(`${conv_prefix}_conv.${i}`);
351347
} else {
352348
throw new Error(`Unsupported layer type: ${layer_types[i]}`);
353349
}
354350
}
355-
return cache_values;
351+
return names;
356352
} else if (['granitemoehybrid', 'falcon_h1', 'nemotron_h'].includes(config.model_type)) {
357-
const pkv_prefix = options?.prefix ?? 'past_key_values';
358-
const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past';
359-
360353
const c = /** @type {any} */ (config);
361-
362-
// Normalize config field names across model types
363354
const layer_types = c.layer_types ?? c.layers_block_type;
364355
const num_layers = c.num_hidden_layers ?? layer_types?.length;
365-
const num_key_value_heads = c.num_key_value_heads;
366-
const head_dim = c.head_dim ?? c.hidden_size / c.num_attention_heads;
367-
const mamba_n_heads = c.mamba_n_heads ?? c.mamba_num_heads;
368-
const mamba_d_head = c.mamba_d_head ?? c.mamba_head_dim;
369-
const mamba_d_state = c.mamba_d_state ?? c.ssm_state_size;
370-
const mamba_n_groups = c.mamba_n_groups ?? c.n_groups;
371-
const mamba_d_conv = c.mamba_d_conv ?? c.conv_kernel;
372-
const mamba_d_ssm =
373-
c.mamba_d_ssm ?? (c.mamba_expand ? c.mamba_expand * c.hidden_size : mamba_n_heads * mamba_d_head);
374-
const conv_d_inner = mamba_d_ssm + 2 * mamba_n_groups * mamba_d_state;
375-
376-
/** @type {Record<string, number[]>} */
377-
const cache_values = {};
378356

379357
for (let i = 0; i < num_layers; ++i) {
380358
if (!layer_types || layer_types[i] === 'mamba') {
381-
cache_values[`${conv_prefix}_conv.${i}`] = [batch_size, conv_d_inner, mamba_d_conv];
382-
cache_values[`${conv_prefix}_ssm.${i}`] = [batch_size, mamba_n_heads, mamba_d_head, mamba_d_state];
359+
names.add(`${conv_prefix}_conv.${i}`);
360+
names.add(`${conv_prefix}_ssm.${i}`);
383361
}
384362
if (!layer_types || layer_types[i] === 'attention') {
385-
for (const kv of ['key', 'value']) {
386-
cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim];
387-
}
363+
names.add(`${pkv_prefix}.${i}.key`);
364+
names.add(`${pkv_prefix}.${i}.value`);
388365
}
389366
}
390-
return cache_values;
367+
return names;
391368
} else if (['qwen3_next', 'qwen3_5_text', 'qwen3_5_moe_text', 'olmo_hybrid'].includes(config.model_type)) {
392-
const pkv_prefix = options?.prefix ?? 'past_key_values';
393-
const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past';
394-
395-
/** @type {Record<string, number[]>} */
396-
const cache_values = {};
397-
const {
398-
head_dim,
399-
layer_types,
400-
num_attention_heads,
401-
num_key_value_heads,
402-
hidden_size,
403-
linear_num_value_heads,
404-
linear_num_key_heads,
405-
linear_key_head_dim,
406-
linear_value_head_dim,
407-
linear_conv_kernel_dim,
408-
} = /** @type {any} */ (config);
409-
410-
const key_dim = linear_key_head_dim * linear_num_key_heads;
411-
const value_dim = linear_value_head_dim * linear_num_value_heads;
412-
413-
const final_head_dim = head_dim ?? hidden_size / num_attention_heads;
369+
const { layer_types } = /** @type {any} */ (config);
414370
for (let i = 0; i < layer_types.length; ++i) {
415371
if (layer_types[i] === 'full_attention') {
416-
for (const kv of ['key', 'value']) {
417-
cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, final_head_dim];
418-
}
372+
names.add(`${pkv_prefix}.${i}.key`);
373+
names.add(`${pkv_prefix}.${i}.value`);
419374
} else if (layer_types[i] === 'linear_attention') {
420375
if (config.model_type === 'olmo_hybrid') {
421-
cache_values[`${conv_prefix}_conv.${i}.key`] = [batch_size, key_dim, linear_conv_kernel_dim];
422-
cache_values[`${conv_prefix}_conv.${i}.value`] = [batch_size, value_dim, linear_conv_kernel_dim];
423-
cache_values[`${conv_prefix}_conv.${i}.query`] = [batch_size, key_dim, linear_conv_kernel_dim];
376+
names.add(`${conv_prefix}_conv.${i}.key`);
377+
names.add(`${conv_prefix}_conv.${i}.value`);
378+
names.add(`${conv_prefix}_conv.${i}.query`);
424379
} else {
425-
const conv_dim = key_dim * 2 + value_dim;
426-
cache_values[`${conv_prefix}_conv.${i}`] = [batch_size, conv_dim, linear_conv_kernel_dim];
380+
names.add(`${conv_prefix}_conv.${i}`);
427381
}
428-
cache_values[`${conv_prefix}_recurrent.${i}`] = [
429-
batch_size,
430-
linear_num_value_heads,
431-
linear_key_head_dim,
432-
linear_value_head_dim,
433-
];
382+
names.add(`${conv_prefix}_recurrent.${i}`);
434383
} else {
435384
throw new Error(`Unsupported layer type: ${layer_types[i]}`);
436385
}
437386
}
438-
return cache_values;
387+
return names;
439388
} else if (['gemma4', 'gemma4_text'].includes(config.model_type)) {
440389
const c = /** @type {any} */ (
441390
config.model_type === 'gemma4' ? /** @type {any} */ (config).text_config : config
442391
);
443-
const pkv_prefix = options?.prefix ?? 'past_key_values';
444-
445-
/** @type {Record<string, number[]>} */
446-
const cache_values = {};
447392
const num_hidden_layers = c.num_hidden_layers;
448393
const num_kv_shared_layers = c.num_kv_shared_layers ?? 0;
449394
const num_kv_layers = num_hidden_layers - num_kv_shared_layers;
450-
const num_key_value_heads = c.num_key_value_heads;
451-
const head_dim = c.head_dim;
452-
const global_head_dim = c.global_head_dim ?? head_dim;
453-
const layer_types = c.layer_types ?? [];
454395

455-
// Create `num_kv_layers` unique KV entries, corresponding to the first `num_kv_layers`
456-
// model layers (the remaining layers share caches with earlier ones).
457-
// Full attention layers use global_head_dim, sliding attention layers use head_dim.
458396
for (let i = 0; i < num_kv_layers; ++i) {
459-
const dim = layer_types[i] === 'full_attention' ? global_head_dim : head_dim;
460-
for (const kv of ['key', 'value']) {
461-
cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, dim];
462-
}
397+
names.add(`${pkv_prefix}.${i}.key`);
398+
names.add(`${pkv_prefix}.${i}.value`);
463399
}
464-
return cache_values;
400+
return names;
465401
} else if (['lfm2_vl', 'qwen3_5', 'qwen3_5_moe', 'voxtral_realtime'].includes(config.model_type)) {
466402
let subConfig;
467403
if (config.model_type === 'voxtral_realtime' && options?.session_name === 'audio_encoder') {
468404
subConfig = /** @type {any} */ (config).audio_config;
469405
} else {
470406
subConfig = /** @type {any} */ (config).text_config;
471407
}
472-
return getCacheShapes(subConfig, options);
408+
return getCacheNames(subConfig, options);
473409
}
474410

475-
return getKeyValueShapes(config, options);
411+
return getKeyValueNames(config, { prefix: pkv_prefix });
476412
}
477413

478-
/** @type {typeof getKeyValueShapes} */
479-
function getKeyValueShapes(config, { prefix = 'past_key_values', batch_size = 1 } = {}) {
480-
/** @type {Record<string, number[]>} */
481-
const decoderFeeds = {};
414+
/**
415+
* @param {PretrainedConfig} config
416+
* @param {{ prefix?: string }} [options]
417+
* @returns {Set<string>}
418+
*/
419+
function getKeyValueNames(config, { prefix = 'past_key_values' } = {}) {
420+
/** @type {Set<string>} */
421+
const names = new Set();
482422
const normalized_config = config.normalized_config;
483423

484424
if (
485425
normalized_config.is_encoder_decoder &&
486426
'num_encoder_heads' in normalized_config &&
487427
'num_decoder_heads' in normalized_config
488428
) {
489-
const encoder_dim_kv =
490-
normalized_config.encoder_dim_kv ??
491-
normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads;
492-
const decoder_dim_kv =
493-
normalized_config.decoder_dim_kv ??
494-
normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads;
495-
496-
const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
497-
const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
498429
for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
499-
decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
500-
decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
501-
decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
502-
decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
430+
names.add(`${prefix}.${i}.encoder.key`);
431+
names.add(`${prefix}.${i}.encoder.value`);
432+
names.add(`${prefix}.${i}.decoder.key`);
433+
names.add(`${prefix}.${i}.decoder.value`);
434+
}
435+
} else if (normalized_config.multi_query) {
436+
// e.g., for `gpt_bigcode`
437+
for (let i = 0; i < normalized_config.num_layers; ++i) {
438+
names.add(`${prefix}.${i}.key_value`);
503439
}
504440
} else {
505-
// Decoders
506-
const num_heads = normalized_config.num_heads;
507-
const num_layers = normalized_config.num_layers;
508-
const dim_kv =
509-
normalized_config.dim_kv ??
510-
normalized_config.hidden_size / (normalized_config.num_attention_heads ?? num_heads);
511-
512-
if (normalized_config.model_type === 'falcon') {
513-
// NOTE: Custom implementation for Falcon
514-
const dims = [batch_size * num_heads, 0, dim_kv];
515-
for (let i = 0; i < num_layers; ++i) {
516-
decoderFeeds[`${prefix}.${i}.key`] = dims;
517-
decoderFeeds[`${prefix}.${i}.value`] = dims;
518-
}
519-
} else if (normalized_config.multi_query) {
520-
// e.g., for `gpt_bigcode`
521-
const dims = [batch_size * num_heads, 0, 2 * dim_kv];
522-
523-
for (let i = 0; i < num_layers; ++i) {
524-
decoderFeeds[`${prefix}.${i}.key_value`] = dims;
525-
}
526-
} else if (normalized_config.model_type === 'bloom') {
527-
// NOTE: Custom implementation for Bloom
528-
529-
const keyDims = [batch_size * num_heads, dim_kv, 0]; // [batch_size x num_heads,64,past_sequence_length]
530-
const valueDims = [batch_size * num_heads, 0, dim_kv]; // [batch_size x num_heads,past_sequence_length,64]
531-
for (let i = 0; i < num_layers; ++i) {
532-
decoderFeeds[`${prefix}.${i}.key`] = keyDims;
533-
decoderFeeds[`${prefix}.${i}.value`] = valueDims;
534-
}
535-
} else if (normalized_config.model_type === 'openelm') {
536-
for (let i = 0; i < num_layers; ++i) {
537-
const dims = [batch_size, num_heads[i], 0, dim_kv];
538-
539-
decoderFeeds[`${prefix}.${i}.key`] = dims;
540-
decoderFeeds[`${prefix}.${i}.value`] = dims;
541-
}
542-
} else {
543-
// Decoder-only
544-
const dims = [batch_size, num_heads, 0, dim_kv];
545-
for (let i = 0; i < num_layers; ++i) {
546-
decoderFeeds[`${prefix}.${i}.key`] = dims;
547-
decoderFeeds[`${prefix}.${i}.value`] = dims;
548-
}
441+
for (let i = 0; i < normalized_config.num_layers; ++i) {
442+
names.add(`${prefix}.${i}.key`);
443+
names.add(`${prefix}.${i}.value`);
549444
}
550445
}
551446

552-
return decoderFeeds;
447+
return names;
553448
}
554449
/**
555450
* Base class for all configuration classes. For more information, see the corresponding
@@ -626,7 +521,6 @@ export class AutoConfig {
626521
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
627522
* @typedef {Object} TransformersJSConfig
628523
* @property {Record<import('./utils/devices.js').DeviceType, DeviceConfig>} [device_config] Device-specific configurations.
629-
* @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>|boolean} [kv_cache_dtype] The data type of the key-value cache.
630524
* @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
631525
* See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
632526
* for more information.

packages/transformers/src/models/modeling_utils.js

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Callable } from '../utils/generic.js';
22
import { constructSessions, sessionRun } from './session.js';
3-
import { AutoConfig, getCacheShapes } from '../configs.js';
3+
import { AutoConfig, getCacheNames } from '../configs.js';
44
import { Tensor, full_like, cat, zeros_like, ones_like, ones } from '../utils/tensor.js';
55
import { DataTypeMap } from '../utils/dtypes.js';
66

@@ -1233,6 +1233,21 @@ export function getAttentions(model_output) {
12331233
return attentions;
12341234
}
12351235

1236+
/**
1237+
* Resolve symbolic dims from ONNX inputMetadata for empty-cache initialization.
1238+
* Each symbolic dim name is looked up in `symbols`; numeric dims pass through.
1239+
* Any unresolved symbolic dim defaults to 0.
1240+
* @param {ReadonlyArray<number|string>} metadataShape
1241+
* @param {Record<string, number>} symbols
1242+
* @returns {number[]}
1243+
*/
1244+
export function resolveCacheShape(metadataShape, symbols) {
1245+
return metadataShape.map((d) => {
1246+
if (typeof d === 'number') return d;
1247+
return symbols[d] ?? 0;
1248+
});
1249+
}
1250+
12361251
/**
12371252
* Adds past key values to the decoder feeds object. If pastKeyValues is null,
12381253
* creates a new DynamicCache with zero-filled tensors for each cache entry.
@@ -1251,16 +1266,23 @@ export function addPastKeyValues(self, decoderFeeds, pastKeyValues) {
12511266
const session = self.sessions['decoder_model_merged'] ?? self.sessions['model'];
12521267
const batch_size = (decoderFeeds[self.main_input_name] ?? decoderFeeds.attention_mask)?.dims?.[0] ?? 1;
12531268

1254-
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1255-
const cls = dtype === 'float16' ? DataTypeMap.float16 : DataTypeMap.float32;
1256-
const shapes = getCacheShapes(self.config, { batch_size });
1269+
const names = getCacheNames(self.config);
1270+
const num_heads = self.config?.normalized_config?.num_heads;
1271+
/** @type {Record<string, number>} */
1272+
const symbols = { batch_size };
1273+
if (typeof num_heads === 'number') {
1274+
symbols['batch_size x num_heads'] = batch_size * num_heads;
1275+
}
12571276
/** @type {Record<string, Tensor>} */
12581277
const entries = Object.create(null);
1259-
for (const name in shapes) {
1260-
const size = shapes[name].reduce((a, b) => a * b, 1);
1261-
const t = new Tensor(dtype, new cls(size), shapes[name]);
1262-
decoderFeeds[name] = t;
1263-
entries[name] = t;
1278+
for (const meta of session.inputMetadata) {
1279+
if (!names.has(meta.name)) continue;
1280+
const shape = resolveCacheShape(meta.shape, symbols);
1281+
const size = shape.reduce((a, b) => a * b, 1);
1282+
const cls = DataTypeMap[meta.type];
1283+
const t = new Tensor(meta.type, new cls(size), shape);
1284+
decoderFeeds[meta.name] = t;
1285+
entries[meta.name] = t;
12641286
}
12651287
if (pastKeyValues) {
12661288
// Populate the (empty) user-provided cache in-place

0 commit comments

Comments
 (0)