Skip to content

Commit 6397a67

Browse files
committed
up
1 parent 82685f2 commit 6397a67

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

src/diffusers/pipelines/flux2/image_processor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,24 @@ def check_image_input(
9696
)
9797

9898
return image
99-
99+
100100
@staticmethod
101-
def _resize_to_target_area(
102-
image: PIL.Image.Image, target_area: int = 1024 * 1024, return_if_small_image: bool = False
103-
) -> PIL.Image.Image:
101+
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
104102
image_width, image_height = image.size
105-
pixel_count = image_width * image_height
106-
if return_if_small_image and pixel_count <= target_area:
107-
return image
108103

109-
scale = math.sqrt(target_area / pixel_count)
104+
scale = math.sqrt(target_area / (image_width * image_height))
110105
width = int(image_width * scale)
111106
height = int(image_height * scale)
112107

113108
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
109+
110+
@staticmethod
111+
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
112+
image_width, image_height = image.size
113+
pixel_count = image_width * image_height
114+
if pixel_count <= target_area:
115+
return image
116+
return Flux2ImageProcessor._resize_to_target_area(image, target_area)
114117

115118
def _resize_and_crop(
116119
self,

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
```
5858
"""
5959

60+
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
6061

6162
# Adapted from
6263
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
@@ -95,8 +96,6 @@ def format_input(
9596
]
9697
else:
9798
assert len(images) == len(prompts), "Number of images must match number of prompts"
98-
images = _validate_and_process_images(images)
99-
10099
messages = [
101100
[
102101
{
@@ -149,7 +148,7 @@ def _validate_and_process_images(
149148
# cap the pixels
150149
images = [
151150
[
152-
image_processor._resize_to_target_area(img_i, upsampling_max_image_size, return_if_small_image=True)
151+
image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size)
153152
for img_i in img_i
154153
]
155154
for img_i in images
@@ -301,7 +300,7 @@ def __init__(
301300
self.system_message = SYSTEM_MESSAGE
302301
self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
303302
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
304-
self.upsampling_max_image_size = 768**2
303+
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
305304

306305
@staticmethod
307306
def _get_mistral_3_small_prompt_embeds(
@@ -525,6 +524,10 @@ def upsample_prompt(
525524
else:
526525
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
527526

527+
# Validate and process the input images
528+
if images:
529+
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
530+
528531
# Format input messages
529532
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
530533

0 commit comments

Comments
 (0)