Skip to content

Commit d834093

Browse files
xenovaJacobiusMakesclaudenico-martin
authored
Emit progress_total events from PreTrainedModel.from_pretrained() (#1615)
* Emit progress_total events from PreTrainedModel.from_pretrained() Closes #1052. When a progress_callback is provided, gather file metadata (sizes) upfront via HEAD/Range requests before downloads begin. This lets consumers build a single aggregate progress bar across all model files, matching the behavior already present in the pipeline() API. The implementation reuses the existing get_model_files() and get_file_metadata() utilities and follows the same wrapping pattern used in pipelines.js, emitting progress_total events that include per-file and overall loaded/total byte counts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address review — config.json progress + pipeline() duplication Fixes two issues raised in code review: 1. config.json is fetched by AutoConfig.from_pretrained() before the progress tracker is initialized. Pre-mark it as fully loaded in files_loading so total progress reaches 100%. 2. pipeline() already wraps the callback with progress_total events. When from_pretrained() wrapped it again, users got duplicate events. Fix: mark wrapped callbacks with _progress_total_wrapped flag. from_pretrained() checks the flag and skips wrapping if already set. pipeline()'s wrapper now also sets this flag. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * switched to wrapper-class instead of _progress_total_wrapped property * created /src/models/session_config.js to avoid circular dependencies * Use callable objects instead * fix jsdoc * fix imports * Add strict progress callback unit tests * Deduplicate concurrent loads and fix progress w/ node loading from path * formatting * increase larger whisper models loading times * make tests stricter --------- Co-authored-by: JacobiusMakes <jgalperin98@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Nico Martin <mail@nico.dev>
1 parent a0d86d5 commit d834093

12 files changed

Lines changed: 635 additions & 178 deletions

packages/transformers/src/models/modeling_utils.js

Lines changed: 56 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ import {
3434
import { GenerationConfig } from '../generation/configuration_utils.js';
3535
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from '../generation/stopping_criteria.js';
3636
import { LogitsSampler } from '../generation/logits_sampler.js';
37-
import { pick } from '../utils/core.js';
37+
import { DefaultProgressCallback, pick } from '../utils/core.js';
3838
import { ModelOutput } from './modeling_outputs.js';
3939
import { logger } from '../utils/logger.js';
4040
import { DynamicCache } from '../cache_utils.js';
41+
import { get_model_files } from '../utils/model_registry/get_model_files.js';
42+
import { get_file_metadata } from '../utils/model_registry/get_file_metadata.js';
43+
import { MODEL_SESSION_CONFIG, MODEL_TYPES } from './session_config.js';
4144

4245
/**
4346
* Converts an array or Tensor of integers to an int64 Tensor.
@@ -83,212 +86,81 @@ export function boolTensor(value) {
8386
return new Tensor('bool', [value], [1]);
8487
}
8588

86-
export const MODEL_TYPES = {
87-
EncoderOnly: 0,
88-
EncoderDecoder: 1,
89-
Seq2Seq: 2,
90-
Vision2Seq: 3,
91-
DecoderOnly: 4,
92-
DecoderOnlyWithoutHead: 5,
93-
MaskGeneration: 6,
94-
ImageTextToText: 7,
95-
Musicgen: 8,
96-
MultiModality: 9,
97-
Phi3V: 10,
98-
AudioTextToText: 11,
99-
AutoEncoder: 12,
100-
ImageAudioTextToText: 13,
101-
Supertonic: 14,
102-
Chatterbox: 15,
103-
VoxtralRealtime: 16,
104-
};
89+
export { getSessionsConfig, getTextOnlySessions, MODEL_TYPES } from './session_config.js';
10590

106-
const MODEL_TYPE_CONFIG = {
91+
/**
92+
* Runtime-only model type configuration (forward functions, generation flags).
93+
* Session/file configuration lives in `MODEL_SESSION_CONFIG` (session_config.js)
94+
* and is merged in at lookup time by `resolveTypeConfig` to avoid duplication.
95+
*/
96+
const MODEL_RUNTIME_CONFIG = {
10797
[MODEL_TYPES.DecoderOnly]: {
10898
can_generate: true,
10999
forward: decoder_forward,
110100
prepare_inputs: decoder_prepare_inputs_for_generation,
111-
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
112-
cache_sessions: { model: true },
113-
optional_configs: { generation_config: 'generation_config.json' },
114101
},
115102
[MODEL_TYPES.DecoderOnlyWithoutHead]: {
116103
can_generate: false,
117104
forward: decoder_forward,
118105
prepare_inputs: decoder_prepare_inputs_for_generation,
119-
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
120106
},
121107
[MODEL_TYPES.Seq2Seq]: {
122108
can_generate: true,
123109
forward: seq2seq_forward,
124110
prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
125-
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
126-
cache_sessions: { decoder_model_merged: true },
127-
optional_configs: { generation_config: 'generation_config.json' },
128111
},
129112
[MODEL_TYPES.Vision2Seq]: {
130113
can_generate: true,
131114
forward: seq2seq_forward,
132115
prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
133-
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
134-
cache_sessions: { decoder_model_merged: true },
135-
optional_configs: { generation_config: 'generation_config.json' },
136116
},
137117
[MODEL_TYPES.Musicgen]: {
138118
can_generate: true,
139119
forward: seq2seq_forward,
140-
sessions: () => ({
141-
model: 'text_encoder',
142-
decoder_model_merged: 'decoder_model_merged',
143-
encodec_decode: 'encodec_decode',
144-
}),
145-
cache_sessions: { decoder_model_merged: true },
146-
optional_configs: { generation_config: 'generation_config.json' },
147120
},
148121
[MODEL_TYPES.EncoderDecoder]: {
149122
can_generate: false,
150123
forward: seq2seq_forward,
151-
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
152-
cache_sessions: { decoder_model_merged: true },
153-
},
154-
[MODEL_TYPES.MaskGeneration]: {
155-
sessions: () => ({ model: 'vision_encoder', prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder' }),
156124
},
157125
[MODEL_TYPES.ImageTextToText]: {
158126
can_generate: true,
159127
forward: image_text_to_text_forward,
160128
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
161-
text_only_sessions: { embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' },
162-
sessions: (config, options, textOnly) => {
163-
const s = { ...MODEL_TYPE_CONFIG[MODEL_TYPES.ImageTextToText].text_only_sessions };
164-
if (!textOnly) s['vision_encoder'] = 'vision_encoder';
165-
if (config.is_encoder_decoder) s['model'] = 'encoder_model';
166-
return s;
167-
},
168-
cache_sessions: { decoder_model_merged: true },
169-
optional_configs: { generation_config: 'generation_config.json' },
170129
},
171130
[MODEL_TYPES.AudioTextToText]: {
172131
can_generate: true,
173132
forward: audio_text_to_text_forward,
174133
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
175-
text_only_sessions: { embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' },
176-
sessions: (config, options, textOnly) => {
177-
const s = { ...MODEL_TYPE_CONFIG[MODEL_TYPES.AudioTextToText].text_only_sessions };
178-
if (!textOnly) s['audio_encoder'] = 'audio_encoder';
179-
return s;
180-
},
181-
cache_sessions: { decoder_model_merged: true },
182-
optional_configs: { generation_config: 'generation_config.json' },
183134
},
184135
[MODEL_TYPES.ImageAudioTextToText]: {
185136
can_generate: true,
186137
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
187-
text_only_sessions: { embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' },
188-
sessions: (config, options, textOnly) => {
189-
const s = { ...MODEL_TYPE_CONFIG[MODEL_TYPES.ImageAudioTextToText].text_only_sessions };
190-
if (!textOnly) {
191-
s['audio_encoder'] = 'audio_encoder';
192-
s['vision_encoder'] = 'vision_encoder';
193-
}
194-
return s;
195-
},
196-
optional_configs: { generation_config: 'generation_config.json' },
197138
},
198139
[MODEL_TYPES.Phi3V]: {
199140
can_generate: true,
200141
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
201-
sessions: () => ({
202-
prepare_inputs_embeds: 'prepare_inputs_embeds',
203-
model: 'model',
204-
vision_encoder: 'vision_encoder',
205-
}),
206-
cache_sessions: { model: true },
207-
optional_configs: { generation_config: 'generation_config.json' },
208142
},
209143
[MODEL_TYPES.MultiModality]: {
210144
can_generate: true,
211-
sessions: () => ({
212-
prepare_inputs_embeds: 'prepare_inputs_embeds',
213-
model: 'language_model',
214-
lm_head: 'lm_head',
215-
gen_head: 'gen_head',
216-
gen_img_embeds: 'gen_img_embeds',
217-
image_decode: 'image_decode',
218-
}),
219-
cache_sessions: { model: true },
220-
optional_configs: { generation_config: 'generation_config.json' },
221145
},
222146
[MODEL_TYPES.AutoEncoder]: {
223147
can_generate: false,
224148
forward: auto_encoder_forward,
225-
sessions: () => ({ encoder_model: 'encoder_model', decoder_model: 'decoder_model' }),
226-
},
227-
[MODEL_TYPES.Supertonic]: {
228-
sessions: () => ({
229-
text_encoder: 'text_encoder',
230-
latent_denoiser: 'latent_denoiser',
231-
voice_decoder: 'voice_decoder',
232-
}),
233149
},
234150
[MODEL_TYPES.Chatterbox]: {
235151
can_generate: true,
236152
forward: encoder_forward,
237-
sessions: () => ({
238-
embed_tokens: 'embed_tokens',
239-
speech_encoder: 'speech_encoder',
240-
model: 'language_model',
241-
conditional_decoder: 'conditional_decoder',
242-
}),
243-
cache_sessions: { model: true },
244-
optional_configs: { generation_config: 'generation_config.json' },
245153
},
246154
[MODEL_TYPES.VoxtralRealtime]: {
247155
can_generate: true,
248156
prepare_inputs: decoder_prepare_inputs_for_generation,
249-
text_only_sessions: { embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' },
250-
sessions: (config, options, textOnly) => {
251-
const s = { ...MODEL_TYPE_CONFIG[MODEL_TYPES.VoxtralRealtime].text_only_sessions };
252-
if (!textOnly) s['audio_encoder'] = 'audio_encoder';
253-
return s;
254-
},
255-
cache_sessions: { decoder_model_merged: true, audio_encoder: true },
256-
optional_configs: { generation_config: 'generation_config.json' },
257157
},
258158
default: {
259159
can_generate: false,
260160
forward: encoder_forward,
261-
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
262161
},
263162
};
264163

265-
/**
266-
* Get the session configuration for a given model type.
267-
* @param {number} modelType The model type enum value.
268-
* @param {Object} config The model config.
269-
* @param {Object} [options] Loading options.
270-
* @returns {{ sessions: Record<string, string>, cache_sessions?: Record<string, true>, optional_configs?: Record<string, string> }}
271-
*/
272-
export function getSessionsConfig(modelType, config, options = {}) {
273-
const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
274-
return {
275-
sessions: typeConfig.sessions(config, options),
276-
cache_sessions: typeConfig.cache_sessions,
277-
optional_configs: typeConfig.optional_configs,
278-
};
279-
}
280-
281-
/**
282-
* Returns the text-only session names for a given model type, or `null` if
283-
* the model type does not define a text-only session set.
284-
* @param {number} modelType The model type enum value.
285-
* @returns {Record<string, string>|null}
286-
*/
287-
export function getTextOnlySessions(modelType) {
288-
const typeConfig = MODEL_TYPE_CONFIG[modelType];
289-
return typeConfig?.text_only_sessions ?? null;
290-
}
291-
292164
/**
293165
* Resolves the model type config for a given class name and config.
294166
* @param {string} modelName The name of the class being used to load.
@@ -315,7 +187,9 @@ function resolveTypeConfig(modelName, config) {
315187
}
316188
}
317189

318-
return { typeConfig: MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default, textOnly, modelType };
190+
const runtimeConfig = MODEL_RUNTIME_CONFIG[modelType] ?? MODEL_RUNTIME_CONFIG.default;
191+
const sessionConfig = MODEL_SESSION_CONFIG[modelType] ?? MODEL_SESSION_CONFIG.default;
192+
return { typeConfig: { ...runtimeConfig, ...sessionConfig }, textOnly, modelType };
319193
}
320194

321195
export const MODEL_TYPE_MAPPING = new Map();
@@ -431,6 +305,46 @@ export class PreTrainedModel extends Callable {
431305
}
432306
}
433307

308+
// If a progress callback is provided AND it hasn't already been wrapped
309+
// by pipeline() (which does its own aggregation), gather file metadata
310+
// upfront so we can emit `progress_total` events. This lets consumers
311+
// render a single overall progress bar when calling from_pretrained() directly.
312+
if (progress_callback && !(progress_callback instanceof DefaultProgressCallback)) {
313+
/** @type {import('../utils/core.js').FilesLoadingMap} */
314+
const files_loading = {};
315+
316+
try {
317+
const expected_files = await get_model_files(pretrained_model_name_or_path, {
318+
config,
319+
dtype,
320+
device,
321+
model_file_name,
322+
});
323+
324+
const metadata = await Promise.all(
325+
expected_files.map((file) => get_file_metadata(pretrained_model_name_or_path, file, options)),
326+
);
327+
metadata.forEach((m, i) => {
328+
if (m.exists) {
329+
// config.json is fetched by AutoConfig.from_pretrained() above
330+
const isAlreadyLoaded = expected_files[i] === 'config.json';
331+
files_loading[expected_files[i]] = {
332+
loaded: isAlreadyLoaded ? (m.size ?? 0) : 0,
333+
total: m.size ?? 0,
334+
};
335+
}
336+
});
337+
} catch (e) {
338+
// If we fail to get metadata, we can still proceed without total progress.
339+
// This may happen with local-only models or custom cache setups.
340+
logger.warn(`Unable to fetch model file metadata for total progress tracking: ${e}`);
341+
}
342+
343+
if (Object.keys(files_loading).length > 0) {
344+
options.progress_callback = new DefaultProgressCallback(progress_callback, files_loading);
345+
}
346+
}
347+
434348
const sessions = typeConfig.sessions(config, options, textOnly);
435349
const promises = [
436350
constructSessions(pretrained_model_name_or_path, sessions, options, typeConfig.cache_sessions),
@@ -1271,7 +1185,8 @@ export class PreTrainedModel extends Callable {
12711185
* @private
12721186
*/
12731187
export async function seq2seq_forward(self, model_inputs) {
1274-
let { encoder_outputs, input_ids, decoder_input_ids, decoder_attention_mask, ...other_decoder_inputs } = model_inputs;
1188+
let { encoder_outputs, input_ids, decoder_input_ids, decoder_attention_mask, ...other_decoder_inputs } =
1189+
model_inputs;
12751190
// Encode if needed
12761191
if (!encoder_outputs) {
12771192
const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames);

0 commit comments

Comments
 (0)