@@ -34,10 +34,13 @@ import {
3434import { GenerationConfig } from '../generation/configuration_utils.js' ;
3535import { EosTokenCriteria , MaxLengthCriteria , StoppingCriteriaList } from '../generation/stopping_criteria.js' ;
3636import { LogitsSampler } from '../generation/logits_sampler.js' ;
37- import { pick } from '../utils/core.js' ;
37+ import { DefaultProgressCallback , pick } from '../utils/core.js' ;
3838import { ModelOutput } from './modeling_outputs.js' ;
3939import { logger } from '../utils/logger.js' ;
4040import { DynamicCache } from '../cache_utils.js' ;
41+ import { get_model_files } from '../utils/model_registry/get_model_files.js' ;
42+ import { get_file_metadata } from '../utils/model_registry/get_file_metadata.js' ;
43+ import { MODEL_SESSION_CONFIG , MODEL_TYPES } from './session_config.js' ;
4144
4245/**
4346 * Converts an array or Tensor of integers to an int64 Tensor.
@@ -83,212 +86,81 @@ export function boolTensor(value) {
8386 return new Tensor ( 'bool' , [ value ] , [ 1 ] ) ;
8487}
8588
86- export const MODEL_TYPES = {
87- EncoderOnly : 0 ,
88- EncoderDecoder : 1 ,
89- Seq2Seq : 2 ,
90- Vision2Seq : 3 ,
91- DecoderOnly : 4 ,
92- DecoderOnlyWithoutHead : 5 ,
93- MaskGeneration : 6 ,
94- ImageTextToText : 7 ,
95- Musicgen : 8 ,
96- MultiModality : 9 ,
97- Phi3V : 10 ,
98- AudioTextToText : 11 ,
99- AutoEncoder : 12 ,
100- ImageAudioTextToText : 13 ,
101- Supertonic : 14 ,
102- Chatterbox : 15 ,
103- VoxtralRealtime : 16 ,
104- } ;
89+ export { getSessionsConfig , getTextOnlySessions , MODEL_TYPES } from './session_config.js' ;
10590
106- const MODEL_TYPE_CONFIG = {
91+ /**
92+ * Runtime-only model type configuration (forward functions, generation flags).
93+ * Session/file configuration lives in `MODEL_SESSION_CONFIG` (session_config.js)
94+ * and is merged in at lookup time by `resolveTypeConfig` to avoid duplication.
95+ */
96+ const MODEL_RUNTIME_CONFIG = {
10797 [ MODEL_TYPES . DecoderOnly ] : {
10898 can_generate : true ,
10999 forward : decoder_forward ,
110100 prepare_inputs : decoder_prepare_inputs_for_generation ,
111- sessions : ( config , options ) => ( { model : options . model_file_name ?? 'model' } ) ,
112- cache_sessions : { model : true } ,
113- optional_configs : { generation_config : 'generation_config.json' } ,
114101 } ,
115102 [ MODEL_TYPES . DecoderOnlyWithoutHead ] : {
116103 can_generate : false ,
117104 forward : decoder_forward ,
118105 prepare_inputs : decoder_prepare_inputs_for_generation ,
119- sessions : ( config , options ) => ( { model : options . model_file_name ?? 'model' } ) ,
120106 } ,
121107 [ MODEL_TYPES . Seq2Seq ] : {
122108 can_generate : true ,
123109 forward : seq2seq_forward ,
124110 prepare_inputs : encoder_decoder_prepare_inputs_for_generation ,
125- sessions : ( ) => ( { model : 'encoder_model' , decoder_model_merged : 'decoder_model_merged' } ) ,
126- cache_sessions : { decoder_model_merged : true } ,
127- optional_configs : { generation_config : 'generation_config.json' } ,
128111 } ,
129112 [ MODEL_TYPES . Vision2Seq ] : {
130113 can_generate : true ,
131114 forward : seq2seq_forward ,
132115 prepare_inputs : encoder_decoder_prepare_inputs_for_generation ,
133- sessions : ( ) => ( { model : 'encoder_model' , decoder_model_merged : 'decoder_model_merged' } ) ,
134- cache_sessions : { decoder_model_merged : true } ,
135- optional_configs : { generation_config : 'generation_config.json' } ,
136116 } ,
137117 [ MODEL_TYPES . Musicgen ] : {
138118 can_generate : true ,
139119 forward : seq2seq_forward ,
140- sessions : ( ) => ( {
141- model : 'text_encoder' ,
142- decoder_model_merged : 'decoder_model_merged' ,
143- encodec_decode : 'encodec_decode' ,
144- } ) ,
145- cache_sessions : { decoder_model_merged : true } ,
146- optional_configs : { generation_config : 'generation_config.json' } ,
147120 } ,
148121 [ MODEL_TYPES . EncoderDecoder ] : {
149122 can_generate : false ,
150123 forward : seq2seq_forward ,
151- sessions : ( ) => ( { model : 'encoder_model' , decoder_model_merged : 'decoder_model_merged' } ) ,
152- cache_sessions : { decoder_model_merged : true } ,
153- } ,
154- [ MODEL_TYPES . MaskGeneration ] : {
155- sessions : ( ) => ( { model : 'vision_encoder' , prompt_encoder_mask_decoder : 'prompt_encoder_mask_decoder' } ) ,
156124 } ,
157125 [ MODEL_TYPES . ImageTextToText ] : {
158126 can_generate : true ,
159127 forward : image_text_to_text_forward ,
160128 prepare_inputs : multimodal_text_to_text_prepare_inputs_for_generation ,
161- text_only_sessions : { embed_tokens : 'embed_tokens' , decoder_model_merged : 'decoder_model_merged' } ,
162- sessions : ( config , options , textOnly ) => {
163- const s = { ...MODEL_TYPE_CONFIG [ MODEL_TYPES . ImageTextToText ] . text_only_sessions } ;
164- if ( ! textOnly ) s [ 'vision_encoder' ] = 'vision_encoder' ;
165- if ( config . is_encoder_decoder ) s [ 'model' ] = 'encoder_model' ;
166- return s ;
167- } ,
168- cache_sessions : { decoder_model_merged : true } ,
169- optional_configs : { generation_config : 'generation_config.json' } ,
170129 } ,
171130 [ MODEL_TYPES . AudioTextToText ] : {
172131 can_generate : true ,
173132 forward : audio_text_to_text_forward ,
174133 prepare_inputs : multimodal_text_to_text_prepare_inputs_for_generation ,
175- text_only_sessions : { embed_tokens : 'embed_tokens' , decoder_model_merged : 'decoder_model_merged' } ,
176- sessions : ( config , options , textOnly ) => {
177- const s = { ...MODEL_TYPE_CONFIG [ MODEL_TYPES . AudioTextToText ] . text_only_sessions } ;
178- if ( ! textOnly ) s [ 'audio_encoder' ] = 'audio_encoder' ;
179- return s ;
180- } ,
181- cache_sessions : { decoder_model_merged : true } ,
182- optional_configs : { generation_config : 'generation_config.json' } ,
183134 } ,
184135 [ MODEL_TYPES . ImageAudioTextToText ] : {
185136 can_generate : true ,
186137 prepare_inputs : multimodal_text_to_text_prepare_inputs_for_generation ,
187- text_only_sessions : { embed_tokens : 'embed_tokens' , decoder_model_merged : 'decoder_model_merged' } ,
188- sessions : ( config , options , textOnly ) => {
189- const s = { ...MODEL_TYPE_CONFIG [ MODEL_TYPES . ImageAudioTextToText ] . text_only_sessions } ;
190- if ( ! textOnly ) {
191- s [ 'audio_encoder' ] = 'audio_encoder' ;
192- s [ 'vision_encoder' ] = 'vision_encoder' ;
193- }
194- return s ;
195- } ,
196- optional_configs : { generation_config : 'generation_config.json' } ,
197138 } ,
198139 [ MODEL_TYPES . Phi3V ] : {
199140 can_generate : true ,
200141 prepare_inputs : multimodal_text_to_text_prepare_inputs_for_generation ,
201- sessions : ( ) => ( {
202- prepare_inputs_embeds : 'prepare_inputs_embeds' ,
203- model : 'model' ,
204- vision_encoder : 'vision_encoder' ,
205- } ) ,
206- cache_sessions : { model : true } ,
207- optional_configs : { generation_config : 'generation_config.json' } ,
208142 } ,
209143 [ MODEL_TYPES . MultiModality ] : {
210144 can_generate : true ,
211- sessions : ( ) => ( {
212- prepare_inputs_embeds : 'prepare_inputs_embeds' ,
213- model : 'language_model' ,
214- lm_head : 'lm_head' ,
215- gen_head : 'gen_head' ,
216- gen_img_embeds : 'gen_img_embeds' ,
217- image_decode : 'image_decode' ,
218- } ) ,
219- cache_sessions : { model : true } ,
220- optional_configs : { generation_config : 'generation_config.json' } ,
221145 } ,
222146 [ MODEL_TYPES . AutoEncoder ] : {
223147 can_generate : false ,
224148 forward : auto_encoder_forward ,
225- sessions : ( ) => ( { encoder_model : 'encoder_model' , decoder_model : 'decoder_model' } ) ,
226- } ,
227- [ MODEL_TYPES . Supertonic ] : {
228- sessions : ( ) => ( {
229- text_encoder : 'text_encoder' ,
230- latent_denoiser : 'latent_denoiser' ,
231- voice_decoder : 'voice_decoder' ,
232- } ) ,
233149 } ,
234150 [ MODEL_TYPES . Chatterbox ] : {
235151 can_generate : true ,
236152 forward : encoder_forward ,
237- sessions : ( ) => ( {
238- embed_tokens : 'embed_tokens' ,
239- speech_encoder : 'speech_encoder' ,
240- model : 'language_model' ,
241- conditional_decoder : 'conditional_decoder' ,
242- } ) ,
243- cache_sessions : { model : true } ,
244- optional_configs : { generation_config : 'generation_config.json' } ,
245153 } ,
246154 [ MODEL_TYPES . VoxtralRealtime ] : {
247155 can_generate : true ,
248156 prepare_inputs : decoder_prepare_inputs_for_generation ,
249- text_only_sessions : { embed_tokens : 'embed_tokens' , decoder_model_merged : 'decoder_model_merged' } ,
250- sessions : ( config , options , textOnly ) => {
251- const s = { ...MODEL_TYPE_CONFIG [ MODEL_TYPES . VoxtralRealtime ] . text_only_sessions } ;
252- if ( ! textOnly ) s [ 'audio_encoder' ] = 'audio_encoder' ;
253- return s ;
254- } ,
255- cache_sessions : { decoder_model_merged : true , audio_encoder : true } ,
256- optional_configs : { generation_config : 'generation_config.json' } ,
257157 } ,
258158 default : {
259159 can_generate : false ,
260160 forward : encoder_forward ,
261- sessions : ( config , options ) => ( { model : options . model_file_name ?? 'model' } ) ,
262161 } ,
263162} ;
264163
265- /**
266- * Get the session configuration for a given model type.
267- * @param {number } modelType The model type enum value.
268- * @param {Object } config The model config.
269- * @param {Object } [options] Loading options.
270- * @returns {{ sessions: Record<string, string>, cache_sessions?: Record<string, true>, optional_configs?: Record<string, string> } }
271- */
272- export function getSessionsConfig ( modelType , config , options = { } ) {
273- const typeConfig = MODEL_TYPE_CONFIG [ modelType ] ?? MODEL_TYPE_CONFIG . default ;
274- return {
275- sessions : typeConfig . sessions ( config , options ) ,
276- cache_sessions : typeConfig . cache_sessions ,
277- optional_configs : typeConfig . optional_configs ,
278- } ;
279- }
280-
281- /**
282- * Returns the text-only session names for a given model type, or `null` if
283- * the model type does not define a text-only session set.
284- * @param {number } modelType The model type enum value.
285- * @returns {Record<string, string>|null }
286- */
287- export function getTextOnlySessions ( modelType ) {
288- const typeConfig = MODEL_TYPE_CONFIG [ modelType ] ;
289- return typeConfig ?. text_only_sessions ?? null ;
290- }
291-
292164/**
293165 * Resolves the model type config for a given class name and config.
294166 * @param {string } modelName The name of the class being used to load.
@@ -315,7 +187,9 @@ function resolveTypeConfig(modelName, config) {
315187 }
316188 }
317189
318- return { typeConfig : MODEL_TYPE_CONFIG [ modelType ] ?? MODEL_TYPE_CONFIG . default , textOnly, modelType } ;
190+ const runtimeConfig = MODEL_RUNTIME_CONFIG [ modelType ] ?? MODEL_RUNTIME_CONFIG . default ;
191+ const sessionConfig = MODEL_SESSION_CONFIG [ modelType ] ?? MODEL_SESSION_CONFIG . default ;
192+ return { typeConfig : { ...runtimeConfig , ...sessionConfig } , textOnly, modelType } ;
319193}
320194
321195export const MODEL_TYPE_MAPPING = new Map ( ) ;
@@ -431,6 +305,46 @@ export class PreTrainedModel extends Callable {
431305 }
432306 }
433307
308+ // If a progress callback is provided AND it hasn't already been wrapped
309+ // by pipeline() (which does its own aggregation), gather file metadata
310+ // upfront so we can emit `progress_total` events. This lets consumers
311+ // render a single overall progress bar when calling from_pretrained() directly.
312+ if ( progress_callback && ! ( progress_callback instanceof DefaultProgressCallback ) ) {
313+ /** @type {import('../utils/core.js').FilesLoadingMap } */
314+ const files_loading = { } ;
315+
316+ try {
317+ const expected_files = await get_model_files ( pretrained_model_name_or_path , {
318+ config,
319+ dtype,
320+ device,
321+ model_file_name,
322+ } ) ;
323+
324+ const metadata = await Promise . all (
325+ expected_files . map ( ( file ) => get_file_metadata ( pretrained_model_name_or_path , file , options ) ) ,
326+ ) ;
327+ metadata . forEach ( ( m , i ) => {
328+ if ( m . exists ) {
329+ // config.json is fetched by AutoConfig.from_pretrained() above
330+ const isAlreadyLoaded = expected_files [ i ] === 'config.json' ;
331+ files_loading [ expected_files [ i ] ] = {
332+ loaded : isAlreadyLoaded ? ( m . size ?? 0 ) : 0 ,
333+ total : m . size ?? 0 ,
334+ } ;
335+ }
336+ } ) ;
337+ } catch ( e ) {
338+ // If we fail to get metadata, we can still proceed without total progress.
339+ // This may happen with local-only models or custom cache setups.
340+ logger . warn ( `Unable to fetch model file metadata for total progress tracking: ${ e } ` ) ;
341+ }
342+
343+ if ( Object . keys ( files_loading ) . length > 0 ) {
344+ options . progress_callback = new DefaultProgressCallback ( progress_callback , files_loading ) ;
345+ }
346+ }
347+
434348 const sessions = typeConfig . sessions ( config , options , textOnly ) ;
435349 const promises = [
436350 constructSessions ( pretrained_model_name_or_path , sessions , options , typeConfig . cache_sessions ) ,
@@ -1271,7 +1185,8 @@ export class PreTrainedModel extends Callable {
12711185 * @private
12721186 */
12731187export async function seq2seq_forward ( self , model_inputs ) {
1274- let { encoder_outputs, input_ids, decoder_input_ids, decoder_attention_mask, ...other_decoder_inputs } = model_inputs ;
1188+ let { encoder_outputs, input_ids, decoder_input_ids, decoder_attention_mask, ...other_decoder_inputs } =
1189+ model_inputs ;
12751190 // Encode if needed
12761191 if ( ! encoder_outputs ) {
12771192 const encoder_inputs = pick ( model_inputs , self . sessions [ 'model' ] . inputNames ) ;
0 commit comments