Skip to content

Commit 84e030c

Browse files
committed
update, remove intermediaate_inputs
1 parent e4393fa commit 84e030c

6 files changed

Lines changed: 198 additions & 187 deletions

File tree

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,6 @@ def description(self) -> str:
112112
def inputs(self) -> List[InputParam]:
113113
return [
114114
InputParam("num_videos_per_prompt", default=1),
115-
]
116-
117-
@property
118-
def intermediate_inputs(self) -> List[str]:
119-
return [
120115
InputParam(
121116
"prompt_embeds",
122117
required=True,
@@ -143,18 +138,6 @@ def intermediate_outputs(self) -> List[str]:
143138
type_hint=torch.dtype,
144139
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
145140
),
146-
OutputParam(
147-
"prompt_embeds",
148-
type_hint=torch.Tensor,
149-
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
150-
description="text embeddings used to guide the image generation",
151-
),
152-
OutputParam(
153-
"negative_prompt_embeds",
154-
type_hint=torch.Tensor,
155-
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
156-
description="negative text embeddings used to guide the image generation",
157-
),
158141
]
159142

160143
def check_inputs(self, components, block_state):
@@ -215,26 +198,16 @@ def inputs(self) -> List[InputParam]:
215198
InputParam("sigmas"),
216199
]
217200

218-
@property
219-
def intermediate_outputs(self) -> List[OutputParam]:
220-
return [
221-
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
222-
OutputParam(
223-
"num_inference_steps",
224-
type_hint=int,
225-
description="The number of denoising steps to perform at inference time",
226-
),
227-
]
228201

229202
@torch.no_grad()
230203
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
231204
block_state = self.get_block_state(state)
232-
block_state.device = components._execution_device
205+
device = components._execution_device
233206

234207
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
235208
components.scheduler,
236209
block_state.num_inference_steps,
237-
block_state.device,
210+
device,
238211
block_state.timesteps,
239212
block_state.sigmas,
240213
)
@@ -246,10 +219,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
246219
class WanPrepareLatentsStep(ModularPipelineBlocks):
247220
model_name = "wan"
248221

249-
@property
250-
def expected_components(self) -> List[ComponentSpec]:
251-
return []
252-
253222
@property
254223
def description(self) -> str:
255224
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@@ -262,11 +231,6 @@ def inputs(self) -> List[InputParam]:
262231
InputParam("num_frames", type_hint=int),
263232
InputParam("latents", type_hint=Optional[torch.Tensor]),
264233
InputParam("num_videos_per_prompt", type_hint=int, default=1),
265-
]
266-
267-
@property
268-
def intermediate_inputs(self) -> List[InputParam]:
269-
return [
270234
InputParam("generator"),
271235
InputParam(
272236
"batch_size",
@@ -337,27 +301,26 @@ def prepare_latents(
337301
@torch.no_grad()
338302
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
339303
block_state = self.get_block_state(state)
304+
self.check_inputs(components, block_state)
305+
306+
device = components._execution_device
307+
dtype = torch.float32 # Wan latents should be torch.float32 for best quality
340308

341309
block_state.height = block_state.height or components.default_height
342310
block_state.width = block_state.width or components.default_width
343311
block_state.num_frames = block_state.num_frames or components.default_num_frames
344-
block_state.device = components._execution_device
345-
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
346-
block_state.num_channels_latents = components.num_channels_latents
347-
348-
self.check_inputs(components, block_state)
349312

350313
block_state.latents = self.prepare_latents(
351314
components,
352-
block_state.batch_size * block_state.num_videos_per_prompt,
353-
block_state.num_channels_latents,
354-
block_state.height,
355-
block_state.width,
356-
block_state.num_frames,
357-
block_state.dtype,
358-
block_state.device,
359-
block_state.generator,
360-
block_state.latents,
315+
batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
316+
num_channels_latents=components.num_channels_latents,
317+
height=block_state.height,
318+
width=block_state.width,
319+
num_frames=block_state.num_frames,
320+
dtype=dtype,
321+
device=device,
322+
generator=block_state.generator,
323+
latents=block_state.latents,
361324
)
362325

363326
self.set_block_state(state, block_state)

src/diffusers/modular_pipelines/wan/decoders.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,6 @@ def description(self) -> str:
5050

5151
@property
5252
def inputs(self) -> List[Tuple[str, Any]]:
53-
return [
54-
InputParam("output_type", default="pil"),
55-
]
56-
57-
@property
58-
def intermediate_inputs(self) -> List[str]:
5953
return [
6054
InputParam(
6155
"latents",
@@ -80,24 +74,21 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
8074
block_state = self.get_block_state(state)
8175
vae_dtype = components.vae.dtype
8276

83-
if not block_state.output_type == "latent":
84-
latents = block_state.latents
85-
latents_mean = (
86-
torch.tensor(components.vae.config.latents_mean)
87-
.view(1, components.vae.config.z_dim, 1, 1, 1)
88-
.to(latents.device, latents.dtype)
89-
)
90-
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
91-
1, components.vae.config.z_dim, 1, 1, 1
92-
).to(latents.device, latents.dtype)
93-
latents = latents / latents_std + latents_mean
94-
latents = latents.to(vae_dtype)
95-
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
96-
else:
97-
block_state.videos = block_state.latents
77+
latents = block_state.latents
78+
latents_mean = (
79+
torch.tensor(components.vae.config.latents_mean)
80+
.view(1, components.vae.config.z_dim, 1, 1, 1)
81+
.to(latents.device, latents.dtype)
82+
)
83+
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
84+
1, components.vae.config.z_dim, 1, 1, 1
85+
).to(latents.device, latents.dtype)
86+
latents = latents / latents_std + latents_mean
87+
latents = latents.to(vae_dtype)
88+
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
9889

9990
block_state.videos = components.video_processor.postprocess_video(
100-
block_state.videos, output_type=block_state.output_type
91+
block_state.videos, output_type="np"
10192
)
10293

10394
self.set_block_state(state, block_state)

src/diffusers/modular_pipelines/wan/denoise.py

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@
2727
ModularPipelineBlocks,
2828
PipelineState,
2929
)
30-
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
30+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
3131
from .modular_pipeline import WanModularPipeline
3232

3333

3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3535

3636

37+
3738
class WanLoopDenoiser(ModularPipelineBlocks):
3839
model_name = "wan"
3940

@@ -61,11 +62,6 @@ def description(self) -> str:
6162
def inputs(self) -> List[Tuple[str, Any]]:
6263
return [
6364
InputParam("attention_kwargs"),
64-
]
65-
66-
@property
67-
def intermediate_inputs(self) -> List[str]:
68-
return [
6965
InputParam(
7066
"latents",
7167
required=True,
@@ -78,14 +74,8 @@ def intermediate_inputs(self) -> List[str]:
7874
type_hint=int,
7975
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
8076
),
81-
InputParam(
82-
kwargs_type="denoiser_input_fields",
83-
description=(
84-
"All conditional model inputs that need to be prepared with guider. "
85-
"It should contain prompt_embeds/negative_prompt_embeds. "
86-
"Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
87-
),
88-
),
77+
InputParam("prompt_embeds", required=True, type_hint=torch.Tensor),
78+
InputParam("negative_prompt_embeds", required=True, type_hint=torch.Tensor),
8979
]
9080

9181
@torch.no_grad()
@@ -95,10 +85,7 @@ def __call__(
9585
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
9686
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
9787
guider_inputs = {
98-
"prompt_embeds": (
99-
getattr(block_state, "prompt_embeds", None),
100-
getattr(block_state, "negative_prompt_embeds", None),
101-
),
88+
"encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds),
10289
}
10390
transformer_dtype = components.transformer.dtype
10491

@@ -118,16 +105,15 @@ def __call__(
118105
for guider_state_batch in guider_state:
119106
components.guider.prepare_models(components.transformer)
120107
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
121-
prompt_embeds = cond_kwargs.pop("prompt_embeds")
122108

123109
# Predict the noise residual
124110
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
125111
guider_state_batch.noise_pred = components.transformer(
126112
hidden_states=block_state.latents.to(transformer_dtype),
127-
timestep=t.flatten(),
128-
encoder_hidden_states=prompt_embeds,
113+
timestep=t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype),
129114
attention_kwargs=block_state.attention_kwargs,
130115
return_dict=False,
116+
**cond_kwargs,
131117
)[0]
132118
components.guider.cleanup_models(components.transformer)
133119

@@ -154,19 +140,6 @@ def description(self) -> str:
154140
"object (e.g. `WanDenoiseLoopWrapper`)"
155141
)
156142

157-
@property
158-
def inputs(self) -> List[Tuple[str, Any]]:
159-
return []
160-
161-
@property
162-
def intermediate_inputs(self) -> List[str]:
163-
return [
164-
InputParam("generator"),
165-
]
166-
167-
@property
168-
def intermediate_outputs(self) -> List[OutputParam]:
169-
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
170143

171144
@torch.no_grad()
172145
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
@@ -198,18 +171,11 @@ def description(self) -> str:
198171
@property
199172
def loop_expected_components(self) -> List[ComponentSpec]:
200173
return [
201-
ComponentSpec(
202-
"guider",
203-
ClassifierFreeGuidance,
204-
config=FrozenDict({"guidance_scale": 5.0}),
205-
default_creation_method="from_config",
206-
),
207174
ComponentSpec("scheduler", UniPCMultistepScheduler),
208-
ComponentSpec("transformer", WanTransformer3DModel),
209175
]
210176

211177
@property
212-
def loop_intermediate_inputs(self) -> List[InputParam]:
178+
def loop_inputs(self) -> List[InputParam]:
213179
return [
214180
InputParam(
215181
"timesteps",
@@ -246,6 +212,81 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
246212
return components, state
247213

248214

215+
# class Wan22DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
216+
# model_name = "wan"
217+
218+
# @property
219+
# def description(self) -> str:
220+
# return (
221+
# "Pipeline block that iteratively denoise the latents over `timesteps`. "
222+
# "The specific steps with each iteration can be customized with `sub_blocks` attributes"
223+
# )
224+
225+
# @property
226+
# def loop_expected_configs(self) -> List[ConfigSpec]:
227+
# return [
228+
# ConfigSpec(
229+
# "boundary_ratio",
230+
# type_hint=float,
231+
# description="The ratio of the total timesteps to use as the boundary for switching between transformers in two-stage denoising.",
232+
# ),
233+
# ]
234+
235+
# @property
236+
# def loop_expected_components(self) -> List[ComponentSpec]:
237+
# return [
238+
# ComponentSpec("scheduler", UniPCMultistepScheduler),
239+
# ]
240+
241+
# @property
242+
# def loop_inputs(self) -> List[InputParam]:
243+
# return [
244+
# InputParam(
245+
# "timesteps",
246+
# required=True,
247+
# type_hint=torch.Tensor,
248+
# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
249+
# ),
250+
# InputParam(
251+
# "num_inference_steps",
252+
# required=True,
253+
# type_hint=int,
254+
# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
255+
# ),
256+
# ]
257+
258+
# @torch.no_grad()
259+
# def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
260+
# block_state = self.get_block_state(state)
261+
262+
# block_state.num_warmup_steps = max(
263+
# len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
264+
# )
265+
266+
# block_state.boundary_timestep = components.config.boundary_ratio * components.scheduler.config.num_train_timesteps
267+
268+
# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
269+
# for i, t in enumerate(block_state.timesteps):
270+
271+
# if t > block_state.boundary_timestep:
272+
# # hieh-noise stage
273+
# block_state.current_model = components.transformer
274+
# block_state.current_guider = components.guider
275+
# else:
276+
# # low-noise stage
277+
# block_state.current_model = components.transformer_2
278+
# block_state.current_guider = components.guider_2
279+
# components, block_state = self.loop_step(components, block_state, i=i, t=t)
280+
# if i == len(block_state.timesteps) - 1 or (
281+
# (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
282+
# ):
283+
# progress_bar.update()
284+
285+
# self.set_block_state(state, block_state)
286+
287+
# return components, state
288+
289+
249290
class WanDenoiseStep(WanDenoiseLoopWrapper):
250291
block_classes = [
251292
WanLoopDenoiser,
@@ -261,5 +302,5 @@ def description(self) -> str:
261302
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
262303
" - `WanLoopDenoiser`\n"
263304
" - `WanLoopAfterDenoiser`\n"
264-
"This block supports both text2vid tasks."
305+
"This block supports text-to-video tasks."
265306
)

0 commit comments

Comments
 (0)