Skip to content

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231

Open
plugyawn wants to merge 18 commits intohuggingface:mainfrom
plugyawn:rae-dit-training
Open

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231
plugyawn wants to merge 18 commits intohuggingface:mainfrom
plugyawn:rae-dit-training

Conversation

@plugyawn
Copy link
Copy Markdown

@plugyawn plugyawn commented Mar 9, 2026

What does this PR do?

This PR adds support for Diffusion Transformers with Representation Autoencoders in Diffusers.

It implements the Stage-2 side of the RAE setup:

  • RAEDiT2DModel
  • RAEDiTPipeline
  • checkpoint conversion for published upstream Stage-2 checkpoints
  • API docs
  • a small examples/research_projects/rae_dit/ training scaffold

This addresses #13225.

Reference implementation: byteriper's repository

Validation

Inference parity with the official implementation is high. For matched class label / initial latent noise / schedule, I measured:

  • max_abs_error=0.00001717
  • mean_abs_error=0.00000122

Qualitative parity artifacts used during validation:

  • same published Stage-2 checkpoint
  • same class label
  • same initial latent noise
  • same 25-step shifted Euler schedule

Inference is also slightly faster in the current Diffusers port on a 40GB A100:

Precision CFG Steps Diffusers sec/img Upstream sec/img Diffusers img/s Delta
bf16 1.0 25 0.817 0.913 1.225 +11.8%
bf16 4.0 25 0.852 0.931 1.174 +9.3%
bf16 1.0 50 1.568 1.761 0.638 +12.3%
bf16 4.0 50 1.649 1.853 0.606 +12.4%

Notes

  • This PR intentionally does not add upstream autoguidance / guidance-model support.
  • The training script is a research-project scaffold under examples/research_projects, not a claim of full upstream training parity.
  • AutoencoderRAE.from_pretrained() is used for the Stage-1 component so the packaged RAEDiTPipeline.from_pretrained(...) path works with published RAE checkpoints.

Before submitting

@plugyawn plugyawn changed the title Add Stage-2 RAE DiT support with pipeline, conversion, and training tooling RAE DiT inference, checkpoint conversion, and preliminary training tooling Mar 9, 2026
@plugyawn plugyawn changed the title RAE DiT inference, checkpoint conversion, and preliminary training tooling Add RAE Diffusion Transformer inference/preliminary training pipelines Mar 9, 2026
@plugyawn plugyawn marked this pull request as draft March 9, 2026 05:46
@plugyawn plugyawn marked this pull request as ready for review March 9, 2026 05:51
@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Mar 9, 2026

@kashif @sayakpaul would be great if you could review. Please note the no_init_weights() fix (details in the PR body); if you prefer, that could be a separate PR, but considering diffusers is supposed to be an extension to torch, I guess it makes sense?

@sayakpaul
Copy link
Copy Markdown
Member

Thanks for the PR. To keep the scope manageable, could we break it down into separate PRs?

For example,

there is also a change to no_init_weights( ). Specifically: it makes Diffusers’ skip-weight-init behave more like normal PyTorch. Now, when no_init_weights() is active, the torch.nn.init.* functions stop returning the tensor they were called on (for ref: PyTorch does return). Most models never notice this, but the RAE-DiT implementation does rely on the return value during construction, which can make otherwise valid checkpoints fail to load through the standard from_pretrained() path.

could be a separate PR.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I left some initial comments, let me know if they make sense.

Comment on lines +13 to +16
- `examples/dreambooth/train_dreambooth_flux.py`
for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks.
- `examples/flux-control/train_control_flux.py`
for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't belong here.

Comment thread src/diffusers/models/modeling_utils.py Outdated
Comment on lines +218 to +221
# Preserve the `torch.nn.init.*` return contract so third-party model
# constructors that chain on the returned tensor still work under
# `no_init_weights()`.
return args[0] if len(args) > 0 else None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example?

super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)

@unittest.skip(
"RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the case? We can always skip layerwise casting for certain layer or layer groups here:

_skip_layerwise_casting_patterns = None

model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02)


class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should use the newly added model tester mixins. You can find an example in #13046

Comment on lines +48 to +49
if shift is None:
shift = torch.zeros_like(scale)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small function, which is okay being present in the caller sites inline?

We also probably don't need _repeat_to_length().

Comment on lines +466 to +470
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt"
)
self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use how #13046 initialized the position embeddings?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense, will do that.

)
return hidden_states

def _run_block(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this. Let's instead follow this pattern:

for index_block, block in enumerate(self.transformer_blocks):


return class_labels

def _prepare_latents(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be called prepare_latents() similar to other pipelines.

Comment on lines +247 to +252
if output_type == "pt":
output = images
else:
output = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
output = self.numpy_to_pil(output)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use an image processor instead here. See:

image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (output,)

return ImagePipelineOutput(images=output)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give this pipeline a separate output class: RAEDiTPipelineOutput.

@sayakpaul sayakpaul requested review from dg845 and kashif March 9, 2026 11:33
@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Mar 10, 2026

@sayakpaul, from what I understand the RAE checkpoint -> DiT checkpoint -> generation pipeline necessarily requires the no_init_weight() change (otherwise the semantics become a bit muddled, imo).

Would it make more sense to open a PR for handling no_init_weights() behavior before this one?

@sayakpaul
Copy link
Copy Markdown
Member

Could you explain why that's needed? I am still not sure about that actually. Prefer providing specific examples that fail without the change for init.

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Mar 10, 2026

Not sure how to link files, but it seems to be related to changes introduced in #13046.

A specific example,

  • AutoencoderRAE consturcts DinoV2WithRegistersModel.
  • ModelMixin.from_pretrained() does this construction under no_init_weights( ) first, before low_cpu_mem_usage kicks in (modelling_utils.py, around line 1300)
  • AutoencoderRAE constructs Dinov2WithRegistersModel(config) in _build_encoder:84, and
    ModelMixin.from_pretrained() always does that construction under no_init_weights() first, even
    before low_cpu_mem_usage matters; see modeling_utils.py:1270. In current transformers, DINOv2-
    with-registers has init code like this in modeling_dinov2_with_registers.py:464:
  module.weight.data = nn.init.trunc_normal_(
      module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  ).to(module.weight.dtype)

Under today’s no_init_weights(), nn.init.trunc_normal_ is replaced with a stub that just passes
and returns None, so that becomes None.to(...) and fails with an AttributeError: 'NoneType' object has no attribute 'to'.

Codex has a better summary, I think:

failing example: AutoencoderRAE builds Dinov2WithRegistersModel(config) in its encoder
path, and ModelMixin.from_pretrained() always instantiates models under no_init_weights() first.
In current transformers, DINOv2’s init_weights() assigns the return value of
nn.init.trunc_normal
(...) and then calls .to(...) on it. With the current no_init_weights()
stub, that return value becomes None, so construction fails with AttributeError: 'NoneType'
object has no attribute 'to'. The proposed change keeps skip-init behavior intact, but restores
the normal PyTorch return contract so these constructors remain compatible.

Re: #13046, note test_models_autoencoder_rae.py:45, where the unit tests seem to be a little off, imo. Not sure the tests are aligned.

# ---------------------------------------------------------------------------
# Tiny test encoder for fast unit tests (no transformers dependency)
# ---------------------------------------------------------------------------


class _TinyTestEncoderModule(torch.nn.Module):
    """Minimal encoder that mimics the patch-token interface without any HF model."""

    def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
        super().__init__()
        self.patch_size = patch_size
        self.hidden_size = hidden_size

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
        tokens = pooled.flatten(2).transpose(1, 2).contiguous()
        return tokens.repeat(1, 1, self.hidden_size)


def _tiny_test_encoder_forward(model, images):
    return model(images)


def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
    return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)


# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
_original_build_encoder = _build_encoder


def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
    if encoder_type == "tiny_test":
        return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
    return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)


_rae_module._build_encoder = _patched_build_encoder

I'm new to diffusers idiomatics, but I was confused why this appeared to be a problem only now, and asked GPT:

no_init_weights() only becomes a problem when all of these are true at once:

  • a diffusers ModelMixin.from_pretrained() call is constructing the model
  • that model’s init() instantiates another model internally
  • that internal model uses torch.nn.init.* and also relies on its return value

RAE is unusual because it does exactly that. Inside autoencoder_rae.py, the AutoencoderRAE constructor directly >builds a transformers vision backbone:

  • Dinov2WithRegistersModel:98
  • SiglipVisionModel:111
  • ViTMAEModel:124

That is not how most other diffusers integrations are structured. Most of the repo does one of these instead:

  • native diffusers models in src/diffusers/models, whose init code only relies on side effects
  • pipelines that accept transformers models as separate top-level components, rather than constructing them inside > a ModelMixin

So other work usually does not run a transformers constructor inside diffusers’ patched no_init_weights() context.

@sayakpaul
Copy link
Copy Markdown
Member

Not sure how to link files

Yes, we can link files and I think it's better this way. For example, it's much better to refer to specific lines like

def get_parameter_device(parameter: torch.nn.Module) -> torch.device:

instead of plain text.

Overall, I think that the explanation you provided in the above comment is that helpful. We need to have some specific (preferably very minimal) code snippet with and without that change to better understand what's happening and why.

For this kind of PRs, it's an expectation that the contributors will try to take some time to understand the library code.

@plugyawn
Copy link
Copy Markdown
Author

Hi @sayakpaul! My bad, I'll update the PR today.

@sayakpaul
Copy link
Copy Markdown
Member

@plugyawn do you want to take another crack?

@plugyawn
Copy link
Copy Markdown
Author

Yess, I'm reading through the diffusers codebase in more detail (I had used it quite a bit but not dived as deep as I had for transformers, hehe), and waiting for some compute to come through (I ran out, should be back by this week).

Thank you so much for waiting! I'll update the PR by this week.

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Mar 30, 2026

@sayakpaul, I removed old no_init_weights() change; as you said, there is a workaround for now.
The nested-init issue is handled locally in AutoencoderRAE now.

The diff is pretty different from the last round, so if you have time, it'd be great if you could review it again! Sorry for the wait!

@sayakpaul
Copy link
Copy Markdown
Member

@kashif possible to do a round of review?

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Mar 30, 2026

Also, I'm pretty certain the merged RAE implementation has some faulty testing.
I'm not sure how it works with the no_init_weights( ) problem unfixed; the tests also use a fake encoder. Might have gone under the radar?

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 2, 2026

@kashif, @sayakpaul would you prefer a separate PR (or here) to fix the testing issues on the RAE?

@kashif
Copy link
Copy Markdown
Contributor

kashif commented Apr 2, 2026

yes please, seperate!

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 2, 2026

Ill take this as a yes for a separate PR!

@plugyawn plugyawn requested a review from sayakpaul April 8, 2026 16:42
@sayakpaul
Copy link
Copy Markdown
Member

@kashif could you do an initial review please?

@plugyawn
Copy link
Copy Markdown
Author

@kashif @sayakpaul is there a timeline on this?

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 19, 2026

@sayakpaul, just a small bump. Have a free week, would love to get this to a mergeable state.

@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review here?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 19, 2026

Claude finished @sayakpaul's task in 2m 37s —— View job


Code Review: RAE Diffusion Transformer Pipeline

  • Read review guidelines (AGENTS.md, models.md, modular.md, parity-testing)
  • Read all changed files under src/diffusers/
  • Review model code (transformer_rae_dit.py)
  • Review pipeline code (pipeline_rae_dit.py, pipeline_output.py, __init__.py)
  • Review autoencoder changes (autoencoder_rae.py)
  • Review init/export changes
  • Post comprehensive review

Overall this is a solid contribution with good inference parity numbers. The code is clean and readable. Below are the issues I found, grouped by severity.


Must-fix

1. Attention does not follow the diffusers attention pattern

RAEDiTAttention (transformer_rae_dit.py:154-199) calls F.scaled_dot_product_attention directly. Per models.md, all new models must use dispatch_attention_fn via a processor class and inherit AttentionModuleMixin. Every other recent transformer in the repo (Flux, Wan, Bria, LTX, PRX, etc.) follows this pattern. This is required for attention backend switching (e.g. xformers, flash-attn, ring attention for sequence parallelism).

The fix is to:

  • Create a RAEDiTAttnProcessor class with _attention_backend / _parallel_config attributes and move the forward logic there, using dispatch_attention_fn
  • Make RAEDiTAttention inherit from (nn.Module, AttentionModuleMixin), set _default_processor_cls = RAEDiTAttnProcessor, and delegate to self.processor(...) in forward()

See transformer_flux.py:75-139 (processor) and transformer_flux.py:275-325 (attention class) for the canonical reference.

Fix this →

2. Missing _no_split_modules on RAEDiT2DModel

RAEDiT2DModel (transformer_rae_dit.py:312) is missing _no_split_modules. Every other transformer model in the repo declares this for correct device placement with accelerate. Should be:

_no_split_modules = ["RAEDiTBlock"]

3. unpatchify uses torch.einsum — potential torch.compile graph break

transformer_rae_dit.py:470:

hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)

Per models.md: "Avoid graph breaks for torch.compile compatibility." While einsum is generally compile-safe, the string-based notation can cause issues in some backends. More importantly, this is an existing pattern used by other models (DiT, SD3, PixArt), so it's acceptable here. However, since models.md says to "avoid NumPy operations in forward implementations," note that einsum string-based notation is fine as long as it's pure PyTorch. This is a soft pass — just flagging for awareness.


Should-fix

4. wo_shift path allocates unnecessary zero tensors

In RAEDiTBlock.forward() (transformer_rae_dit.py:264-267):

if shift_msa is None:
    shift_msa = torch.zeros_like(scale_msa)
if shift_mlp is None:
    shift_mlp = torch.zeros_like(scale_mlp)

When wo_shift=True, this allocates two zero tensors every forward pass just to add zero. The subsequent modulation (norm * (1 + scale) + shift) with shift=0 is mathematically just norm * (1 + scale). Consider branching the modulation logic instead:

norm_hidden_states = self.norm1(hidden_states)
if shift_msa is not None:
    norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
    norm_hidden_states = norm_hidden_states * (1 + scale_msa)

5. GaussianFourierEmbedding.W naming convention

transformer_rae_dit.py:76: The parameter is named self.W (uppercase). While this matches the reference implementation, diffusers convention uses snake_case for parameters/attributes. Consider renaming to self.weight or self.fourier_weight.

6. _prepare_timesteps in pipeline is boilerplate that could use the scheduler directly

pipeline_rae_dit.py:117-133: The _prepare_timesteps method duplicates timestep preparation logic. Since FlowMatchEulerDiscreteScheduler already provides timesteps as tensors from set_timesteps(), the conversion logic for MPS/NPU is only needed if raw floats are passed. This works but is unnecessarily complex — the timestep from self.scheduler.timesteps is already a tensor on the correct device.

7. Missing copyright header in transformer_rae_dit.py

The new file transformer_rae_dit.py has no Apache 2.0 license header, unlike pipeline_output.py and autoencoder_rae.py. All new source files in diffusers should include the standard header.

8. pos_embed registered as persistent=False but used for sincos embeddings

transformer_rae_dit.py:403:

self.register_buffer("pos_embed", pos_embed.unsqueeze(0).float(), persistent=False)

Non-persistent buffers are not saved with state_dict(). If this is intentional (regenerated from config on load), that's fine. But if users save/load with save_pretrained/from_pretrained, the buffer will be regenerated correctly only if the code re-runs __init__. This seems correct here since the pos embed is deterministic from config — just confirming the intent.


Nits / Minor

9. Pipeline __init__ has unused _guidance_scale attribute

pipeline_rae_dit.py:55: self._guidance_scale = 1.0 is set in __init__ but immediately overwritten in __call__ at line 191. This is a common pattern in diffusers pipelines, so it's fine — just noting it's redundant initialization.

10. _expand_conditioning_tokens is a module-level helper

This was flagged in a previous review. While sayakpaul suggested inlining small helpers, this function is used in 3 places (RAEDiTBlock.forward for MSA and MLP expansion, RAEDiTFinalLayer.forward), so keeping it as a helper is reasonable given its 20+ lines. The current approach is acceptable.

11. _swap_swiglu_projection_halves accesses .data directly

transformer_rae_dit.py:60-68: Directly manipulating .data bypasses autograd. Since this is called only during initialize_weights() (construction time), not during forward, it's safe. Just noting the pattern.

12. pipeline_rae_dit.py:232 — timestep normalization coupling

timestep_input = timestep_input / self.scheduler.config.num_train_timesteps

This hardcodes the assumption that the scheduler provides timesteps in [0, num_train_timesteps] range and the model expects [0, 1]. If a different scheduler is used, this will silently produce wrong results. Consider adding a comment or assertion documenting this coupling.


Init / Export changes

The lazy import changes in __init__.py, models/__init__.py, pipelines/__init__.py, transformers/__init__.py, and dummy_pt_objects.py all look correct and follow alphabetical ordering. The _import_structure entries match the TYPE_CHECKING imports.

Autoencoder changes (autoencoder_rae.py)

The _preserve_init_return_tensors() context manager is a clean localized solution to the no_init_weights() issue discussed in the PR comments. It wraps torch.nn.init.* functions to always return the tensor, avoiding the NoneType error when constructing transformers models under diffusers' weight-init suppression. The approach is sound — it doesn't modify global behavior permanently and restores original functions in the finally block.

Summary

The main blocker is issue #1 (attention pattern). This is a hard requirement per models.md and is needed for the model to work with diffusers' attention backend infrastructure. Issue #2 (_no_split_modules) is also important for correct multi-GPU behavior. The remaining items are improvements that could be addressed incrementally.


@plugyawn
Copy link
Copy Markdown
Author

Good bot.

@github-actions github-actions bot added documentation Improvements or additions to documentation models tests utils pipelines examples size/L PR with diff > 200 LOC labels Apr 19, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants