Skip to content

Commit 26d8bc0

Browse files
committed
Use VaeImageProcessor.postprocess in standard and modular ernie
1 parent 1176735 commit 26d8bc0

2 files changed

Lines changed: 11 additions & 20 deletions

File tree

src/diffusers/modular_pipelines/ernie_image/decoders.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numpy as np
1615
import torch
17-
from PIL import Image
1816

1917
from ...configuration_utils import FrozenDict
18+
from ...image_processor import VaeImageProcessor
2019
from ...models import AutoencoderKLFlux2
2120
from ...utils import logging
2221
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
@@ -44,6 +43,12 @@ def expected_components(self) -> list[ComponentSpec]:
4443
config=FrozenDict({"patch_size": 2}),
4544
default_creation_method="from_config",
4645
),
46+
ComponentSpec(
47+
"image_processor",
48+
VaeImageProcessor,
49+
config=FrozenDict({"vae_scale_factor": 16}),
50+
default_creation_method="from_config",
51+
),
4752
]
4853

4954
@property
@@ -81,18 +86,7 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState)
8186
latents = components.pachifier.unpack_latents(latents)
8287

8388
images = vae.decode(latents.to(vae.dtype), return_dict=False)[0]
84-
images = (images.clamp(-1, 1) + 1) / 2
85-
86-
output_type = block_state.output_type
87-
if output_type == "pt":
88-
block_state.images = images
89-
elif output_type == "np":
90-
block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy()
91-
elif output_type == "pil":
92-
images_np = images.cpu().permute(0, 2, 3, 1).float().numpy()
93-
block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np]
94-
else:
95-
raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.")
89+
block_state.images = components.image_processor.postprocess(images, output_type=block_state.output_type)
9690

9791
self.set_block_state(state, block_state)
9892
return components, state

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from typing import Callable, List, Optional, Union
2121

2222
import torch
23-
from PIL import Image
2423
from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model
2524

25+
from ...image_processor import VaeImageProcessor
2626
from ...models import AutoencoderKLFlux2
2727
from ...models.transformers import ErnieImageTransformer2DModel
2828
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -68,6 +68,7 @@ def __init__(
6868
pe_tokenizer=pe_tokenizer,
6969
)
7070
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
71+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
7172

7273
@property
7374
def guidance_scale(self):
@@ -379,11 +380,7 @@ def __call__(
379380
images = self.vae.decode(latents, return_dict=False)[0]
380381

381382
# Post-process
382-
images = (images.clamp(-1, 1) + 1) / 2
383-
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
384-
385-
if output_type == "pil":
386-
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
383+
images = self.image_processor.postprocess(images, output_type=output_type)
387384

388385
# Offload all models
389386
self.maybe_free_model_hooks()

0 commit comments

Comments
 (0)