|
57 | 57 | ``` |
58 | 58 | """ |
59 | 59 |
|
| 60 | +UPSAMPLING_MAX_IMAGE_SIZE = 768**2 |
60 | 61 |
|
61 | 62 | # Adapted from |
62 | 63 | # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 |
@@ -95,8 +96,6 @@ def format_input( |
95 | 96 | ] |
96 | 97 | else: |
97 | 98 | assert len(images) == len(prompts), "Number of images must match number of prompts" |
98 | | - images = _validate_and_process_images(images) |
99 | | - |
100 | 99 | messages = [ |
101 | 100 | [ |
102 | 101 | { |
@@ -149,7 +148,7 @@ def _validate_and_process_images( |
149 | 148 | # cap the pixels |
150 | 149 | images = [ |
151 | 150 | [ |
152 | | - image_processor._resize_to_target_area(img_i, upsampling_max_image_size, return_if_small_image=True) |
| 151 | + image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) |
153 | 152 | for img_i in img_i |
154 | 153 | ] |
155 | 154 | for img_i in images |
@@ -301,7 +300,7 @@ def __init__( |
301 | 300 | self.system_message = SYSTEM_MESSAGE |
302 | 301 | self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I |
303 | 302 | self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I |
304 | | - self.upsampling_max_image_size = 768**2 |
| 303 | + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE |
305 | 304 |
|
306 | 305 | @staticmethod |
307 | 306 | def _get_mistral_3_small_prompt_embeds( |
@@ -525,6 +524,10 @@ def upsample_prompt( |
525 | 524 | else: |
526 | 525 | system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I |
527 | 526 |
|
| 527 | + # Validate and process the input images |
| 528 | + if images: |
| 529 | + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) |
| 530 | + |
528 | 531 | # Format input messages |
529 | 532 | messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) |
530 | 533 |
|
|
0 commit comments