1818import re
1919
2020import torch
21- from cosmos3 .common .init import init_script
22-
23-
24- init_script ()
25-
26- from accelerate import init_empty_weights # noqa: E402
27- from cosmos3 .args import _CHECKPOINTS # noqa: E402
28- from cosmos3 .model import Cosmos3OmniModel # noqa: E402
29- from projects .cosmos3 .vfm .models .omni_mot_model import OmniMoTModel # noqa: E402
30- from transformers import AutoTokenizer # noqa: E402
31-
32- from diffusers import AutoencoderKLWan , UniPCMultistepScheduler # noqa: E402
33- from diffusers .models .autoencoders .autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer # noqa: E402
34- from diffusers .models .transformers .transformer_cosmos3 import Cosmos3OmniTransformer # noqa: E402
35- from diffusers .pipelines .cosmos .pipeline_cosmos3_omni import Cosmos3OmniPipeline # noqa: E402
3621
3722
3823DEFAULT_SOUND_TOKENIZER_CONFIG = {
24+ "model_type" : "autoencoder_v2" ,
3925 "sampling_rate" : 48000 ,
26+ "stereo" : True ,
27+ "use_wav_as_input" : True ,
28+ "normalize_volume" : True ,
29+ "hop_size" : 1920 ,
30+ "input_channels" : 1 ,
31+ "enc_type" : "spec_convnext" ,
32+ "enc_dim" : 192 ,
33+ "enc_intermediate_dim" : 768 ,
34+ "enc_num_layers" : 12 ,
35+ "enc_num_blocks" : 2 ,
36+ "enc_n_fft" : 64 ,
37+ "enc_hop_length" : 16 ,
38+ "enc_latent_dim" : 128 ,
39+ "enc_c_mults" : [1 , 2 , 4 ],
40+ "enc_strides" : [4 , 5 , 6 ],
41+ "enc_identity_init" : False ,
42+ "enc_use_snake" : True ,
43+ "dec_type" : "oobleck" ,
4044 "vocoder_input_dim" : 64 ,
4145 "dec_dim" : 320 ,
4246 "dec_c_mults" : [1 , 2 , 4 , 8 , 16 ],
4347 "dec_strides" : [2 , 4 , 5 , 6 , 8 ],
48+ "dec_use_snake" : True ,
49+ "dec_final_tanh" : False ,
4450 "dec_out_channels" : 2 ,
51+ "dec_anti_aliasing" : False ,
52+ "dec_use_nearest_upsample" : False ,
53+ "dec_use_tanh_at_final" : False ,
54+ "bottleneck_type" : "vae" ,
55+ "bottleneck" : {"type" : "vae" },
56+ "activation" : "snakebeta" ,
57+ "snake_logscale" : True ,
58+ "anti_aliasing" : False ,
59+ "use_cuda_kernel" : False ,
60+ "causal" : False ,
61+ "padding_mode" : "zeros" ,
62+ "latent_mean" : None ,
63+ "latent_std" : None ,
4564}
4665
4766
@@ -114,8 +133,10 @@ def _sound_tokenizer_strip_per_key_prefixes(state_dict: dict[str, torch.Tensor])
114133 return out
115134
116135
117- def _sound_tokenizer_filter_decoder (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
118- return {key : value for key , value in state_dict .items () if key .startswith ("decoder." )}
136+ def _sound_tokenizer_filter_supported_modules (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
137+ return {
138+ key : value for key , value in state_dict .items () if key .startswith ("encoder." ) or key .startswith ("decoder." )
139+ }
119140
120141
121142def _sound_tokenizer_infer_num_blocks (state_dict : dict [str , torch .Tensor ]) -> int :
@@ -185,7 +206,11 @@ def _remap(key: str) -> str:
185206def _sound_tokenizer_reshape_snake_params (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
186207 out : dict [str , torch .Tensor ] = {}
187208 for key , value in state_dict .items ():
188- if (key .endswith (".alpha" ) or key .endswith (".beta" )) and value .ndim == 1 :
209+ if (
210+ key .startswith (("encoder." , "decoder." ))
211+ and (key .endswith (".alpha" ) or key .endswith (".beta" ))
212+ and value .ndim == 1
213+ ):
189214 value = value .unsqueeze (0 ).unsqueeze (- 1 ).contiguous ()
190215 out [key ] = value
191216 return out
@@ -197,7 +222,11 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) ->
197222 candidate_keys = [
198223 key
199224 for key in state_dict
200- if key .endswith (".weight" ) and any (f".{ layer } ." in key for layer in ("conv1" , "conv2" , "conv_t1" ))
225+ if key .endswith (".weight" )
226+ and (
227+ any (f".{ layer } ." in key for layer in ("conv1" , "conv2" , "conv_t1" ))
228+ or re .fullmatch (r"encoder\.layers\.\d+\.weight" , key )
229+ )
201230 ]
202231 for key in candidate_keys :
203232 stem = key [: - len (".weight" )]
@@ -216,8 +245,10 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) ->
216245def _remap_avae_state_dict (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
217246 """Convert a legacy AVAE state dict into the Cosmos3AVAEAudioTokenizer state dict."""
218247 state_dict = _sound_tokenizer_strip_per_key_prefixes (state_dict )
219- state_dict = _sound_tokenizer_filter_decoder (state_dict )
248+ state_dict = _sound_tokenizer_filter_supported_modules (state_dict )
220249 if not state_dict :
250+ raise RuntimeError ("Sound tokenizer state dict has no `encoder.*` or `decoder.*` keys after prefix stripping." )
251+ if not any (key .startswith ("decoder." ) for key in state_dict ):
221252 raise RuntimeError ("Sound tokenizer state dict has no `decoder.*` keys after prefix stripping." )
222253 state_dict = _sound_tokenizer_remap_flat_layout (state_dict )
223254 state_dict = _sound_tokenizer_reshape_snake_params (state_dict )
@@ -230,20 +261,67 @@ def _remap_avae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, tor
230261def _build_sound_tokenizer (
231262 checkpoint_path : pathlib .Path ,
232263 config_path : pathlib .Path | None ,
233- ) -> Cosmos3AVAEAudioTokenizer :
264+ ):
265+ from diffusers .models .autoencoders .autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer
266+
234267 config = _load_sound_tokenizer_config (config_path , fallback_config_path = pathlib .Path ())
235268 print (f"Loading AVAE sound tokenizer weights from { checkpoint_path } …" )
236269 raw_state_dict = _load_sound_tokenizer_state_dict (checkpoint_path )
237270 state_dict = _remap_avae_state_dict (raw_state_dict )
238- print (f" Remapped { len (raw_state_dict )} → { len (state_dict )} decoder keys." )
271+ has_encoder = any (key .startswith ("encoder." ) for key in state_dict )
272+ print (
273+ f" Remapped { len (raw_state_dict )} → { len (state_dict )} tokenizer keys "
274+ f"({ 'encoder+decoder' if has_encoder else 'decoder-only' } )."
275+ )
239276
240277 sound_tokenizer = Cosmos3AVAEAudioTokenizer (
278+ model_type = config .get ("model_type" , DEFAULT_SOUND_TOKENIZER_CONFIG ["model_type" ]),
241279 sampling_rate = config .get ("sampling_rate" , DEFAULT_SOUND_TOKENIZER_CONFIG ["sampling_rate" ]),
280+ stereo = config .get ("stereo" , DEFAULT_SOUND_TOKENIZER_CONFIG ["stereo" ]),
281+ use_wav_as_input = config .get ("use_wav_as_input" , DEFAULT_SOUND_TOKENIZER_CONFIG ["use_wav_as_input" ]),
282+ normalize_volume = config .get ("normalize_volume" , DEFAULT_SOUND_TOKENIZER_CONFIG ["normalize_volume" ]),
283+ hop_size = config .get ("hop_size" , DEFAULT_SOUND_TOKENIZER_CONFIG ["hop_size" ]),
284+ input_channels = config .get ("input_channels" , DEFAULT_SOUND_TOKENIZER_CONFIG ["input_channels" ]),
285+ enc_type = config .get ("enc_type" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_type" ]),
286+ enc_dim = config .get ("enc_dim" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_dim" ]),
287+ enc_intermediate_dim = config .get (
288+ "enc_intermediate_dim" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_intermediate_dim" ]
289+ ),
290+ enc_num_layers = config .get ("enc_num_layers" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_num_layers" ]),
291+ enc_num_blocks = config .get ("enc_num_blocks" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_num_blocks" ]),
292+ enc_n_fft = config .get ("enc_n_fft" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_n_fft" ]),
293+ enc_hop_length = config .get ("enc_hop_length" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_hop_length" ]),
294+ enc_latent_dim = config .get ("enc_latent_dim" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_latent_dim" ]),
295+ enc_c_mults = tuple (config .get ("enc_c_mults" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_c_mults" ])),
296+ enc_strides = tuple (config .get ("enc_strides" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_strides" ])),
297+ enc_identity_init = config .get ("enc_identity_init" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_identity_init" ]),
298+ enc_use_snake = config .get ("enc_use_snake" , DEFAULT_SOUND_TOKENIZER_CONFIG ["enc_use_snake" ]),
299+ dec_type = config .get ("dec_type" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_type" ]),
242300 vocoder_input_dim = config .get ("vocoder_input_dim" , DEFAULT_SOUND_TOKENIZER_CONFIG ["vocoder_input_dim" ]),
243301 dec_dim = config .get ("dec_dim" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_dim" ]),
244302 dec_c_mults = tuple (config .get ("dec_c_mults" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_c_mults" ])),
245303 dec_strides = tuple (config .get ("dec_strides" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_strides" ])),
304+ dec_use_snake = config .get ("dec_use_snake" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_use_snake" ]),
305+ dec_final_tanh = config .get ("dec_final_tanh" , False ),
246306 dec_out_channels = config .get ("dec_out_channels" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_out_channels" ]),
307+ dec_anti_aliasing = config .get ("dec_anti_aliasing" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_anti_aliasing" ]),
308+ dec_use_nearest_upsample = config .get (
309+ "dec_use_nearest_upsample" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_use_nearest_upsample" ]
310+ ),
311+ dec_use_tanh_at_final = config .get (
312+ "dec_use_tanh_at_final" , DEFAULT_SOUND_TOKENIZER_CONFIG ["dec_use_tanh_at_final" ]
313+ ),
314+ bottleneck_type = config .get ("bottleneck_type" , DEFAULT_SOUND_TOKENIZER_CONFIG ["bottleneck_type" ]),
315+ bottleneck = config .get ("bottleneck" , DEFAULT_SOUND_TOKENIZER_CONFIG ["bottleneck" ]),
316+ activation = config .get ("activation" , DEFAULT_SOUND_TOKENIZER_CONFIG ["activation" ]),
317+ snake_logscale = config .get ("snake_logscale" , DEFAULT_SOUND_TOKENIZER_CONFIG ["snake_logscale" ]),
318+ anti_aliasing = config .get ("anti_aliasing" , DEFAULT_SOUND_TOKENIZER_CONFIG ["anti_aliasing" ]),
319+ use_cuda_kernel = config .get ("use_cuda_kernel" , DEFAULT_SOUND_TOKENIZER_CONFIG ["use_cuda_kernel" ]),
320+ causal = config .get ("causal" , DEFAULT_SOUND_TOKENIZER_CONFIG ["causal" ]),
321+ padding_mode = config .get ("padding_mode" , DEFAULT_SOUND_TOKENIZER_CONFIG ["padding_mode" ]),
322+ latent_mean = config .get ("latent_mean" , DEFAULT_SOUND_TOKENIZER_CONFIG ["latent_mean" ]),
323+ latent_std = config .get ("latent_std" , DEFAULT_SOUND_TOKENIZER_CONFIG ["latent_std" ]),
324+ encoder_enabled = has_encoder ,
247325 )
248326 load_result = sound_tokenizer .load_state_dict (state_dict , strict = True )
249327 if load_result .missing_keys or load_result .unexpected_keys :
@@ -255,8 +333,8 @@ def _build_sound_tokenizer(
255333
256334
257335@contextlib .contextmanager
258- def _skip_source_sound_tokenizer_load ():
259- original_set_up_tokenizers = OmniMoTModel .set_up_tokenizers
336+ def _skip_source_sound_tokenizer_load (omni_mot_model_cls ):
337+ original_set_up_tokenizers = omni_mot_model_cls .set_up_tokenizers
260338
261339 def set_up_tokenizers_without_sound (self ):
262340 if not getattr (self .config , "sound_gen" , False ):
@@ -269,14 +347,28 @@ def set_up_tokenizers_without_sound(self):
269347 finally :
270348 self .config .sound_gen = sound_gen
271349
272- OmniMoTModel .set_up_tokenizers = set_up_tokenizers_without_sound
350+ omni_mot_model_cls .set_up_tokenizers = set_up_tokenizers_without_sound
273351 try :
274352 yield
275353 finally :
276- OmniMoTModel .set_up_tokenizers = original_set_up_tokenizers
354+ omni_mot_model_cls .set_up_tokenizers = original_set_up_tokenizers
277355
278356
279357def main ():
358+ from cosmos3 .common .init import init_script
359+
360+ init_script ()
361+
362+ from accelerate import init_empty_weights
363+ from cosmos3 .args import _CHECKPOINTS
364+ from cosmos3 .model import Cosmos3OmniModel
365+ from projects .cosmos3 .vfm .models .omni_mot_model import OmniMoTModel
366+ from transformers import AutoTokenizer
367+
368+ from diffusers import AutoencoderKLWan , UniPCMultistepScheduler
369+ from diffusers .models .transformers .transformer_cosmos3 import Cosmos3OmniTransformer
370+ from diffusers .pipelines .cosmos .pipeline_cosmos3_omni import Cosmos3OmniPipeline
371+
280372 parser = argparse .ArgumentParser (description = __doc__ )
281373 parser .add_argument (
282374 "--checkpoint-path" ,
@@ -330,7 +422,7 @@ def main():
330422
331423 print ("Instantiating model and loading weights from DCP checkpoint …" )
332424 print ("Skipping source AVAE tokenizer instantiation during converter-only model load …" )
333- with _skip_source_sound_tokenizer_load ():
425+ with _skip_source_sound_tokenizer_load (OmniMoTModel ):
334426 _tmp = Cosmos3OmniModel .from_pretrained_dcp (checkpoint_path ).model
335427
336428 # Extract network components and architecture config from DCP model
0 commit comments