Skip to content

Fix: use resample-aware bilinear+antialias interpolation for tensor/numpy resize in VaeImageProcessor#13500

Open
GitGlimpse895 wants to merge 4 commits intohuggingface:mainfrom
GitGlimpse895:fix/tensor-resize-interpolation
Open

Fix: use resample-aware bilinear+antialias interpolation for tensor/numpy resize in VaeImageProcessor#13500
GitGlimpse895 wants to merge 4 commits intohuggingface:mainfrom
GitGlimpse895:fix/tensor-resize-interpolation

Conversation

@GitGlimpse895
Copy link
Copy Markdown

@GitGlimpse895 GitGlimpse895 commented Apr 18, 2026

What does this PR do?

VaeImageProcessor exposes a resample config parameter (defaulting to "lanczos")
and correctly applies it when resizing PIL images via PIL_INTERPOLATION. However,
the two torch.nn.functional.interpolate calls handling torch.Tensor and
np.ndarray inputs passed no mode argument — causing PyTorch to silently default
to "nearest" neighbor interpolation, regardless of the configured resample filter.
No antialias=True was set either, causing aliasing artifacts on downsampling.

This fix:

  • Adds a TORCH_INTERPOLATION dict in pil_utils.py mapping the same resample-string
    keys as PIL_INTERPOLATION to their torch.nn.functional.interpolate equivalents
    (with antialias eligibility). "lanczos" maps to bilinear+antialias, the closest
    high-quality torch substitute.
  • Updates both tensor branches of VaeImageProcessor.resize() to use the mapped mode
    and antialias flag, making tensor/numpy resize quality consistent with the PIL path.

This silently affected every pipeline that passes torch.Tensor inputs to
VaeImageProcessor (ControlNet conditioning, IP-Adapter, img2img, etc.).

Fixes # (issue)

Before submitting

Who can review?

@yiyixuxu @sayakpaul @DN6

@github-actions github-actions bot added utils size/S PR with diff < 50 LOC labels Apr 18, 2026
@ParamChordiya
Copy link
Copy Markdown

Code Review: Fix resample-aware interpolation in VaeImageProcessor

Summary

This fixes a legitimate quality bug: VaeImageProcessor.resize() correctly applies the configured resample filter for PIL inputs, but the torch.Tensor and np.ndarray branches called F.interpolate with no mode argument, defaulting to "nearest" neighbor — producing blocky/aliased results. The TORCH_INTERPOLATION mapping dict and (mode, antialias) tuple pattern is clean.

Issues

  1. "lanczos" maps to "bilinear" — should it be "bicubic"?"bicubic" is a closer approximation to Lanczos in frequency response and sharpness. Since "lanczos" is the default resample value, this mapping affects the vast majority of users. Worth discussing.

  2. No new tests — Existing tests only check output shapes, not interpolation quality or mode. Needs at minimum:

    • A test verifying interpolation is not "nearest" (e.g. checkerboard pattern resize)
    • A test that each key in TORCH_INTERPOLATION works without error
    • A test comparing tensor vs PIL resize output similarity (PSNR/MSE threshold)
  3. Import path inconsistencyPIL_INTERPOLATION is exported via utils/__init__.py, but TORCH_INTERPOLATION is imported directly from utils.pil_utils. Should be consistent.

  4. Silent behavior change — This changes outputs for all existing pipelines passing tensor/numpy inputs. Deserves a changelog entry flagging it as a fix that changes output values.

  5. No KeyError guard on mapping lookup — Invalid resample strings will produce an unhelpful KeyError. A ValueError with supported options listed would be better.

  6. Duplicate code — The TORCH_INTERPOLATION lookup is identical in both the tensor and numpy branches. Could factor it out above with isinstance(image, (torch.Tensor, np.ndarray)).

  7. resize_and_crop_tensor not updated — This static method also calls F.interpolate with hardcoded mode="bilinear" and no antialias. Worth a follow-up.

Verdict

Request changes — The core fix is correct, but needs tests, the lanczos mapping deserves discussion, and the import path should be consistent.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
i left one comment

Comment thread src/diffusers/utils/pil_utils.py Outdated
"linear": ("bilinear", True),
"bilinear": ("bilinear", True),
"bicubic": ("bicubic", True),
"lanczos": ("bilinear", True),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, so if this option is not supported in torch, let's not map it to anything
just send a warning that says this resample mode is not supported for tensor/ndarray so it will be ignored (the default nearest is used intead). this way we don't change the default behavior for resize
what do you think?

@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 19, 2026
@GitGlimpse895
Copy link
Copy Markdown
Author

Thanks @yiyixuxu — updated! Revised approach:

  • TORCH_INTERPOLATION now only maps modes torch natively supports
    (bilinear, bicubic, nearest).
  • For unsupported modes like "lanczos", the code now emits a
    logger.warning and falls back to "nearest", preserving existing
    default behavior with no silent output change.
  • Also factored the duplicate lookup out of both tensor and numpy
    branches into a single shared isinstance(image, (torch.Tensor, np.ndarray))
    branch, eliminating the code duplication @ParamChordiya flagged.

@GitGlimpse895 GitGlimpse895 force-pushed the fix/tensor-resize-interpolation branch from 799e72c to d3456ea Compare April 19, 2026 02:47
@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants