@@ -663,22 +663,23 @@ def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, t
663663class WanVideoUnit_FunControl (PipelineUnit ):
664664 def __init__ (self ):
665665 super ().__init__ (
666- input_params = ("control_video" , "num_frames" , "height" , "width" , "tiled" , "tile_size" , "tile_stride" , "clip_feature" , "y" ),
666+ input_params = ("control_video" , "num_frames" , "height" , "width" , "tiled" , "tile_size" , "tile_stride" , "clip_feature" , "y" , "latents" ),
667667 onload_model_names = ("vae" ,)
668668 )
669669
670- def process (self , pipe : WanVideoPipeline , control_video , num_frames , height , width , tiled , tile_size , tile_stride , clip_feature , y ):
670+ def process (self , pipe : WanVideoPipeline , control_video , num_frames , height , width , tiled , tile_size , tile_stride , clip_feature , y , latents ):
671671 if control_video is None :
672672 return {}
673673 pipe .load_models_to_device (self .onload_model_names )
674674 control_video = pipe .preprocess_video (control_video )
675675 control_latents = pipe .vae .encode (control_video , device = pipe .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride ).to (dtype = pipe .torch_dtype , device = pipe .device )
676676 control_latents = control_latents .to (dtype = pipe .torch_dtype , device = pipe .device )
677+ y_dim = pipe .dit .in_dim - control_latents .shape [1 ]- latents .shape [1 ]
677678 if clip_feature is None or y is None :
678679 clip_feature = torch .zeros ((1 , 257 , 1280 ), dtype = pipe .torch_dtype , device = pipe .device )
679- y = torch .zeros ((1 , 16 , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = pipe .torch_dtype , device = pipe .device )
680+ y = torch .zeros ((1 , y_dim , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = pipe .torch_dtype , device = pipe .device )
680681 else :
681- y = y [:, - 16 :]
682+ y = y [:, - y_dim :]
682683 y = torch .concat ([control_latents , y ], dim = 1 )
683684 return {"clip_feature" : clip_feature , "y" : y }
684685
@@ -698,6 +699,8 @@ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
698699 reference_image = reference_image .resize ((width , height ))
699700 reference_latents = pipe .preprocess_video ([reference_image ])
700701 reference_latents = pipe .vae .encode (reference_latents , device = pipe .device )
702+ if pipe .image_encoder is None :
703+ return {"reference_latents" : reference_latents }
701704 clip_feature = pipe .preprocess_image (reference_image )
702705 clip_feature = pipe .image_encoder .encode_image ([clip_feature ])
703706 return {"reference_latents" : reference_latents , "clip_feature" : clip_feature }
@@ -707,13 +710,14 @@ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
707710class WanVideoUnit_FunCameraControl (PipelineUnit ):
708711 def __init__ (self ):
709712 super ().__init__ (
710- input_params = ("height" , "width" , "num_frames" , "camera_control_direction" , "camera_control_speed" , "camera_control_origin" , "latents" , "input_image" ),
713+ input_params = ("height" , "width" , "num_frames" , "camera_control_direction" , "camera_control_speed" , "camera_control_origin" , "latents" , "input_image" , "tiled" , "tile_size" , "tile_stride" ),
711714 onload_model_names = ("vae" ,)
712715 )
713716
714- def process (self , pipe : WanVideoPipeline , height , width , num_frames , camera_control_direction , camera_control_speed , camera_control_origin , latents , input_image ):
717+ def process (self , pipe : WanVideoPipeline , height , width , num_frames , camera_control_direction , camera_control_speed , camera_control_origin , latents , input_image , tiled , tile_size , tile_stride ):
715718 if camera_control_direction is None :
716719 return {}
720+ pipe .load_models_to_device (self .onload_model_names )
717721 camera_control_plucker_embedding = pipe .dit .control_adapter .process_camera_coordinates (
718722 camera_control_direction , num_frames , height , width , camera_control_speed , camera_control_origin )
719723
@@ -728,14 +732,27 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont
728732 control_camera_latents = control_camera_latents .contiguous ().view (b , f // 4 , 4 , c , h , w ).transpose (2 , 3 )
729733 control_camera_latents = control_camera_latents .contiguous ().view (b , f // 4 , c * 4 , h , w ).transpose (1 , 2 )
730734 control_camera_latents_input = control_camera_latents .to (device = pipe .device , dtype = pipe .torch_dtype )
731-
735+
732736 input_image = input_image .resize ((width , height ))
733737 input_latents = pipe .preprocess_video ([input_image ])
734- pipe .load_models_to_device (self .onload_model_names )
735738 input_latents = pipe .vae .encode (input_latents , device = pipe .device )
736739 y = torch .zeros_like (latents ).to (pipe .device )
737740 y [:, :, :1 ] = input_latents
738741 y = y .to (dtype = pipe .torch_dtype , device = pipe .device )
742+
743+ if y .shape [1 ] != pipe .dit .in_dim - latents .shape [1 ]:
744+ image = pipe .preprocess_image (input_image .resize ((width , height ))).to (pipe .device )
745+ vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
746+ y = pipe .vae .encode ([vae_input .to (dtype = pipe .torch_dtype , device = pipe .device )], device = pipe .device , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )[0 ]
747+ y = y .to (dtype = pipe .torch_dtype , device = pipe .device )
748+ msk = torch .ones (1 , num_frames , height // 8 , width // 8 , device = pipe .device )
749+ msk [:, 1 :] = 0
750+ msk = torch .concat ([torch .repeat_interleave (msk [:, 0 :1 ], repeats = 4 , dim = 1 ), msk [:, 1 :]], dim = 1 )
751+ msk = msk .view (1 , msk .shape [1 ] // 4 , 4 , height // 8 , width // 8 )
752+ msk = msk .transpose (1 , 2 )[0 ]
753+ y = torch .cat ([msk ,y ])
754+ y = y .unsqueeze (0 )
755+ y = y .to (dtype = pipe .torch_dtype , device = pipe .device )
739756 return {"control_camera_latents_input" : control_camera_latents_input , "y" : y }
740757
741758
@@ -1048,7 +1065,7 @@ def model_fn_wan_video(
10481065 if clip_feature is not None and dit .require_clip_embedding :
10491066 clip_embdding = dit .img_emb (clip_feature )
10501067 context = torch .cat ([clip_embdding , context ], dim = 1 )
1051-
1068+
10521069 # Add camera control
10531070 x , (f , h , w ) = dit .patchify (x , control_camera_latents_input )
10541071
0 commit comments