Skip to content

Commit f39362a

Browse files
committed
Address ACE-Step audio token and APG review
1 parent 23f2353 commit f39362a

9 files changed

Lines changed: 717 additions & 147 deletions

File tree

scripts/convert_ace_step_to_diffusers.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="
131131
# =========================================================================
132132
transformer_sd = {}
133133
condition_encoder_sd = {}
134-
other_sd = {} # tokenizer, detokenizer (audio quantization — not used by the text2music pipeline)
134+
audio_tokenizer_sd = {}
135+
audio_token_detokenizer_sd = {}
136+
other_sd = {}
135137

136138
# Rename original ACE-Step attention keys to the diffusers `Attention` +
137139
# `AttnProcessor` convention (`to_q`/`to_k`/`to_v`/`to_out.0`/`norm_q`/`norm_k`).
@@ -174,11 +176,21 @@ def _rename_attn_keys(key: str) -> str:
174176
# Keep it co-located with the condition encoder since that is where the
175177
# pipeline pulls unconditional sequences from.
176178
condition_encoder_sd["null_condition_emb"] = value.to(target_dtype)
179+
elif key.startswith("tokenizer."):
180+
new_key = key[len("tokenizer.") :]
181+
new_key = _rename_attn_keys(new_key)
182+
audio_tokenizer_sd[new_key] = value.to(target_dtype)
183+
elif key.startswith("detokenizer."):
184+
new_key = key[len("detokenizer.") :]
185+
new_key = _rename_attn_keys(new_key)
186+
audio_token_detokenizer_sd[new_key] = value.to(target_dtype)
177187
else:
178188
other_sd[key] = value.to(target_dtype)
179189

180190
print(f" Transformer keys: {len(transformer_sd)}")
181191
print(f" Condition encoder keys: {len(condition_encoder_sd)}")
192+
print(f" Audio tokenizer keys: {len(audio_tokenizer_sd)}")
193+
print(f" Audio token detokenizer keys: {len(audio_token_detokenizer_sd)}")
182194
print(f" Other keys: {len(other_sd)} ({list(other_sd.keys())[:5]}...)")
183195

184196
# =========================================================================
@@ -248,6 +260,47 @@ def _rename_attn_keys(key: str) -> str:
248260
"sliding_window": original_config["sliding_window"],
249261
}
250262

263+
audio_tokenizer_config = {
264+
"_class_name": "AceStepAudioTokenizer",
265+
"_diffusers_version": "0.33.0.dev0",
266+
"hidden_size": encoder_hidden_size,
267+
"intermediate_size": encoder_intermediate_size,
268+
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
269+
"pool_window_size": original_config.get("pool_window_size", 5),
270+
"fsq_dim": original_config.get("fsq_dim", encoder_hidden_size),
271+
"fsq_input_levels": original_config.get("fsq_input_levels", [8, 8, 8, 5, 5, 5]),
272+
"fsq_input_num_quantizers": original_config.get("fsq_input_num_quantizers", 1),
273+
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
274+
"num_attention_heads": encoder_num_attention_heads,
275+
"num_key_value_heads": encoder_num_key_value_heads,
276+
"head_dim": original_config["head_dim"],
277+
"rope_theta": original_config["rope_theta"],
278+
"attention_bias": original_config["attention_bias"],
279+
"attention_dropout": original_config["attention_dropout"],
280+
"rms_norm_eps": original_config["rms_norm_eps"],
281+
"sliding_window": original_config["sliding_window"],
282+
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
283+
}
284+
285+
audio_token_detokenizer_config = {
286+
"_class_name": "AceStepAudioTokenDetokenizer",
287+
"_diffusers_version": "0.33.0.dev0",
288+
"hidden_size": encoder_hidden_size,
289+
"intermediate_size": encoder_intermediate_size,
290+
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
291+
"pool_window_size": original_config.get("pool_window_size", 5),
292+
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
293+
"num_attention_heads": encoder_num_attention_heads,
294+
"num_key_value_heads": encoder_num_key_value_heads,
295+
"head_dim": original_config["head_dim"],
296+
"rope_theta": original_config["rope_theta"],
297+
"attention_bias": original_config["attention_bias"],
298+
"attention_dropout": original_config["attention_dropout"],
299+
"rms_norm_eps": original_config["rms_norm_eps"],
300+
"sliding_window": original_config["sliding_window"],
301+
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
302+
}
303+
251304
# =========================================================================
252305
# 3. Bake silence_latent into the condition_encoder state dict.
253306
#
@@ -282,11 +335,19 @@ def _rename_attn_keys(key: str) -> str:
282335
AutoencoderOobleck,
283336
FlowMatchEulerDiscreteScheduler,
284337
)
285-
from diffusers.pipelines.ace_step import AceStepConditionEncoder
338+
from diffusers.pipelines.ace_step import (
339+
AceStepAudioTokenDetokenizer,
340+
AceStepAudioTokenizer,
341+
AceStepConditionEncoder,
342+
)
286343

287344
# Drop metadata keys — they're re-populated by `save_pretrained` at save time.
288345
transformer_init_kwargs = {k: v for k, v in transformer_config.items() if not k.startswith("_")}
289346
condition_encoder_init_kwargs = {k: v for k, v in condition_encoder_config.items() if not k.startswith("_")}
347+
audio_tokenizer_init_kwargs = {k: v for k, v in audio_tokenizer_config.items() if not k.startswith("_")}
348+
audio_token_detokenizer_init_kwargs = {
349+
k: v for k, v in audio_token_detokenizer_config.items() if not k.startswith("_")
350+
}
290351

291352
print("\nConstructing transformer ...")
292353
transformer = AceStepTransformer1DModel(**transformer_init_kwargs).to(target_dtype)
@@ -296,6 +357,14 @@ def _rename_attn_keys(key: str) -> str:
296357
condition_encoder = AceStepConditionEncoder(**condition_encoder_init_kwargs).to(target_dtype)
297358
condition_encoder.load_state_dict(condition_encoder_sd, strict=True)
298359

360+
print("Constructing audio_tokenizer ...")
361+
audio_tokenizer = AceStepAudioTokenizer(**audio_tokenizer_init_kwargs).to(target_dtype)
362+
audio_tokenizer.load_state_dict(audio_tokenizer_sd, strict=True)
363+
364+
print("Constructing audio_token_detokenizer ...")
365+
audio_token_detokenizer = AceStepAudioTokenDetokenizer(**audio_token_detokenizer_init_kwargs).to(target_dtype)
366+
audio_token_detokenizer.load_state_dict(audio_token_detokenizer_sd, strict=True)
367+
299368
print("Loading VAE ...")
300369
vae = AutoencoderOobleck.from_pretrained(vae_dir).to(target_dtype)
301370

@@ -319,6 +388,8 @@ def _rename_attn_keys(key: str) -> str:
319388
transformer=transformer,
320389
condition_encoder=condition_encoder,
321390
scheduler=scheduler,
391+
audio_tokenizer=audio_tokenizer,
392+
audio_token_detokenizer=audio_token_detokenizer,
322393
)
323394

324395
print(f"\nSaving pipeline -> {output_dir}")
@@ -331,18 +402,13 @@ def _rename_attn_keys(key: str) -> str:
331402
shutil.copy2(silence_latent_src, os.path.join(output_dir, "silence_latent.pt"))
332403
print(f" kept raw silence_latent copy at {output_dir}/silence_latent.pt")
333404

334-
# Report other keys that were not saved to transformer or condition_encoder
405+
# Report any keys that were not saved to registered pipeline modules.
335406
if other_sd:
336-
print(f"\nNote: {len(other_sd)} keys were dropped (tokenizer / detokenizer weights):")
407+
print(f"\nNote: {len(other_sd)} keys were dropped:")
337408
for key in sorted(other_sd.keys())[:10]:
338409
print(f" {key}")
339410
if len(other_sd) > 10:
340411
print(f" ... ({len(other_sd) - 10} more)")
341-
print(
342-
"These belong to the audio tokenizer / detokenizer used by the 5Hz LM path "
343-
"(cover / audio-code tasks). The Diffusers text2music pipeline does not "
344-
"currently expose them."
345-
)
346412

347413
print(f"\nConversion complete! Output saved to: {output_dir}")
348414
print("\nTo load the pipeline:")

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,8 @@
487487
)
488488
_import_structure["pipelines"].extend(
489489
[
490+
"AceStepAudioTokenDetokenizer",
491+
"AceStepAudioTokenizer",
490492
"AceStepConditionEncoder",
491493
"AceStepPipeline",
492494
"AllegroPipeline",
@@ -1277,6 +1279,8 @@
12771279
ZImageModularPipeline,
12781280
)
12791281
from .pipelines import (
1282+
AceStepAudioTokenDetokenizer,
1283+
AceStepAudioTokenizer,
12801284
AceStepConditionEncoder,
12811285
AceStepPipeline,
12821286
AllegroPipeline,

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
4040
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
4141
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
4242
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
43+
adaptive_projected_guidance_norm_dim (`int` or `tuple[int]`, *optional*):
44+
Dimension(s) over which to compute the APG norm and projection. If omitted, all non-batch dimensions are
45+
used, preserving the original behavior.
4346
guidance_rescale (`float`, defaults to `0.0`):
4447
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
4548
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
@@ -62,6 +65,7 @@ def __init__(
6265
guidance_scale: float = 7.5,
6366
adaptive_projected_guidance_momentum: float | None = None,
6467
adaptive_projected_guidance_rescale: float = 15.0,
68+
adaptive_projected_guidance_norm_dim: int | tuple[int, ...] | None = None,
6569
eta: float = 1.0,
6670
guidance_rescale: float = 0.0,
6771
use_original_formulation: bool = False,
@@ -74,6 +78,7 @@ def __init__(
7478
self.guidance_scale = guidance_scale
7579
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
7680
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
81+
self.adaptive_projected_guidance_norm_dim = adaptive_projected_guidance_norm_dim
7782
self.eta = eta
7883
self.guidance_rescale = guidance_rescale
7984
self.use_original_formulation = use_original_formulation
@@ -117,6 +122,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No
117122
self.eta,
118123
self.adaptive_projected_guidance_rescale,
119124
self.use_original_formulation,
125+
self.adaptive_projected_guidance_norm_dim,
120126
)
121127

122128
if self.guidance_rescale > 0.0:
@@ -210,9 +216,15 @@ def normalized_guidance(
210216
eta: float = 1.0,
211217
norm_threshold: float = 0.0,
212218
use_original_formulation: bool = False,
219+
norm_dim: int | tuple[int, ...] | None = None,
213220
):
214221
diff = pred_cond - pred_uncond
215-
dim = [-i for i in range(1, len(diff.shape))]
222+
if norm_dim is None:
223+
dim = [-i for i in range(1, len(diff.shape))]
224+
elif isinstance(norm_dim, int):
225+
dim = [norm_dim]
226+
else:
227+
dim = list(norm_dim)
216228

217229
if momentum_buffer is not None:
218230
momentum_buffer.update(diff)
@@ -224,11 +236,15 @@ def normalized_guidance(
224236
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
225237
diff = diff * scale_factor
226238

227-
v0, v1 = diff.double(), pred_cond.double()
239+
if diff.device.type in {"mps", "npu"}:
240+
v0, v1 = diff.cpu().double(), pred_cond.cpu().double()
241+
else:
242+
v0, v1 = diff.double(), pred_cond.double()
228243
v1 = torch.nn.functional.normalize(v1, dim=dim)
229244
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
230245
v0_orthogonal = v0 - v0_parallel
231-
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
246+
diff_parallel = v0_parallel.to(device=diff.device, dtype=diff.dtype)
247+
diff_orthogonal = v0_orthogonal.to(device=diff.device, dtype=diff.dtype)
232248
normalized_update = diff_orthogonal + eta * diff_parallel
233249

234250
pred = pred_cond if use_original_formulation else pred_uncond

src/diffusers/pipelines/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@
150150
]
151151
)
152152
_import_structure["ace_step"] = [
153+
"AceStepAudioTokenDetokenizer",
154+
"AceStepAudioTokenizer",
153155
"AceStepConditionEncoder",
154156
"AceStepPipeline",
155157
]
@@ -578,7 +580,12 @@
578580
except OptionalDependencyNotAvailable:
579581
from ..utils.dummy_torch_and_transformers_objects import *
580582
else:
581-
from .ace_step import AceStepConditionEncoder, AceStepPipeline
583+
from .ace_step import (
584+
AceStepAudioTokenDetokenizer,
585+
AceStepAudioTokenizer,
586+
AceStepConditionEncoder,
587+
AceStepPipeline,
588+
)
582589
from .allegro import AllegroPipeline
583590
from .animatediff import (
584591
AnimateDiffControlNetPipeline,

src/diffusers/pipelines/ace_step/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25-
_import_structure["modeling_ace_step"] = ["AceStepConditionEncoder"]
25+
_import_structure["modeling_ace_step"] = [
26+
"AceStepAudioTokenDetokenizer",
27+
"AceStepAudioTokenizer",
28+
"AceStepConditionEncoder",
29+
]
2630
_import_structure["pipeline_ace_step"] = ["AceStepPipeline"]
2731

2832

@@ -34,7 +38,7 @@
3438
from ...utils.dummy_torch_and_transformers_objects import *
3539

3640
else:
37-
from .modeling_ace_step import AceStepConditionEncoder
41+
from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder
3842
from .pipeline_ace_step import AceStepPipeline
3943

4044
else:

0 commit comments

Comments
 (0)