@@ -131,7 +131,9 @@ def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="
131131 # =========================================================================
132132 transformer_sd = {}
133133 condition_encoder_sd = {}
134- other_sd = {} # tokenizer, detokenizer (audio quantization — not used by the text2music pipeline)
134+ audio_tokenizer_sd = {}
135+ audio_token_detokenizer_sd = {}
136+ other_sd = {}
135137
136138 # Rename original ACE-Step attention keys to the diffusers `Attention` +
137139 # `AttnProcessor` convention (`to_q`/`to_k`/`to_v`/`to_out.0`/`norm_q`/`norm_k`).
@@ -174,11 +176,21 @@ def _rename_attn_keys(key: str) -> str:
174176 # Keep it co-located with the condition encoder since that is where the
175177 # pipeline pulls unconditional sequences from.
176178 condition_encoder_sd ["null_condition_emb" ] = value .to (target_dtype )
179+ elif key .startswith ("tokenizer." ):
180+ new_key = key [len ("tokenizer." ) :]
181+ new_key = _rename_attn_keys (new_key )
182+ audio_tokenizer_sd [new_key ] = value .to (target_dtype )
183+ elif key .startswith ("detokenizer." ):
184+ new_key = key [len ("detokenizer." ) :]
185+ new_key = _rename_attn_keys (new_key )
186+ audio_token_detokenizer_sd [new_key ] = value .to (target_dtype )
177187 else :
178188 other_sd [key ] = value .to (target_dtype )
179189
180190 print (f" Transformer keys: { len (transformer_sd )} " )
181191 print (f" Condition encoder keys: { len (condition_encoder_sd )} " )
192+ print (f" Audio tokenizer keys: { len (audio_tokenizer_sd )} " )
193+ print (f" Audio token detokenizer keys: { len (audio_token_detokenizer_sd )} " )
182194 print (f" Other keys: { len (other_sd )} ({ list (other_sd .keys ())[:5 ]} ...)" )
183195
184196 # =========================================================================
@@ -248,6 +260,47 @@ def _rename_attn_keys(key: str) -> str:
248260 "sliding_window" : original_config ["sliding_window" ],
249261 }
250262
263+ audio_tokenizer_config = {
264+ "_class_name" : "AceStepAudioTokenizer" ,
265+ "_diffusers_version" : "0.33.0.dev0" ,
266+ "hidden_size" : encoder_hidden_size ,
267+ "intermediate_size" : encoder_intermediate_size ,
268+ "audio_acoustic_hidden_dim" : original_config ["audio_acoustic_hidden_dim" ],
269+ "pool_window_size" : original_config .get ("pool_window_size" , 5 ),
270+ "fsq_dim" : original_config .get ("fsq_dim" , encoder_hidden_size ),
271+ "fsq_input_levels" : original_config .get ("fsq_input_levels" , [8 , 8 , 8 , 5 , 5 , 5 ]),
272+ "fsq_input_num_quantizers" : original_config .get ("fsq_input_num_quantizers" , 1 ),
273+ "num_attention_pooler_hidden_layers" : original_config .get ("num_attention_pooler_hidden_layers" , 2 ),
274+ "num_attention_heads" : encoder_num_attention_heads ,
275+ "num_key_value_heads" : encoder_num_key_value_heads ,
276+ "head_dim" : original_config ["head_dim" ],
277+ "rope_theta" : original_config ["rope_theta" ],
278+ "attention_bias" : original_config ["attention_bias" ],
279+ "attention_dropout" : original_config ["attention_dropout" ],
280+ "rms_norm_eps" : original_config ["rms_norm_eps" ],
281+ "sliding_window" : original_config ["sliding_window" ],
282+ "layer_types" : original_config ["layer_types" ][: original_config .get ("num_attention_pooler_hidden_layers" , 2 )],
283+ }
284+
285+ audio_token_detokenizer_config = {
286+ "_class_name" : "AceStepAudioTokenDetokenizer" ,
287+ "_diffusers_version" : "0.33.0.dev0" ,
288+ "hidden_size" : encoder_hidden_size ,
289+ "intermediate_size" : encoder_intermediate_size ,
290+ "audio_acoustic_hidden_dim" : original_config ["audio_acoustic_hidden_dim" ],
291+ "pool_window_size" : original_config .get ("pool_window_size" , 5 ),
292+ "num_attention_pooler_hidden_layers" : original_config .get ("num_attention_pooler_hidden_layers" , 2 ),
293+ "num_attention_heads" : encoder_num_attention_heads ,
294+ "num_key_value_heads" : encoder_num_key_value_heads ,
295+ "head_dim" : original_config ["head_dim" ],
296+ "rope_theta" : original_config ["rope_theta" ],
297+ "attention_bias" : original_config ["attention_bias" ],
298+ "attention_dropout" : original_config ["attention_dropout" ],
299+ "rms_norm_eps" : original_config ["rms_norm_eps" ],
300+ "sliding_window" : original_config ["sliding_window" ],
301+ "layer_types" : original_config ["layer_types" ][: original_config .get ("num_attention_pooler_hidden_layers" , 2 )],
302+ }
303+
251304 # =========================================================================
252305 # 3. Bake silence_latent into the condition_encoder state dict.
253306 #
@@ -282,11 +335,19 @@ def _rename_attn_keys(key: str) -> str:
282335 AutoencoderOobleck ,
283336 FlowMatchEulerDiscreteScheduler ,
284337 )
285- from diffusers .pipelines .ace_step import AceStepConditionEncoder
338+ from diffusers .pipelines .ace_step import (
339+ AceStepAudioTokenDetokenizer ,
340+ AceStepAudioTokenizer ,
341+ AceStepConditionEncoder ,
342+ )
286343
287344 # Drop metadata keys — they're re-populated by `save_pretrained` at save time.
288345 transformer_init_kwargs = {k : v for k , v in transformer_config .items () if not k .startswith ("_" )}
289346 condition_encoder_init_kwargs = {k : v for k , v in condition_encoder_config .items () if not k .startswith ("_" )}
347+ audio_tokenizer_init_kwargs = {k : v for k , v in audio_tokenizer_config .items () if not k .startswith ("_" )}
348+ audio_token_detokenizer_init_kwargs = {
349+ k : v for k , v in audio_token_detokenizer_config .items () if not k .startswith ("_" )
350+ }
290351
291352 print ("\n Constructing transformer ..." )
292353 transformer = AceStepTransformer1DModel (** transformer_init_kwargs ).to (target_dtype )
@@ -296,6 +357,14 @@ def _rename_attn_keys(key: str) -> str:
296357 condition_encoder = AceStepConditionEncoder (** condition_encoder_init_kwargs ).to (target_dtype )
297358 condition_encoder .load_state_dict (condition_encoder_sd , strict = True )
298359
360+ print ("Constructing audio_tokenizer ..." )
361+ audio_tokenizer = AceStepAudioTokenizer (** audio_tokenizer_init_kwargs ).to (target_dtype )
362+ audio_tokenizer .load_state_dict (audio_tokenizer_sd , strict = True )
363+
364+ print ("Constructing audio_token_detokenizer ..." )
365+ audio_token_detokenizer = AceStepAudioTokenDetokenizer (** audio_token_detokenizer_init_kwargs ).to (target_dtype )
366+ audio_token_detokenizer .load_state_dict (audio_token_detokenizer_sd , strict = True )
367+
299368 print ("Loading VAE ..." )
300369 vae = AutoencoderOobleck .from_pretrained (vae_dir ).to (target_dtype )
301370
@@ -319,6 +388,8 @@ def _rename_attn_keys(key: str) -> str:
319388 transformer = transformer ,
320389 condition_encoder = condition_encoder ,
321390 scheduler = scheduler ,
391+ audio_tokenizer = audio_tokenizer ,
392+ audio_token_detokenizer = audio_token_detokenizer ,
322393 )
323394
324395 print (f"\n Saving pipeline -> { output_dir } " )
@@ -331,18 +402,13 @@ def _rename_attn_keys(key: str) -> str:
331402 shutil .copy2 (silence_latent_src , os .path .join (output_dir , "silence_latent.pt" ))
332403 print (f" kept raw silence_latent copy at { output_dir } /silence_latent.pt" )
333404
334- # Report other keys that were not saved to transformer or condition_encoder
405+ # Report any keys that were not saved to registered pipeline modules.
335406 if other_sd :
336- print (f"\n Note: { len (other_sd )} keys were dropped (tokenizer / detokenizer weights) :" )
407+ print (f"\n Note: { len (other_sd )} keys were dropped:" )
337408 for key in sorted (other_sd .keys ())[:10 ]:
338409 print (f" { key } " )
339410 if len (other_sd ) > 10 :
340411 print (f" ... ({ len (other_sd ) - 10 } more)" )
341- print (
342- "These belong to the audio tokenizer / detokenizer used by the 5Hz LM path "
343- "(cover / audio-code tasks). The Diffusers text2music pipeline does not "
344- "currently expose them."
345- )
346412
347413 print (f"\n Conversion complete! Output saved to: { output_dir } " )
348414 print ("\n To load the pipeline:" )
0 commit comments