|
26 | 26 | from huggingface_hub.utils import validate_hf_hub_args |
27 | 27 | from safetensors.torch import load_file |
28 | 28 | from torch.nn.utils.rnn import pad_sequence |
29 | | -from transformers import PreTrainedTokenizerBase, T5Tokenizer, UMT5Config, UMT5EncoderModel |
| 29 | +from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel |
30 | 30 |
|
31 | 31 | from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae |
32 | 32 | from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging |
@@ -105,7 +105,7 @@ def _load_longcat_tokenizer( |
105 | 105 | tokenizer_kwargs = {"local_files_only": local_files_only} |
106 | 106 | if not isinstance(tokenizer_source, Path) and tokenizer_source == pretrained_model_name_or_path and subfolder: |
107 | 107 | tokenizer_kwargs["subfolder"] = subfolder |
108 | | - return T5Tokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) |
| 108 | + return AutoTokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) |
109 | 109 |
|
110 | 110 |
|
111 | 111 | def _resolve_longcat_file( |
@@ -278,6 +278,10 @@ def from_pretrained( |
278 | 278 | transformer = transformer.to(dtype=torch_dtype) |
279 | 279 | vae = vae.to(dtype=torch_dtype) |
280 | 280 |
|
| 281 | + text_encoder.eval() |
| 282 | + transformer.eval() |
| 283 | + vae.eval() |
| 284 | + |
281 | 285 | pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) |
282 | 286 | pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) |
283 | 287 | pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) |
@@ -322,15 +326,24 @@ def prepare_latents( |
322 | 326 | dtype: torch.dtype, |
323 | 327 | generator: torch.Generator | list[torch.Generator] | None = None, |
324 | 328 | ) -> torch.Tensor: |
| 329 | + if isinstance(generator, list): |
| 330 | + if len(generator) != batch_size: |
| 331 | + raise ValueError( |
| 332 | + f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}." |
| 333 | + ) |
| 334 | + generators = generator |
| 335 | + else: |
| 336 | + generators = [generator] * batch_size |
| 337 | + |
325 | 338 | latents = [ |
326 | 339 | torch.randn( |
327 | 340 | duration, |
328 | 341 | self.latent_dim, |
329 | 342 | device=device, |
330 | 343 | dtype=dtype, |
331 | | - generator=generator if isinstance(generator, torch.Generator) else None, |
| 344 | + generator=generators[idx], |
332 | 345 | ) |
333 | | - for _ in range(batch_size) |
| 346 | + for idx in range(batch_size) |
334 | 347 | ] |
335 | 348 | return pad_sequence(latents, padding_value=0.0, batch_first=True) |
336 | 349 |
|
@@ -409,7 +422,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens |
409 | 422 | attention_mask=mask, |
410 | 423 | latent_cond=latent_cond, |
411 | 424 | ).sample |
412 | | - return pred + (pred - null_pred) * guidance_scale |
| 425 | + return null_pred + (pred - null_pred) * guidance_scale |
413 | 426 |
|
414 | 427 | for idx in range(len(timesteps) - 1): |
415 | 428 | curr_t = timesteps[idx] |
|
0 commit comments