Skip to content

VAE Decoder next(iter(..)) causes graph break #12501

@ppadjinTT

Description

@ppadjinTT

Describe the bug

When doing torch.compile with the decoder part of the VAE, a graph break gets hit. This makes it impossible to compile the model as a whole.

And these are the exact lines from diffusers library that cause the graph break:

diffusers/src/diffusers/models/autoencoders/vae.py:287-289

        sample = self.conv_in(sample)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

Particularly, the next(iter()) call causes the evaluation of the lazy tensors that preceed this part of the graph, forcing conv_in module to compile independently.

It would be ideal if this upscale_dtype can be inferred in a different way

Thanks!

Reproduction

This is the repro code:

import torch
from diffusers import AutoencoderKL


vae = AutoencoderKL.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        subfolder="vae",
        torch_dtype=torch.float16
)

model = vae.decoder
model = model.eval()
model = model.to(torch.bfloat16)
device = torch.device("cpu")
sample_img = torch.randn(1, 4, 64, 64, dtype=torch.bfloat16)

model.compile()

sample_img = sample_img.to(device)
model = model.to(device)

with torch.no_grad():
    output = model(sample_img)

Logs

System Info

🤗 Diffusers version: 0.35.1

  • Platform: Linux-5.4.0-212-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.13
  • PyTorch version (GPU?): 2.7.0+cpu (False)
  • Flax version (CPU?/GPU?/TPU?): 0.10.4 (cpu)
  • Jax version: 0.7.1
  • JaxLib version: 0.7.1
  • Huggingface_hub version: 0.35.3
  • Transformers version: 4.52.4
  • Accelerate version: 1.10.1
  • PEFT version: 0.17.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NA
  • Using GPU in script?: no, just cpu
  • Using distributed or parallel set-up in script?: no, single core cpu

Who can help?

@sayakpaul
@DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions