2626from huggingface_hub .utils import validate_hf_hub_args
2727from safetensors .torch import load_file
2828from torch .nn .utils .rnn import pad_sequence
29- from transformers import PreTrainedTokenizerBase , T5Tokenizer , UMT5Config , UMT5EncoderModel
29+ from transformers import AutoTokenizer , PreTrainedTokenizerBase , UMT5Config , UMT5EncoderModel
3030
3131from ...models import LongCatAudioDiTTransformer , LongCatAudioDiTVae
3232from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT , logging
@@ -105,7 +105,7 @@ def _load_longcat_tokenizer(
105105 tokenizer_kwargs = {"local_files_only" : local_files_only }
106106 if not isinstance (tokenizer_source , Path ) and tokenizer_source == pretrained_model_name_or_path and subfolder :
107107 tokenizer_kwargs ["subfolder" ] = subfolder
108- return T5Tokenizer .from_pretrained (tokenizer_source , ** tokenizer_kwargs )
108+ return AutoTokenizer .from_pretrained (tokenizer_source , ** tokenizer_kwargs )
109109
110110
111111def _resolve_longcat_file (
@@ -278,6 +278,10 @@ def from_pretrained(
278278 transformer = transformer .to (dtype = torch_dtype )
279279 vae = vae .to (dtype = torch_dtype )
280280
281+ text_encoder .eval ()
282+ transformer .eval ()
283+ vae .eval ()
284+
281285 pipe = cls (vae = vae , text_encoder = text_encoder , tokenizer = tokenizer , transformer = transformer )
282286 pipe .sample_rate = config .get ("sampling_rate" , pipe .sample_rate )
283287 pipe .latent_hop = config .get ("latent_hop" , pipe .latent_hop )
@@ -322,15 +326,24 @@ def prepare_latents(
322326 dtype : torch .dtype ,
323327 generator : torch .Generator | list [torch .Generator ] | None = None ,
324328 ) -> torch .Tensor :
329+ if isinstance (generator , list ):
330+ if len (generator ) != batch_size :
331+ raise ValueError (
332+ f"Expected { batch_size } generators for batch size { batch_size } , but got { len (generator )} ."
333+ )
334+ generators = generator
335+ else :
336+ generators = [generator ] * batch_size
337+
325338 latents = [
326339 torch .randn (
327340 duration ,
328341 self .latent_dim ,
329342 device = device ,
330343 dtype = dtype ,
331- generator = generator if isinstance ( generator , torch . Generator ) else None ,
344+ generator = generators [ idx ] ,
332345 )
333- for _ in range (batch_size )
346+ for idx in range (batch_size )
334347 ]
335348 return pad_sequence (latents , padding_value = 0.0 , batch_first = True )
336349
@@ -381,6 +394,12 @@ def __call__(
381394 else :
382395 if isinstance (negative_prompt , str ):
383396 negative_prompt = [negative_prompt ] * batch_size
397+ else :
398+ negative_prompt = list (negative_prompt )
399+ if len (negative_prompt ) != batch_size :
400+ raise ValueError (
401+ f"`negative_prompt` must have batch size { batch_size } , but got { len (negative_prompt )} prompts."
402+ )
384403 neg_text , neg_text_len = self .encode_prompt (negative_prompt , device )
385404 neg_text_mask = _lens_to_mask (neg_text_len , length = neg_text .shape [1 ])
386405
@@ -399,7 +418,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens
399418 attention_mask = mask ,
400419 latent_cond = latent_cond ,
401420 ).sample
402- if guidance_scale < 1e-5 :
421+ if guidance_scale <= 1.0 :
403422 return pred
404423 null_pred = self .transformer (
405424 hidden_states = current_sample ,
@@ -409,7 +428,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens
409428 attention_mask = mask ,
410429 latent_cond = latent_cond ,
411430 ).sample
412- return pred + (pred - null_pred ) * guidance_scale
431+ return null_pred + (pred - null_pred ) * guidance_scale
413432
414433 for idx in range (len (timesteps ) - 1 ):
415434 curr_t = timesteps [idx ]
0 commit comments