Skip to content

Commit 0a5c98e

Browse files
authored
Add from_single_file support to ErnieImageTransformer2DModel (#13727)
* Add from_single_file support to ErnieImageTransformer2DModel * drop redundant copy loop
1 parent 86dab15 commit 0a5c98e

3 files changed

Lines changed: 17 additions & 2 deletions

File tree

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
convert_chroma_transformer_checkpoint_to_diffusers,
3838
convert_controlnet_checkpoint,
3939
convert_cosmos_transformer_checkpoint_to_diffusers,
40+
convert_ernie_image_transformer_checkpoint_to_diffusers,
4041
convert_flux2_transformer_checkpoint_to_diffusers,
4142
convert_flux_transformer_checkpoint_to_diffusers,
4243
convert_hidream_transformer_to_diffusers,
@@ -118,6 +119,10 @@
118119
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
119120
"default_subfolder": "transformer",
120121
},
122+
"ErnieImageTransformer2DModel": {
123+
"checkpoint_mapping_fn": convert_ernie_image_transformer_checkpoint_to_diffusers,
124+
"default_subfolder": "transformer",
125+
},
121126
"LTXVideoTransformer3DModel": {
122127
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
123128
"default_subfolder": "transformer",

src/diffusers/loaders/single_file_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4170,3 +4170,13 @@ def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
41704170
update_state_dict_inplace(converted_state_dict, key, new_key)
41714171

41724172
return converted_state_dict
4173+
4174+
4175+
def convert_ernie_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
4176+
keys = list(checkpoint.keys())
4177+
4178+
for k in keys:
4179+
if "model.diffusion_model." in k:
4180+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
4181+
4182+
return checkpoint

src/diffusers/models/transformers/transformer_ernie_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch.nn.functional as F
2626

2727
from ...configuration_utils import ConfigMixin, register_to_config
28-
from ...loaders import PeftAdapterMixin
28+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2929
from ...utils import BaseOutput, logging
3030
from ..attention import AttentionModuleMixin
3131
from ..attention_dispatch import dispatch_attention_fn
@@ -289,7 +289,7 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
289289
return x
290290

291291

292-
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
292+
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
293293
_supports_gradient_checkpointing = True
294294
_repeated_blocks = ["ErnieImageSharedAdaLNBlock"]
295295

0 commit comments

Comments
 (0)