diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index fdaff9b0af8a..a36d8d1393c5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -818,8 +818,10 @@ def __call__( ) if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError("For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index eed671152bc9..27cdd0f14bf9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -958,8 +958,10 @@ def __call__( # set control mode if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError("For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = []