Skip to content

Commit 2b45074

Browse files
author
VERIGEN
committed
Add ERNIE Image model support
1 parent dfec2f2 commit 2b45074

8 files changed

Lines changed: 37 additions & 2 deletions

File tree

ai_diffusion/comfy_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def _find_text_encoder_models(model_list: Sequence[str]):
751751
kind = ResourceKind.text_encoder
752752
return {
753753
resource_id(kind, Arch.all, te): _find_model(model_list, kind, Arch.all, te)
754-
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3_06b", "qwen_3_4b", "qwen_3_8b"]
754+
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3_06b", "qwen_3_4b", "qwen_3_8b", "ministral"]
755755
}
756756

757757

ai_diffusion/comfy_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def empty_latent_image(self, extent: Extent, arch: Arch, batch_size=1):
613613
w, h = extent.width, extent.height
614614
if arch.is_flux_like or arch.is_qwen_like or arch in (Arch.sd3, Arch.chroma, Arch.zimage):
615615
return self.add("EmptySD3LatentImage", 1, width=w, height=h, batch_size=batch_size)
616-
if arch.is_flux2:
616+
if arch.is_flux2 or arch is Arch.ernie:
617617
return self.add("EmptyFlux2LatentImage", 1, width=w, height=h, batch_size=batch_size)
618618
else:
619619
return self.add("EmptyLatentImage", 1, width=w, height=h, batch_size=batch_size)
Lines changed: 10 additions & 0 deletions
Loading
Lines changed: 10 additions & 0 deletions
Loading

ai_diffusion/resources.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Arch(Enum):
9797
qwen_l = "Qwen Layered"
9898
anima = "Anima"
9999
zimage = "Z-Image"
100+
ernie = "ERNIE Image"
100101

101102
auto = "Automatic"
102103
all = "All"
@@ -139,6 +140,8 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
139140
return Arch.anima
140141
if string in {"z-image", "zimage"}:
141142
return Arch.zimage
143+
if string in {"ernie-image", "ernie_image"}:
144+
return Arch.ernie
142145
return None
143146

144147
@staticmethod
@@ -244,6 +247,8 @@ def text_encoders(self):
244247
return ["qwen_3_06b"]
245248
case Arch.zimage:
246249
return ["qwen_3_4b"]
250+
case Arch.ernie:
251+
return ["ministral"]
247252
raise ValueError(f"Unsupported architecture: {self}")
248253

249254
@staticmethod
@@ -265,6 +270,7 @@ def list():
265270
Arch.qwen_l,
266271
Arch.anima,
267272
Arch.zimage,
273+
Arch.ernie,
268274
]
269275

270276

@@ -799,6 +805,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
799805
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_4b"): ["qwen_3_4b", "qwen3-4b", "qwen3_4b", "qwen_3", "qwen-3"],
800806
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_8b"): ["qwen_3_8b", "qwen3-8b", "qwen3_8b"],
801807
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_06b"): ["qwen_3_06b", "qwen3-06b", "qwen3_06b"],
808+
resource_id(ResourceKind.text_encoder, Arch.all, "ministral"): ["ministral-3-3b", "ministral"],
802809
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
803810
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
804811
resource_id(ResourceKind.vae, Arch.illu, "default"): ["sdxl_vae"],
@@ -815,6 +822,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
815822
resource_id(ResourceKind.vae, Arch.qwen_l, "default"): ["qwen_image_layered_vae"],
816823
resource_id(ResourceKind.vae, Arch.anima, "default"): ["qwen_image"],
817824
resource_id(ResourceKind.vae, Arch.zimage, "default"): ["z-image", "flux-", "flux_", "flux/", "flux1", "ae.s"],
825+
resource_id(ResourceKind.vae, Arch.ernie, "default"): ["flux2"],
818826
}
819827
# fmt: on
820828

@@ -848,6 +856,8 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
848856
ResourceId(ResourceKind.vae, Arch.zimage, "default"),
849857
ResourceId(ResourceKind.vae, Arch.flux2_4b, "default"),
850858
ResourceId(ResourceKind.vae, Arch.flux2_9b, "default"),
859+
ResourceId(ResourceKind.text_encoder, Arch.ernie, "ministral"),
860+
ResourceId(ResourceKind.vae, Arch.ernie, "default"),
851861
}
852862

853863
recommended_resource_ids = [

ai_diffusion/ui/theme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def checkpoint_icon(arch: Arch, format: FileFormat | None = None, client: Client
8383
return icon("sd-version-z-image")
8484
elif arch is Arch.anima:
8585
return icon("sd-version-anima")
86+
elif arch is Arch.ernie:
87+
return icon("sd-version-ernie")
8688
else:
8789
log.warning(f"Unresolved SD version {arch}, cannot fetch icon")
8890
return icon("warning")

ai_diffusion/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
167167
clip = w.load_clip(te["qwen_3_06b"], type="omnigen2")
168168
case Arch.zimage:
169169
clip = w.load_clip(te["qwen_3_4b"], type="lumina2")
170+
case Arch.ernie:
171+
clip = w.load_clip(te["ministral"], type="flux2")
170172
case _:
171173
raise RuntimeError(f"No text encoder for model architecture {arch.name}")
172174

tests/test_resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_resource_ids_exist():
5454
Arch.qwen_e_p,
5555
Arch.flux2_9b,
5656
Arch.anima,
57+
Arch.ernie,
5758
):
5859
continue # no model downloads yet
5960
model = res.find_resource(resource_id)

0 commit comments

Comments
 (0)