@@ -501,24 +501,45 @@ def encode_prompt(
501501 negative_prompt_embeds : jax .Array = None ,
502502 ):
503503 prompt = [prompt ] if isinstance (prompt , str ) else prompt
504- if prompt_embeds is None :
505- prompt_embeds = self ._get_t5_prompt_embeds (
506- prompt = prompt ,
507- num_videos_per_prompt = num_videos_per_prompt ,
508- max_sequence_length = max_sequence_length ,
509- )
510- prompt_embeds = jnp .array (prompt_embeds .detach ().float ().numpy (), dtype = jnp .float32 )
511-
512- if negative_prompt_embeds is None :
513- batch_size = len (prompt_embeds )
514- negative_prompt = negative_prompt or ""
515- negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
516- negative_prompt_embeds = self ._get_t5_prompt_embeds (
517- prompt = negative_prompt ,
504+ batch_size = len (prompt )
505+
506+ if negative_prompt is None :
507+ negative_prompt = ["" ] * batch_size
508+ elif isinstance (negative_prompt , str ):
509+ negative_prompt = [negative_prompt ] * batch_size
510+
511+ use_batched_text_encoder = getattr (self .config , "use_batched_text_encoder" , False )
512+ if use_batched_text_encoder and prompt_embeds is None and negative_prompt_embeds is None :
513+ # Batch both together
514+ combined_prompts = prompt + negative_prompt
515+ combined_embeds = self ._get_t5_prompt_embeds (
516+ prompt = combined_prompts ,
518517 num_videos_per_prompt = num_videos_per_prompt ,
519518 max_sequence_length = max_sequence_length ,
520519 )
521- negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().float ().numpy (), dtype = jnp .float32 )
520+ combined_embeds = jnp .array (combined_embeds .detach ().float ().numpy (), dtype = jnp .float32 )
521+
522+ # Split back
523+ prompt_embeds = combined_embeds [: batch_size * num_videos_per_prompt ]
524+ negative_prompt_embeds = combined_embeds [batch_size * num_videos_per_prompt :]
525+
526+ else :
527+ # Fallback to separate encoding if one of them is already provided
528+ if prompt_embeds is None :
529+ prompt_embeds = self ._get_t5_prompt_embeds (
530+ prompt = prompt ,
531+ num_videos_per_prompt = num_videos_per_prompt ,
532+ max_sequence_length = max_sequence_length ,
533+ )
534+ prompt_embeds = jnp .array (prompt_embeds .detach ().float ().numpy (), dtype = jnp .float32 )
535+
536+ if negative_prompt_embeds is None :
537+ negative_prompt_embeds = self ._get_t5_prompt_embeds (
538+ prompt = negative_prompt ,
539+ num_videos_per_prompt = num_videos_per_prompt ,
540+ max_sequence_length = max_sequence_length ,
541+ )
542+ negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().float ().numpy (), dtype = jnp .float32 )
522543
523544 return prompt_embeds , negative_prompt_embeds
524545
0 commit comments