diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 4bf00f749f25..1a08f9425f4e 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -290,6 +290,7 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True + _repeated_blocks = ["ErnieImageSharedAdaLNBlock"] @register_to_config def __init__(