2727 ModularPipelineBlocks ,
2828 PipelineState ,
2929)
30- from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
30+ from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam , ConfigSpec
3131from .modular_pipeline import WanModularPipeline
3232
3333
3434logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3535
3636
37+
3738class WanLoopDenoiser (ModularPipelineBlocks ):
3839 model_name = "wan"
3940
@@ -61,11 +62,6 @@ def description(self) -> str:
6162 def inputs (self ) -> List [Tuple [str , Any ]]:
6263 return [
6364 InputParam ("attention_kwargs" ),
64- ]
65-
66- @property
67- def intermediate_inputs (self ) -> List [str ]:
68- return [
6965 InputParam (
7066 "latents" ,
7167 required = True ,
@@ -78,14 +74,8 @@ def intermediate_inputs(self) -> List[str]:
7874 type_hint = int ,
7975 description = "The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ,
8076 ),
81- InputParam (
82- kwargs_type = "denoiser_input_fields" ,
83- description = (
84- "All conditional model inputs that need to be prepared with guider. "
85- "It should contain prompt_embeds/negative_prompt_embeds. "
86- "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
87- ),
88- ),
77+ InputParam ("prompt_embeds" , required = True , type_hint = torch .Tensor ),
78+ InputParam ("negative_prompt_embeds" , required = True , type_hint = torch .Tensor ),
8979 ]
9080
9181 @torch .no_grad ()
@@ -95,10 +85,7 @@ def __call__(
9585 # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
9686 # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
9787 guider_inputs = {
98- "prompt_embeds" : (
99- getattr (block_state , "prompt_embeds" , None ),
100- getattr (block_state , "negative_prompt_embeds" , None ),
101- ),
88+ "encoder_hidden_states" : (block_state .prompt_embeds , block_state .negative_prompt_embeds ),
10289 }
10390 transformer_dtype = components .transformer .dtype
10491
@@ -118,16 +105,15 @@ def __call__(
118105 for guider_state_batch in guider_state :
119106 components .guider .prepare_models (components .transformer )
120107 cond_kwargs = {input_name : getattr (guider_state_batch , input_name ) for input_name in guider_inputs .keys ()}
121- prompt_embeds = cond_kwargs .pop ("prompt_embeds" )
122108
123109 # Predict the noise residual
124110 # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
125111 guider_state_batch .noise_pred = components .transformer (
126112 hidden_states = block_state .latents .to (transformer_dtype ),
127- timestep = t .flatten (),
128- encoder_hidden_states = prompt_embeds ,
113+ timestep = t .expand (block_state .latents .shape [0 ]).to (block_state .latents .dtype ),
129114 attention_kwargs = block_state .attention_kwargs ,
130115 return_dict = False ,
116+ ** cond_kwargs ,
131117 )[0 ]
132118 components .guider .cleanup_models (components .transformer )
133119
@@ -154,19 +140,6 @@ def description(self) -> str:
154140 "object (e.g. `WanDenoiseLoopWrapper`)"
155141 )
156142
157- @property
158- def inputs (self ) -> List [Tuple [str , Any ]]:
159- return []
160-
161- @property
162- def intermediate_inputs (self ) -> List [str ]:
163- return [
164- InputParam ("generator" ),
165- ]
166-
167- @property
168- def intermediate_outputs (self ) -> List [OutputParam ]:
169- return [OutputParam ("latents" , type_hint = torch .Tensor , description = "The denoised latents" )]
170143
171144 @torch .no_grad ()
172145 def __call__ (self , components : WanModularPipeline , block_state : BlockState , i : int , t : torch .Tensor ):
@@ -198,18 +171,11 @@ def description(self) -> str:
198171 @property
199172 def loop_expected_components (self ) -> List [ComponentSpec ]:
200173 return [
201- ComponentSpec (
202- "guider" ,
203- ClassifierFreeGuidance ,
204- config = FrozenDict ({"guidance_scale" : 5.0 }),
205- default_creation_method = "from_config" ,
206- ),
207174 ComponentSpec ("scheduler" , UniPCMultistepScheduler ),
208- ComponentSpec ("transformer" , WanTransformer3DModel ),
209175 ]
210176
211177 @property
212- def loop_intermediate_inputs (self ) -> List [InputParam ]:
178+ def loop_inputs (self ) -> List [InputParam ]:
213179 return [
214180 InputParam (
215181 "timesteps" ,
@@ -246,6 +212,81 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
246212 return components , state
247213
248214
215+ # class Wan22DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
216+ # model_name = "wan"
217+
218+ # @property
219+ # def description(self) -> str:
220+ # return (
221+ # "Pipeline block that iteratively denoise the latents over `timesteps`. "
222+ # "The specific steps with each iteration can be customized with `sub_blocks` attributes"
223+ # )
224+
225+ # @property
226+ # def loop_expected_configs(self) -> List[ConfigSpec]:
227+ # return [
228+ # ConfigSpec(
229+ # "boundary_ratio",
230+ # type_hint=float,
231+ # description="The ratio of the total timesteps to use as the boundary for switching between transformers in two-stage denoising.",
232+ # ),
233+ # ]
234+
235+ # @property
236+ # def loop_expected_components(self) -> List[ComponentSpec]:
237+ # return [
238+ # ComponentSpec("scheduler", UniPCMultistepScheduler),
239+ # ]
240+
241+ # @property
242+ # def loop_inputs(self) -> List[InputParam]:
243+ # return [
244+ # InputParam(
245+ # "timesteps",
246+ # required=True,
247+ # type_hint=torch.Tensor,
248+ # description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
249+ # ),
250+ # InputParam(
251+ # "num_inference_steps",
252+ # required=True,
253+ # type_hint=int,
254+ # description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
255+ # ),
256+ # ]
257+
258+ # @torch.no_grad()
259+ # def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
260+ # block_state = self.get_block_state(state)
261+
262+ # block_state.num_warmup_steps = max(
263+ # len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
264+ # )
265+
266+ # block_state.boundary_timestep = components.config.boundary_ratio * components.scheduler.config.num_train_timesteps
267+
268+ # with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
269+ # for i, t in enumerate(block_state.timesteps):
270+
271+ # if t > block_state.boundary_timestep:
272+ # # hieh-noise stage
273+ # block_state.current_model = components.transformer
274+ # block_state.current_guider = components.guider
275+ # else:
276+ # # low-noise stage
277+ # block_state.current_model = components.transformer_2
278+ # block_state.current_guider = components.guider_2
279+ # components, block_state = self.loop_step(components, block_state, i=i, t=t)
280+ # if i == len(block_state.timesteps) - 1 or (
281+ # (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
282+ # ):
283+ # progress_bar.update()
284+
285+ # self.set_block_state(state, block_state)
286+
287+ # return components, state
288+
289+
249290class WanDenoiseStep (WanDenoiseLoopWrapper ):
250291 block_classes = [
251292 WanLoopDenoiser ,
@@ -261,5 +302,5 @@ def description(self) -> str:
261302 "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n "
262303 " - `WanLoopDenoiser`\n "
263304 " - `WanLoopAfterDenoiser`\n "
264- "This block supports both text2vid tasks."
305+ "This block supports text-to-video tasks."
265306 )
0 commit comments