Skip to content

Commit e8176d2

Browse files
committed
Inline all standard pipeline methods, remove runtime dependency
1 parent 3953a25 commit e8176d2

2 files changed

Lines changed: 212 additions & 100 deletions

File tree

src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from ...models import HunyuanVideo15Transformer3DModel
21-
from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
2221
from ...schedulers import FlowMatchEulerDiscreteScheduler
2322
from ...utils import logging
2423
from ...utils.torch_utils import randn_tensor
@@ -169,14 +168,13 @@ def intermediate_outputs(self) -> list[OutputParam]:
169168
OutputParam("image_embeds", type_hint=torch.Tensor),
170169
]
171170

172-
# Copied from pipeline_hunyuan_video1_5.py lines 652-655, 706-725
171+
# Copied from pipeline_hunyuan_video1_5.py lines 652-655, 477-524, 706-725 with self->components
173172
@torch.no_grad()
174173
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
175174
block_state = self.get_block_state(state)
176175
device = components._execution_device
177176
dtype = block_state.dtype
178177

179-
# Calculate default height/width if not provided (line 652-655)
180178
height = block_state.height
181179
width = block_state.width
182180
if height is None and width is None:
@@ -187,28 +185,33 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
187185
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
188186
num_frames = block_state.num_frames
189187

190-
# Copied from HunyuanVideo15Pipeline.prepare_latents (lines 477-505, 707-717)
191-
block_state.latents = HunyuanVideo15Pipeline.prepare_latents(
192-
components,
193-
batch_size,
194-
components.num_channels_latents,
195-
height,
196-
width,
197-
num_frames,
198-
dtype,
199-
device,
200-
block_state.generator,
201-
block_state.latents,
202-
)
188+
# Copied from HunyuanVideo15Pipeline.prepare_latents with self->components
189+
latents = block_state.latents
190+
if latents is not None:
191+
latents = latents.to(device=device, dtype=dtype)
192+
else:
193+
shape = (
194+
batch_size,
195+
components.num_channels_latents,
196+
(num_frames - 1) // components.vae_scale_factor_temporal + 1,
197+
int(height) // components.vae_scale_factor_spatial,
198+
int(width) // components.vae_scale_factor_spatial,
199+
)
200+
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
201+
raise ValueError(
202+
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
203+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
204+
)
205+
latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype)
203206

204-
# Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask (lines 508-524, 718)
205-
cond_latents_concat, mask_concat = HunyuanVideo15Pipeline.prepare_cond_latents_and_mask(
206-
components, block_state.latents, dtype, device
207-
)
208-
block_state.cond_latents_concat = cond_latents_concat
209-
block_state.mask_concat = mask_concat
207+
block_state.latents = latents
208+
209+
# Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask with self->components
210+
b, c, f, h, w = latents.shape
211+
block_state.cond_latents_concat = torch.zeros(b, c, f, h, w, dtype=dtype, device=device)
212+
block_state.mask_concat = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device)
210213

211-
# T2V: zero image_embeds (line 719-725)
214+
# T2V: zero image_embeds
212215
block_state.image_embeds = torch.zeros(
213216
block_state.batch_size,
214217
components.vision_num_semantic_tokens,

src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py

Lines changed: 186 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
16+
1517
import torch
1618
from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2TokenizerFast, T5EncoderModel
1719

1820
from ...configuration_utils import FrozenDict
1921
from ...guiders import ClassifierFreeGuidance
20-
from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
2122
from ...utils import logging
2223
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2324
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -27,6 +28,111 @@
2728
logger = logging.get_logger(__name__)
2829

2930

31+
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input
32+
def format_text_input(prompt, system_message):
33+
return [
34+
[{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt
35+
]
36+
37+
38+
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts
39+
def extract_glyph_texts(prompt):
40+
pattern = r"\"(.*?)\"|\"(.*?)\""
41+
matches = re.findall(pattern, prompt)
42+
result = [match[0] or match[1] for match in matches]
43+
result = list(dict.fromkeys(result)) if len(result) > 1 else result
44+
if result:
45+
formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
46+
else:
47+
formatted_result = None
48+
return formatted_result
49+
50+
51+
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds
52+
def _get_mllm_prompt_embeds(
53+
text_encoder,
54+
tokenizer,
55+
prompt,
56+
device,
57+
tokenizer_max_length=1000,
58+
num_hidden_layers_to_skip=2,
59+
# fmt: off
60+
system_message="You are a helpful assistant. Describe the video by detailing the following aspects: \
61+
1. The main content and theme of the video. \
62+
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
63+
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
64+
4. background environment, light, style and atmosphere. \
65+
5. camera angles, movements, and transitions used in the video.",
66+
# fmt: on
67+
crop_start=108,
68+
):
69+
prompt = [prompt] if isinstance(prompt, str) else prompt
70+
prompt = format_text_input(prompt, system_message)
71+
72+
text_inputs = tokenizer.apply_chat_template(
73+
prompt,
74+
add_generation_prompt=True,
75+
tokenize=True,
76+
return_dict=True,
77+
padding="max_length",
78+
max_length=tokenizer_max_length + crop_start,
79+
truncation=True,
80+
return_tensors="pt",
81+
)
82+
83+
text_input_ids = text_inputs.input_ids.to(device=device)
84+
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
85+
86+
prompt_embeds = text_encoder(
87+
input_ids=text_input_ids,
88+
attention_mask=prompt_attention_mask,
89+
output_hidden_states=True,
90+
).hidden_states[-(num_hidden_layers_to_skip + 1)]
91+
92+
if crop_start is not None and crop_start > 0:
93+
prompt_embeds = prompt_embeds[:, crop_start:]
94+
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
95+
96+
return prompt_embeds, prompt_attention_mask
97+
98+
99+
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds
100+
def _get_byt5_prompt_embeds(tokenizer, text_encoder, prompt, device, tokenizer_max_length=256):
101+
prompt = [prompt] if isinstance(prompt, str) else prompt
102+
glyph_texts = [extract_glyph_texts(p) for p in prompt]
103+
104+
prompt_embeds_list = []
105+
prompt_embeds_mask_list = []
106+
107+
for glyph_text in glyph_texts:
108+
if glyph_text is None:
109+
glyph_text_embeds = torch.zeros(
110+
(1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype
111+
)
112+
glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64)
113+
else:
114+
txt_tokens = tokenizer(
115+
glyph_text,
116+
padding="max_length",
117+
max_length=tokenizer_max_length,
118+
truncation=True,
119+
add_special_tokens=True,
120+
return_tensors="pt",
121+
).to(device)
122+
123+
glyph_text_embeds = text_encoder(
124+
input_ids=txt_tokens.input_ids,
125+
attention_mask=txt_tokens.attention_mask.float(),
126+
)[0]
127+
glyph_text_embeds = glyph_text_embeds.to(device=device)
128+
glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
129+
130+
prompt_embeds_list.append(glyph_text_embeds)
131+
prompt_embeds_mask_list.append(glyph_text_embeds_mask)
132+
133+
return torch.cat(prompt_embeds_list, dim=0), torch.cat(prompt_embeds_mask_list, dim=0)
134+
135+
30136
class HunyuanVideo15TextEncoderStep(ModularPipelineBlocks):
31137
model_name = "hunyuan-video-1.5"
32138

@@ -78,38 +184,29 @@ def intermediate_outputs(self) -> list[OutputParam]:
78184
OutputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
79185
]
80186

81-
# Copied from HunyuanVideo15Pipeline.encode_prompt
82-
@torch.no_grad()
83-
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
84-
block_state = self.get_block_state(state)
85-
device = components._execution_device
86-
dtype = components.transformer.dtype
87-
88-
prompt = block_state.prompt
89-
negative_prompt = block_state.negative_prompt
90-
num_videos_per_prompt = block_state.num_videos_per_prompt
91-
92-
if prompt is not None and isinstance(prompt, str):
93-
batch_size = 1
94-
elif prompt is not None and isinstance(prompt, list):
95-
batch_size = len(prompt)
96-
elif getattr(block_state, "prompt_embeds", None) is not None:
97-
batch_size = block_state.prompt_embeds.shape[0]
98-
else:
99-
batch_size = 1
100-
101-
# Encode positive prompt - copied from HunyuanVideo15Pipeline.encode_prompt
102-
prompt_embeds = getattr(block_state, "prompt_embeds", None)
103-
prompt_embeds_mask = getattr(block_state, "prompt_embeds_mask", None)
104-
prompt_embeds_2 = getattr(block_state, "prompt_embeds_2", None)
105-
prompt_embeds_mask_2 = getattr(block_state, "prompt_embeds_mask_2", None)
187+
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt with self->components
188+
@staticmethod
189+
def encode_prompt(
190+
components,
191+
prompt,
192+
device=None,
193+
dtype=None,
194+
batch_size=1,
195+
num_videos_per_prompt=1,
196+
prompt_embeds=None,
197+
prompt_embeds_mask=None,
198+
prompt_embeds_2=None,
199+
prompt_embeds_mask_2=None,
200+
):
201+
device = device or components._execution_device
202+
dtype = dtype or components.text_encoder.dtype
106203

107204
if prompt is None:
108205
prompt = [""] * batch_size
109206
prompt = [prompt] if isinstance(prompt, str) else prompt
110207

111208
if prompt_embeds is None:
112-
prompt_embeds, prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds(
209+
prompt_embeds, prompt_embeds_mask = _get_mllm_prompt_embeds(
113210
tokenizer=components.tokenizer,
114211
text_encoder=components.text_encoder,
115212
prompt=prompt,
@@ -120,7 +217,7 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
120217
)
121218

122219
if prompt_embeds_2 is None:
123-
prompt_embeds_2, prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds(
220+
prompt_embeds_2, prompt_embeds_mask_2 = _get_byt5_prompt_embeds(
124221
tokenizer=components.tokenizer_2,
125222
text_encoder=components.text_encoder_2,
126223
prompt=prompt,
@@ -136,57 +233,69 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
136233
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1)
137234
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2)
138235

139-
block_state.prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
140-
block_state.prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
141-
block_state.prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
142-
block_state.prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
236+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
237+
prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
238+
prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
239+
prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
240+
241+
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
242+
243+
@torch.no_grad()
244+
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
245+
block_state = self.get_block_state(state)
246+
device = components._execution_device
247+
dtype = components.transformer.dtype
248+
249+
prompt = block_state.prompt
250+
negative_prompt = block_state.negative_prompt
251+
num_videos_per_prompt = block_state.num_videos_per_prompt
252+
253+
if prompt is not None and isinstance(prompt, str):
254+
batch_size = 1
255+
elif prompt is not None and isinstance(prompt, list):
256+
batch_size = len(prompt)
257+
elif getattr(block_state, "prompt_embeds", None) is not None:
258+
batch_size = block_state.prompt_embeds.shape[0]
259+
else:
260+
batch_size = 1
261+
262+
(
263+
block_state.prompt_embeds,
264+
block_state.prompt_embeds_mask,
265+
block_state.prompt_embeds_2,
266+
block_state.prompt_embeds_mask_2,
267+
) = self.encode_prompt(
268+
components,
269+
prompt=prompt,
270+
device=device,
271+
dtype=dtype,
272+
batch_size=batch_size,
273+
num_videos_per_prompt=num_videos_per_prompt,
274+
prompt_embeds=getattr(block_state, "prompt_embeds", None),
275+
prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None),
276+
prompt_embeds_2=getattr(block_state, "prompt_embeds_2", None),
277+
prompt_embeds_mask_2=getattr(block_state, "prompt_embeds_mask_2", None),
278+
)
143279

144-
# Encode negative prompt if guider needs it
145280
if components.requires_unconditional_embeds:
146-
neg_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None)
147-
neg_prompt_embeds_mask = getattr(block_state, "negative_prompt_embeds_mask", None)
148-
neg_prompt_embeds_2 = getattr(block_state, "negative_prompt_embeds_2", None)
149-
neg_prompt_embeds_mask_2 = getattr(block_state, "negative_prompt_embeds_mask_2", None)
150-
151-
neg_prompt = negative_prompt
152-
if neg_prompt is None:
153-
neg_prompt = [""] * batch_size
154-
neg_prompt = [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt
155-
156-
if neg_prompt_embeds is None:
157-
neg_prompt_embeds, neg_prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds(
158-
tokenizer=components.tokenizer,
159-
text_encoder=components.text_encoder,
160-
prompt=neg_prompt,
161-
device=device,
162-
tokenizer_max_length=components.tokenizer_max_length,
163-
system_message=components.system_message,
164-
crop_start=components.prompt_template_encode_start_idx,
165-
)
166-
167-
if neg_prompt_embeds_2 is None:
168-
neg_prompt_embeds_2, neg_prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds(
169-
tokenizer=components.tokenizer_2,
170-
text_encoder=components.text_encoder_2,
171-
prompt=neg_prompt,
172-
device=device,
173-
tokenizer_max_length=components.tokenizer_2_max_length,
174-
)
175-
176-
_, seq_len, _ = neg_prompt_embeds.shape
177-
neg_prompt_embeds = neg_prompt_embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1)
178-
neg_prompt_embeds_mask = neg_prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len)
179-
180-
_, seq_len_2, _ = neg_prompt_embeds_2.shape
181-
neg_prompt_embeds_2 = neg_prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1)
182-
neg_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2)
183-
184-
block_state.negative_prompt_embeds = neg_prompt_embeds.to(dtype=dtype, device=device)
185-
block_state.negative_prompt_embeds_mask = neg_prompt_embeds_mask.to(dtype=dtype, device=device)
186-
block_state.negative_prompt_embeds_2 = neg_prompt_embeds_2.to(dtype=dtype, device=device)
187-
block_state.negative_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.to(dtype=dtype, device=device)
188-
189-
# Pass batch_size downstream
281+
(
282+
block_state.negative_prompt_embeds,
283+
block_state.negative_prompt_embeds_mask,
284+
block_state.negative_prompt_embeds_2,
285+
block_state.negative_prompt_embeds_mask_2,
286+
) = self.encode_prompt(
287+
components,
288+
prompt=negative_prompt,
289+
device=device,
290+
dtype=dtype,
291+
batch_size=batch_size,
292+
num_videos_per_prompt=num_videos_per_prompt,
293+
prompt_embeds=getattr(block_state, "negative_prompt_embeds", None),
294+
prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None),
295+
prompt_embeds_2=getattr(block_state, "negative_prompt_embeds_2", None),
296+
prompt_embeds_mask_2=getattr(block_state, "negative_prompt_embeds_mask_2", None),
297+
)
298+
190299
state.set("batch_size", batch_size)
191300

192301
self.set_block_state(state, block_state)

0 commit comments

Comments
 (0)