Skip to content

Commit d2a2621

Browse files
committed
upd
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 63d874d commit d2a2621

File tree

3 files changed

+206
-126
lines changed

3 files changed

+206
-126
lines changed

docs/source/en/api/pipelines/longcat_audio_dit.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,32 @@ This pipeline was adapted from the LongCat-AudioDiT reference implementation: ht
2626
## Usage
2727

2828
```py
29+
import soundfile as sf
2930
import torch
3031
from diffusers import LongCatAudioDiTPipeline
3132

32-
repo_id = "<longcat-audio-dit-repo-id>"
33-
tokenizer_path = os.environ["LONGCAT_AUDIO_DIT_TOKENIZER_PATH"]
34-
35-
pipe = LongCatAudioDiTPipeline.from_pretrained(
36-
repo_id,
37-
tokenizer=tokenizer_path,
33+
pipeline = LongCatAudioDiTPipeline.from_pretrained(
34+
"meituan-longcat/LongCat-AudioDiT-1B",
3835
torch_dtype=torch.float16,
39-
local_files_only=True,
4036
)
41-
pipe = pipe.to("cuda")
37+
pipeline = pipeline.to("cuda")
4238

43-
audio = pipe(
39+
audio = pipeline(
4440
prompt="A calm ocean wave ambience with soft wind in the background.",
45-
audio_end_in_s=2.0,
41+
audio_end_in_s=5.0,
4642
num_inference_steps=16,
4743
guidance_scale=4.0,
4844
output_type="pt",
4945
).audios
46+
47+
output = audio[0, 0].float().cpu().numpy()
48+
sf.write("longcat.wav", output, pipeline.sample_rate)
5049
```
5150

5251
## Tips
5352

5453
- `audio_end_in_s` is the most direct way to control output duration.
5554
- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`.
56-
- If your tokenizer path is local-only, pass it explicitly to `from_pretrained(...)`.
5755

5856
## LongCatAudioDiTPipeline
5957

src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from huggingface_hub.utils import validate_hf_hub_args
2727
from safetensors.torch import load_file
2828
from torch.nn.utils.rnn import pad_sequence
29-
from transformers import PreTrainedTokenizerBase, T5Tokenizer, UMT5Config, UMT5EncoderModel
29+
from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel
3030

3131
from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae
3232
from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
@@ -105,7 +105,7 @@ def _load_longcat_tokenizer(
105105
tokenizer_kwargs = {"local_files_only": local_files_only}
106106
if not isinstance(tokenizer_source, Path) and tokenizer_source == pretrained_model_name_or_path and subfolder:
107107
tokenizer_kwargs["subfolder"] = subfolder
108-
return T5Tokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs)
108+
return AutoTokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs)
109109

110110

111111
def _resolve_longcat_file(
@@ -278,6 +278,10 @@ def from_pretrained(
278278
transformer = transformer.to(dtype=torch_dtype)
279279
vae = vae.to(dtype=torch_dtype)
280280

281+
text_encoder.eval()
282+
transformer.eval()
283+
vae.eval()
284+
281285
pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
282286
pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate)
283287
pipe.latent_hop = config.get("latent_hop", pipe.latent_hop)
@@ -322,15 +326,24 @@ def prepare_latents(
322326
dtype: torch.dtype,
323327
generator: torch.Generator | list[torch.Generator] | None = None,
324328
) -> torch.Tensor:
329+
if isinstance(generator, list):
330+
if len(generator) != batch_size:
331+
raise ValueError(
332+
f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}."
333+
)
334+
generators = generator
335+
else:
336+
generators = [generator] * batch_size
337+
325338
latents = [
326339
torch.randn(
327340
duration,
328341
self.latent_dim,
329342
device=device,
330343
dtype=dtype,
331-
generator=generator if isinstance(generator, torch.Generator) else None,
344+
generator=generators[idx],
332345
)
333-
for _ in range(batch_size)
346+
for idx in range(batch_size)
334347
]
335348
return pad_sequence(latents, padding_value=0.0, batch_first=True)
336349

@@ -381,6 +394,12 @@ def __call__(
381394
else:
382395
if isinstance(negative_prompt, str):
383396
negative_prompt = [negative_prompt] * batch_size
397+
else:
398+
negative_prompt = list(negative_prompt)
399+
if len(negative_prompt) != batch_size:
400+
raise ValueError(
401+
f"`negative_prompt` must have batch size {batch_size}, but got {len(negative_prompt)} prompts."
402+
)
384403
neg_text, neg_text_len = self.encode_prompt(negative_prompt, device)
385404
neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1])
386405

@@ -399,7 +418,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens
399418
attention_mask=mask,
400419
latent_cond=latent_cond,
401420
).sample
402-
if guidance_scale < 1e-5:
421+
if guidance_scale <= 1.0:
403422
return pred
404423
null_pred = self.transformer(
405424
hidden_states=current_sample,
@@ -409,7 +428,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens
409428
attention_mask=mask,
410429
latent_cond=latent_cond,
411430
).sample
412-
return pred + (pred - null_pred) * guidance_scale
431+
return null_pred + (pred - null_pred) * guidance_scale
413432

414433
for idx in range(len(timesteps) - 1):
415434
curr_t = timesteps[idx]

0 commit comments

Comments
 (0)