@@ -663,25 +663,23 @@ def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, t
663663class WanVideoUnit_FunControl (PipelineUnit ):
664664 def __init__ (self ):
665665 super ().__init__ (
666- input_params = ("control_video" , "num_frames" , "height" , "width" , "tiled" , "tile_size" , "tile_stride" , "clip_feature" , "y" ),
666+ input_params = ("control_video" , "num_frames" , "height" , "width" , "tiled" , "tile_size" , "tile_stride" , "clip_feature" , "y" , "latents" ),
667667 onload_model_names = ("vae" ,)
668668 )
669669
670- def process (self , pipe : WanVideoPipeline , control_video , num_frames , height , width , tiled , tile_size , tile_stride , clip_feature , y ):
670+ def process (self , pipe : WanVideoPipeline , control_video , num_frames , height , width , tiled , tile_size , tile_stride , clip_feature , y , latents ):
671671 if control_video is None :
672672 return {}
673673 pipe .load_models_to_device (self .onload_model_names )
674674 control_video = pipe .preprocess_video (control_video )
675675 control_latents = pipe .vae .encode (control_video , device = pipe .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride ).to (dtype = pipe .torch_dtype , device = pipe .device )
676676 control_latents = control_latents .to (dtype = pipe .torch_dtype , device = pipe .device )
677+ y_dim = pipe .dit .in_dim - control_latents .shape [1 ]- latents .shape [1 ]
677678 if clip_feature is None or y is None :
678679 clip_feature = torch .zeros ((1 , 257 , 1280 ), dtype = pipe .torch_dtype , device = pipe .device )
679- y = torch .zeros ((1 , 16 , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = pipe .torch_dtype , device = pipe .device )
680- if pipe .dit2 is not None :
681- y = torch .zeros ((1 , 20 , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = pipe .torch_dtype , device = pipe .device )
680+ y = torch .zeros ((1 , y_dim , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = pipe .torch_dtype , device = pipe .device )
682681 else :
683- if pipe .dit2 is None :
684- y = y [:, - 16 :]
682+ y = y [:, - y_dim :]
685683 y = torch .concat ([control_latents , y ], dim = 1 )
686684 return {"clip_feature" : clip_feature , "y" : y }
687685
@@ -734,13 +732,19 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont
734732 control_camera_latents = control_camera_latents .contiguous ().view (b , f // 4 , 4 , c , h , w ).transpose (2 , 3 )
735733 control_camera_latents = control_camera_latents .contiguous ().view (b , f // 4 , c * 4 , h , w ).transpose (1 , 2 )
736734 control_camera_latents_input = control_camera_latents .to (device = pipe .device , dtype = pipe .torch_dtype )
737-
738- image = pipe .preprocess_image (input_image .resize ((width , height ))).to (pipe .device )
739- vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
740- y = pipe .vae .encode ([vae_input .to (dtype = pipe .torch_dtype , device = pipe .device )], device = pipe .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )[0 ]
735+
736+ input_image = input_image .resize ((width , height ))
737+ input_latents = pipe .preprocess_video ([input_image ])
738+ input_latents = pipe .vae .encode (input_latents , device = pipe .device )
739+ y = torch .zeros_like (latents ).to (pipe .device )
740+ y [:, :, :1 ] = input_latents
741741 y = y .to (dtype = pipe .torch_dtype , device = pipe .device )
742742
743- if pipe .dit2 is not None :
743+ if y .shape [1 ] != pipe .dit .in_dim - latents .shape [1 ]:
744+ image = pipe .preprocess_image (input_image .resize ((width , height ))).to (pipe .device )
745+ vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
746+ y = pipe .vae .encode ([vae_input .to (dtype = pipe .torch_dtype , device = pipe .device )], device = pipe .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )[0 ]
747+ y = y .to (dtype = pipe .torch_dtype , device = pipe .device )
744748 msk = torch .ones (1 , num_frames , height // 8 , width // 8 , device = pipe .device )
745749 msk [:, 1 :] = 0
746750 msk = torch .concat ([torch .repeat_interleave (msk [:, 0 :1 ], repeats = 4 , dim = 1 ), msk [:, 1 :]], dim = 1 )
@@ -1061,7 +1065,7 @@ def model_fn_wan_video(
10611065 if clip_feature is not None and dit .require_clip_embedding :
10621066 clip_embdding = dit .img_emb (clip_feature )
10631067 context = torch .cat ([clip_embdding , context ], dim = 1 )
1064-
1068+
10651069 # Add camera control
10661070 x , (f , h , w ) = dit .patchify (x , control_camera_latents_input )
10671071
0 commit comments