1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import re
16+
1517import torch
1618from transformers import ByT5Tokenizer , Qwen2_5_VLTextModel , Qwen2TokenizerFast , T5EncoderModel
1719
1820from ...configuration_utils import FrozenDict
1921from ...guiders import ClassifierFreeGuidance
20- from ...pipelines .hunyuan_video1_5 .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
2122from ...utils import logging
2223from ..modular_pipeline import ModularPipelineBlocks , PipelineState
2324from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
2728logger = logging .get_logger (__name__ )
2829
2930
31+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input
32+ def format_text_input (prompt , system_message ):
33+ return [
34+ [{"role" : "system" , "content" : system_message }, {"role" : "user" , "content" : p if p else " " }] for p in prompt
35+ ]
36+
37+
38+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts
39+ def extract_glyph_texts (prompt ):
40+ pattern = r"\"(.*?)\"|\"(.*?)\""
41+ matches = re .findall (pattern , prompt )
42+ result = [match [0 ] or match [1 ] for match in matches ]
43+ result = list (dict .fromkeys (result )) if len (result ) > 1 else result
44+ if result :
45+ formatted_result = ". " .join ([f'Text "{ text } "' for text in result ]) + ". "
46+ else :
47+ formatted_result = None
48+ return formatted_result
49+
50+
51+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds
52+ def _get_mllm_prompt_embeds (
53+ text_encoder ,
54+ tokenizer ,
55+ prompt ,
56+ device ,
57+ tokenizer_max_length = 1000 ,
58+ num_hidden_layers_to_skip = 2 ,
59+ # fmt: off
60+ system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
61+ 1. The main content and theme of the video. \
62+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
63+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
64+ 4. background environment, light, style and atmosphere. \
65+ 5. camera angles, movements, and transitions used in the video." ,
66+ # fmt: on
67+ crop_start = 108 ,
68+ ):
69+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
70+ prompt = format_text_input (prompt , system_message )
71+
72+ text_inputs = tokenizer .apply_chat_template (
73+ prompt ,
74+ add_generation_prompt = True ,
75+ tokenize = True ,
76+ return_dict = True ,
77+ padding = "max_length" ,
78+ max_length = tokenizer_max_length + crop_start ,
79+ truncation = True ,
80+ return_tensors = "pt" ,
81+ )
82+
83+ text_input_ids = text_inputs .input_ids .to (device = device )
84+ prompt_attention_mask = text_inputs .attention_mask .to (device = device )
85+
86+ prompt_embeds = text_encoder (
87+ input_ids = text_input_ids ,
88+ attention_mask = prompt_attention_mask ,
89+ output_hidden_states = True ,
90+ ).hidden_states [- (num_hidden_layers_to_skip + 1 )]
91+
92+ if crop_start is not None and crop_start > 0 :
93+ prompt_embeds = prompt_embeds [:, crop_start :]
94+ prompt_attention_mask = prompt_attention_mask [:, crop_start :]
95+
96+ return prompt_embeds , prompt_attention_mask
97+
98+
99+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds
100+ def _get_byt5_prompt_embeds (tokenizer , text_encoder , prompt , device , tokenizer_max_length = 256 ):
101+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
102+ glyph_texts = [extract_glyph_texts (p ) for p in prompt ]
103+
104+ prompt_embeds_list = []
105+ prompt_embeds_mask_list = []
106+
107+ for glyph_text in glyph_texts :
108+ if glyph_text is None :
109+ glyph_text_embeds = torch .zeros (
110+ (1 , tokenizer_max_length , text_encoder .config .d_model ), device = device , dtype = text_encoder .dtype
111+ )
112+ glyph_text_embeds_mask = torch .zeros ((1 , tokenizer_max_length ), device = device , dtype = torch .int64 )
113+ else :
114+ txt_tokens = tokenizer (
115+ glyph_text ,
116+ padding = "max_length" ,
117+ max_length = tokenizer_max_length ,
118+ truncation = True ,
119+ add_special_tokens = True ,
120+ return_tensors = "pt" ,
121+ ).to (device )
122+
123+ glyph_text_embeds = text_encoder (
124+ input_ids = txt_tokens .input_ids ,
125+ attention_mask = txt_tokens .attention_mask .float (),
126+ )[0 ]
127+ glyph_text_embeds = glyph_text_embeds .to (device = device )
128+ glyph_text_embeds_mask = txt_tokens .attention_mask .to (device = device )
129+
130+ prompt_embeds_list .append (glyph_text_embeds )
131+ prompt_embeds_mask_list .append (glyph_text_embeds_mask )
132+
133+ return torch .cat (prompt_embeds_list , dim = 0 ), torch .cat (prompt_embeds_mask_list , dim = 0 )
134+
135+
30136class HunyuanVideo15TextEncoderStep (ModularPipelineBlocks ):
31137 model_name = "hunyuan-video-1.5"
32138
@@ -78,38 +184,29 @@ def intermediate_outputs(self) -> list[OutputParam]:
78184 OutputParam ("negative_prompt_embeds_mask_2" , type_hint = torch .Tensor , kwargs_type = "denoiser_input_fields" ),
79185 ]
80186
81- # Copied from HunyuanVideo15Pipeline.encode_prompt
82- @torch .no_grad ()
83- def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
84- block_state = self .get_block_state (state )
85- device = components ._execution_device
86- dtype = components .transformer .dtype
87-
88- prompt = block_state .prompt
89- negative_prompt = block_state .negative_prompt
90- num_videos_per_prompt = block_state .num_videos_per_prompt
91-
92- if prompt is not None and isinstance (prompt , str ):
93- batch_size = 1
94- elif prompt is not None and isinstance (prompt , list ):
95- batch_size = len (prompt )
96- elif getattr (block_state , "prompt_embeds" , None ) is not None :
97- batch_size = block_state .prompt_embeds .shape [0 ]
98- else :
99- batch_size = 1
100-
101- # Encode positive prompt - copied from HunyuanVideo15Pipeline.encode_prompt
102- prompt_embeds = getattr (block_state , "prompt_embeds" , None )
103- prompt_embeds_mask = getattr (block_state , "prompt_embeds_mask" , None )
104- prompt_embeds_2 = getattr (block_state , "prompt_embeds_2" , None )
105- prompt_embeds_mask_2 = getattr (block_state , "prompt_embeds_mask_2" , None )
187+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt with self->components
188+ @staticmethod
189+ def encode_prompt (
190+ components ,
191+ prompt ,
192+ device = None ,
193+ dtype = None ,
194+ batch_size = 1 ,
195+ num_videos_per_prompt = 1 ,
196+ prompt_embeds = None ,
197+ prompt_embeds_mask = None ,
198+ prompt_embeds_2 = None ,
199+ prompt_embeds_mask_2 = None ,
200+ ):
201+ device = device or components ._execution_device
202+ dtype = dtype or components .text_encoder .dtype
106203
107204 if prompt is None :
108205 prompt = ["" ] * batch_size
109206 prompt = [prompt ] if isinstance (prompt , str ) else prompt
110207
111208 if prompt_embeds is None :
112- prompt_embeds , prompt_embeds_mask = HunyuanVideo15Pipeline . _get_mllm_prompt_embeds (
209+ prompt_embeds , prompt_embeds_mask = _get_mllm_prompt_embeds (
113210 tokenizer = components .tokenizer ,
114211 text_encoder = components .text_encoder ,
115212 prompt = prompt ,
@@ -120,7 +217,7 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
120217 )
121218
122219 if prompt_embeds_2 is None :
123- prompt_embeds_2 , prompt_embeds_mask_2 = HunyuanVideo15Pipeline . _get_byt5_prompt_embeds (
220+ prompt_embeds_2 , prompt_embeds_mask_2 = _get_byt5_prompt_embeds (
124221 tokenizer = components .tokenizer_2 ,
125222 text_encoder = components .text_encoder_2 ,
126223 prompt = prompt ,
@@ -136,57 +233,69 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
136233 prompt_embeds_2 = prompt_embeds_2 .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len_2 , - 1 )
137234 prompt_embeds_mask_2 = prompt_embeds_mask_2 .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len_2 )
138235
139- block_state .prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
140- block_state .prompt_embeds_mask = prompt_embeds_mask .to (dtype = dtype , device = device )
141- block_state .prompt_embeds_2 = prompt_embeds_2 .to (dtype = dtype , device = device )
142- block_state .prompt_embeds_mask_2 = prompt_embeds_mask_2 .to (dtype = dtype , device = device )
236+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
237+ prompt_embeds_mask = prompt_embeds_mask .to (dtype = dtype , device = device )
238+ prompt_embeds_2 = prompt_embeds_2 .to (dtype = dtype , device = device )
239+ prompt_embeds_mask_2 = prompt_embeds_mask_2 .to (dtype = dtype , device = device )
240+
241+ return prompt_embeds , prompt_embeds_mask , prompt_embeds_2 , prompt_embeds_mask_2
242+
243+ @torch .no_grad ()
244+ def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
245+ block_state = self .get_block_state (state )
246+ device = components ._execution_device
247+ dtype = components .transformer .dtype
248+
249+ prompt = block_state .prompt
250+ negative_prompt = block_state .negative_prompt
251+ num_videos_per_prompt = block_state .num_videos_per_prompt
252+
253+ if prompt is not None and isinstance (prompt , str ):
254+ batch_size = 1
255+ elif prompt is not None and isinstance (prompt , list ):
256+ batch_size = len (prompt )
257+ elif getattr (block_state , "prompt_embeds" , None ) is not None :
258+ batch_size = block_state .prompt_embeds .shape [0 ]
259+ else :
260+ batch_size = 1
261+
262+ (
263+ block_state .prompt_embeds ,
264+ block_state .prompt_embeds_mask ,
265+ block_state .prompt_embeds_2 ,
266+ block_state .prompt_embeds_mask_2 ,
267+ ) = self .encode_prompt (
268+ components ,
269+ prompt = prompt ,
270+ device = device ,
271+ dtype = dtype ,
272+ batch_size = batch_size ,
273+ num_videos_per_prompt = num_videos_per_prompt ,
274+ prompt_embeds = getattr (block_state , "prompt_embeds" , None ),
275+ prompt_embeds_mask = getattr (block_state , "prompt_embeds_mask" , None ),
276+ prompt_embeds_2 = getattr (block_state , "prompt_embeds_2" , None ),
277+ prompt_embeds_mask_2 = getattr (block_state , "prompt_embeds_mask_2" , None ),
278+ )
143279
144- # Encode negative prompt if guider needs it
145280 if components .requires_unconditional_embeds :
146- neg_prompt_embeds = getattr (block_state , "negative_prompt_embeds" , None )
147- neg_prompt_embeds_mask = getattr (block_state , "negative_prompt_embeds_mask" , None )
148- neg_prompt_embeds_2 = getattr (block_state , "negative_prompt_embeds_2" , None )
149- neg_prompt_embeds_mask_2 = getattr (block_state , "negative_prompt_embeds_mask_2" , None )
150-
151- neg_prompt = negative_prompt
152- if neg_prompt is None :
153- neg_prompt = ["" ] * batch_size
154- neg_prompt = [neg_prompt ] if isinstance (neg_prompt , str ) else neg_prompt
155-
156- if neg_prompt_embeds is None :
157- neg_prompt_embeds , neg_prompt_embeds_mask = HunyuanVideo15Pipeline ._get_mllm_prompt_embeds (
158- tokenizer = components .tokenizer ,
159- text_encoder = components .text_encoder ,
160- prompt = neg_prompt ,
161- device = device ,
162- tokenizer_max_length = components .tokenizer_max_length ,
163- system_message = components .system_message ,
164- crop_start = components .prompt_template_encode_start_idx ,
165- )
166-
167- if neg_prompt_embeds_2 is None :
168- neg_prompt_embeds_2 , neg_prompt_embeds_mask_2 = HunyuanVideo15Pipeline ._get_byt5_prompt_embeds (
169- tokenizer = components .tokenizer_2 ,
170- text_encoder = components .text_encoder_2 ,
171- prompt = neg_prompt ,
172- device = device ,
173- tokenizer_max_length = components .tokenizer_2_max_length ,
174- )
175-
176- _ , seq_len , _ = neg_prompt_embeds .shape
177- neg_prompt_embeds = neg_prompt_embeds .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len , - 1 )
178- neg_prompt_embeds_mask = neg_prompt_embeds_mask .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len )
179-
180- _ , seq_len_2 , _ = neg_prompt_embeds_2 .shape
181- neg_prompt_embeds_2 = neg_prompt_embeds_2 .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len_2 , - 1 )
182- neg_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2 .repeat (1 , num_videos_per_prompt , 1 ).view (batch_size * num_videos_per_prompt , seq_len_2 )
183-
184- block_state .negative_prompt_embeds = neg_prompt_embeds .to (dtype = dtype , device = device )
185- block_state .negative_prompt_embeds_mask = neg_prompt_embeds_mask .to (dtype = dtype , device = device )
186- block_state .negative_prompt_embeds_2 = neg_prompt_embeds_2 .to (dtype = dtype , device = device )
187- block_state .negative_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2 .to (dtype = dtype , device = device )
188-
189- # Pass batch_size downstream
281+ (
282+ block_state .negative_prompt_embeds ,
283+ block_state .negative_prompt_embeds_mask ,
284+ block_state .negative_prompt_embeds_2 ,
285+ block_state .negative_prompt_embeds_mask_2 ,
286+ ) = self .encode_prompt (
287+ components ,
288+ prompt = negative_prompt ,
289+ device = device ,
290+ dtype = dtype ,
291+ batch_size = batch_size ,
292+ num_videos_per_prompt = num_videos_per_prompt ,
293+ prompt_embeds = getattr (block_state , "negative_prompt_embeds" , None ),
294+ prompt_embeds_mask = getattr (block_state , "negative_prompt_embeds_mask" , None ),
295+ prompt_embeds_2 = getattr (block_state , "negative_prompt_embeds_2" , None ),
296+ prompt_embeds_mask_2 = getattr (block_state , "negative_prompt_embeds_mask_2" , None ),
297+ )
298+
190299 state .set ("batch_size" , batch_size )
191300
192301 self .set_block_state (state , block_state )
0 commit comments