Skip to content

Commit 7448258

Browse files
authored
Add FLUX.2 Klein Inpaint Pipeline (#13050)
* Add Flux2KleinInpaintPipeline * Fixed mask channel mismatch and a bit of cleaning * Added tests and minor refactors * Added support for reference images for inpainting * Style fixes * Fixed the example docstring * Corrected mask latent preparation for correct dimensional alignment * replace masked_image_latents context with clean_source_latents, fix mask spatial alignment and remove unused VAE encoding * Fix T-coordinate collision for conditioning * Changed the default strength from 0.6 to 0.8 * Added reference image test and updated the frozenset * Validated ref image, latent passing support and fixed ref image preprocessing * Refined preprocessing with 1MP resolution cap and timestep tracking * Updated typing, improved validation and changed the example docstring * Style fixes * Fixed batch inference discrepancy and addressed review comments * Fixed a typo Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> * Apply suggestion from @asomoza Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> * Reused encoded latents and fix channel check consistency * fixed pre-encoded latent preprocessing for source and ref images * Apply style fixes * Updated the docstring with the shape requirements * Apply style fixes * Fixed copies
1 parent 160852d commit 7448258

7 files changed

Lines changed: 1503 additions & 2 deletions

File tree

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@
533533
"EasyAnimateInpaintPipeline",
534534
"EasyAnimatePipeline",
535535
"ErnieImagePipeline",
536+
"Flux2KleinInpaintPipeline",
536537
"Flux2KleinKVPipeline",
537538
"Flux2KleinPipeline",
538539
"Flux2Pipeline",
@@ -1317,6 +1318,7 @@
13171318
EasyAnimateInpaintPipeline,
13181319
EasyAnimatePipeline,
13191320
ErnieImagePipeline,
1321+
Flux2KleinInpaintPipeline,
13201322
Flux2KleinKVPipeline,
13211323
Flux2KleinPipeline,
13221324
Flux2Pipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@
160160
]
161161
_import_structure["bria"] = ["BriaPipeline"]
162162
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
163-
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
163+
_import_structure["flux2"] = [
164+
"Flux2Pipeline",
165+
"Flux2KleinPipeline",
166+
"Flux2KleinInpaintPipeline",
167+
"Flux2KleinKVPipeline",
168+
]
164169
_import_structure["flux"] = [
165170
"FluxControlPipeline",
166171
"FluxControlInpaintPipeline",
@@ -697,7 +702,7 @@
697702
FluxPriorReduxPipeline,
698703
ReduxImageEncoder,
699704
)
700-
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
705+
from .flux2 import Flux2KleinInpaintPipeline, Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
701706
from .glm_image import GlmImagePipeline
702707
from .helios import HeliosPipeline, HeliosPyramidPipeline
703708
from .hidream_image import HiDreamImagePipeline

src/diffusers/pipelines/flux2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
2626
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
27+
_import_structure["pipeline_flux2_klein_inpaint"] = ["Flux2KleinInpaintPipeline"]
2728
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2930
try:
@@ -34,6 +35,7 @@
3435
else:
3536
from .pipeline_flux2 import Flux2Pipeline
3637
from .pipeline_flux2_klein import Flux2KleinPipeline
38+
from .pipeline_flux2_klein_inpaint import Flux2KleinInpaintPipeline
3739
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
3840
else:
3941
import sys

src/diffusers/pipelines/flux2/image_processor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ class Flux2ImageProcessor(VaeImageProcessor):
3535
VAE latent channels.
3636
do_normalize (`bool`, *optional*, defaults to `True`):
3737
Whether to normalize the image to [-1,1].
38+
do_binarize (`bool`, *optional*, defaults to `False`):
39+
Whether to binarize the image to 0/1.
3840
do_convert_rgb (`bool`, *optional*, defaults to be `True`):
3941
Whether to convert the images to RGB format.
42+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
43+
Whether to convert the images to grayscale format.
4044
"""
4145

4246
@register_to_config
@@ -46,14 +50,18 @@ def __init__(
4650
vae_scale_factor: int = 16,
4751
vae_latent_channels: int = 32,
4852
do_normalize: bool = True,
53+
do_binarize: bool = False,
4954
do_convert_rgb: bool = True,
55+
do_convert_grayscale: bool = False,
5056
):
5157
super().__init__(
5258
do_resize=do_resize,
5359
vae_scale_factor=vae_scale_factor,
5460
vae_latent_channels=vae_latent_channels,
5561
do_normalize=do_normalize,
62+
do_binarize=do_binarize,
5663
do_convert_rgb=do_convert_rgb,
64+
do_convert_grayscale=do_convert_grayscale,
5765
)
5866

5967
@staticmethod

0 commit comments

Comments
 (0)