Skip to content

Commit fcc0881

Browse files
kappacommitYour Nameclaude
authored
fix(mm): support ComfyUI bundled checkpoint format for Anima model identification (invoke-ai#9113)
Anima finetunes packaged in ComfyUI format use `model.diffusion_model.*` prefixed keys instead of bare or `net.*` prefixed keys. Update the probe and loader to recognize and handle this format. Co-authored-by: Your Name <you@example.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 51f528c commit fcc0881

3 files changed

Lines changed: 185 additions & 11 deletions

File tree

invokeai/backend/model_manager/configs/main.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,12 +1120,20 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool:
11201120
(unique to Anima - the LLM Adapter that bridges Qwen3 text encoder to the Cosmos DiT)
11211121
alongside Cosmos Predict2 DiT keys (blocks, t_embedder, x_embedder, final_layer).
11221122
1123-
The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`).
1123+
The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`)
1124+
or a `model.diffusion_model.` prefix (ComfyUI bundled checkpoint format).
11241125
"""
11251126
has_llm_adapter = False
11261127
has_cosmos_dit = False
11271128

1128-
# Cosmos DiT key prefixes — support both with and without `net.` prefix
1129+
# LLM adapter key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes
1130+
llm_adapter_prefixes = (
1131+
"llm_adapter.",
1132+
"net.llm_adapter.",
1133+
"model.diffusion_model.llm_adapter.",
1134+
)
1135+
1136+
# Cosmos DiT key prefixes — support bare, `net.`, and `model.diffusion_model.` prefixes
11291137
cosmos_prefixes = (
11301138
"blocks.",
11311139
"t_embedder.",
@@ -1135,16 +1143,19 @@ def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool:
11351143
"net.t_embedder.",
11361144
"net.x_embedder.",
11371145
"net.final_layer.",
1146+
"model.diffusion_model.blocks.",
1147+
"model.diffusion_model.t_embedder.",
1148+
"model.diffusion_model.x_embedder.",
1149+
"model.diffusion_model.final_layer.",
11381150
)
11391151

11401152
for key in state_dict.keys():
11411153
if isinstance(key, int):
11421154
continue
1143-
if key.startswith("llm_adapter.") or key.startswith("net.llm_adapter."):
1155+
if any(key.startswith(p) for p in llm_adapter_prefixes):
11441156
has_llm_adapter = True
1145-
for prefix in cosmos_prefixes:
1146-
if key.startswith(prefix):
1147-
has_cosmos_dit = True
1157+
if any(key.startswith(p) for p in cosmos_prefixes):
1158+
has_cosmos_dit = True
11481159
if has_llm_adapter and has_cosmos_dit:
11491160
return True
11501161

invokeai/backend/model_manager/load/model_loaders/anima.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ def _load_from_singlefile(
6767
# Load the state dict from safetensors
6868
sd = load_file(model_path)
6969

70-
# Strip the `net.` prefix that all Anima checkpoint keys have
71-
# e.g., "net.blocks.0.self_attn.q_proj.weight" -> "blocks.0.self_attn.q_proj.weight"
70+
# Handle different checkpoint packaging formats:
71+
# - Official format: keys prefixed with `net.` (e.g. `net.blocks.0...`)
72+
# - ComfyUI bundled format: transformer keys prefixed with `model.diffusion_model.`
73+
# alongside `first_stage_model.*` (VAE) and `cond_stage_model.*` (text encoder)
7274
prefix_to_strip = None
73-
for prefix in ["net."]:
75+
for prefix in ["model.diffusion_model.", "net."]:
7476
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
7577
prefix_to_strip = prefix
7678
break
@@ -80,8 +82,7 @@ def _load_from_singlefile(
8082
for key, value in sd.items():
8183
if isinstance(key, str) and key.startswith(prefix_to_strip):
8284
stripped_sd[key[len(prefix_to_strip) :]] = value
83-
else:
84-
stripped_sd[key] = value
85+
# Skip non-transformer keys from bundled checkpoints (VAE, text encoder)
8586
sd = stripped_sd
8687

8788
# Create an empty AnimaTransformer with Anima's default architecture parameters
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import pytest
2+
3+
from invokeai.backend.model_manager.configs.main import _has_anima_keys
4+
5+
6+
def _make_state_dict(prefixes: list[str], keys: list[str]) -> dict[str, object]:
7+
"""Build a minimal fake state dict with the given prefixes applied to the given keys."""
8+
return {f"{prefix}{key}": None for prefix in prefixes for key in keys}
9+
10+
11+
# Minimal keys that satisfy both llm_adapter and cosmos DiT requirements
12+
ANIMA_LLM_ADAPTER_KEYS = ["llm_adapter.blocks.0.cross_attn.k_norm.weight"]
13+
ANIMA_COSMOS_DIT_KEYS = [
14+
"blocks.0.adaln_modulation_cross_attn.1.weight",
15+
"t_embedder.1.linear_1.weight",
16+
"x_embedder.proj.1.weight",
17+
"final_layer.adaln_modulation.1.weight",
18+
]
19+
20+
21+
class TestHasAnimaKeys:
22+
"""Tests for _has_anima_keys heuristic used during model identification."""
23+
24+
def test_bare_keys(self):
25+
"""Bare keys (no prefix) should be recognized."""
26+
sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
27+
assert _has_anima_keys(sd) is True
28+
29+
def test_net_prefix(self):
30+
"""Official format with `net.` prefix should be recognized."""
31+
sd = _make_state_dict(["net."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
32+
assert _has_anima_keys(sd) is True
33+
34+
def test_comfyui_bundled_prefix(self):
35+
"""ComfyUI bundled format with `model.diffusion_model.` prefix should be recognized."""
36+
sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
37+
assert _has_anima_keys(sd) is True
38+
39+
def test_comfyui_bundled_with_extra_keys(self):
40+
"""Bundled checkpoint with VAE and text encoder keys should still be recognized."""
41+
sd = _make_state_dict(["model.diffusion_model."], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
42+
# Add bundled VAE and text encoder keys (should not interfere)
43+
sd["first_stage_model.conv1.weight"] = None
44+
sd["first_stage_model.encoder.downsamples.0.weight"] = None
45+
sd["cond_stage_model.qwen3_06b.transformer.model.embed_tokens.weight"] = None
46+
assert _has_anima_keys(sd) is True
47+
48+
def test_missing_llm_adapter_keys(self):
49+
"""Should not match if llm_adapter keys are absent."""
50+
sd = _make_state_dict([""], ANIMA_COSMOS_DIT_KEYS)
51+
assert _has_anima_keys(sd) is False
52+
53+
def test_missing_cosmos_dit_keys(self):
54+
"""Should not match if Cosmos DiT keys are absent."""
55+
sd = _make_state_dict([""], ANIMA_LLM_ADAPTER_KEYS)
56+
assert _has_anima_keys(sd) is False
57+
58+
def test_empty_state_dict(self):
59+
"""Empty state dict should not match."""
60+
assert _has_anima_keys({}) is False
61+
62+
def test_unrelated_keys(self):
63+
"""State dict with unrelated keys should not match."""
64+
sd = {
65+
"model.diffusion_model.input_blocks.0.0.weight": None,
66+
"model.diffusion_model.output_blocks.0.0.weight": None,
67+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None,
68+
}
69+
assert _has_anima_keys(sd) is False
70+
71+
@pytest.mark.parametrize(
72+
"prefix",
73+
["", "net.", "model.diffusion_model."],
74+
)
75+
def test_all_prefixes_parametrized(self, prefix: str):
76+
"""All supported prefix formats should be recognized."""
77+
sd = _make_state_dict([prefix], ANIMA_LLM_ADAPTER_KEYS + ANIMA_COSMOS_DIT_KEYS)
78+
assert _has_anima_keys(sd) is True
79+
80+
81+
class TestAnimaDoesNotConflictWithOtherModels:
82+
"""Verify that _has_anima_keys does not false-positive on similar model architectures."""
83+
84+
def test_flux_bundled_checkpoint(self):
85+
"""FLUX bundled checkpoints use double_blocks/single_blocks, not blocks — should not match."""
86+
sd = {
87+
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": None,
88+
"model.diffusion_model.double_blocks.0.img_attn.proj.weight": None,
89+
"model.diffusion_model.single_blocks.0.linear1.weight": None,
90+
"model.diffusion_model.context_embedder.weight": None,
91+
"model.diffusion_model.img_in.weight": None,
92+
}
93+
assert _has_anima_keys(sd) is False
94+
95+
def test_sd1_bundled_checkpoint(self):
96+
"""SD1/SD2/SDXL bundled checkpoints use input_blocks/output_blocks — should not match."""
97+
sd = {
98+
"model.diffusion_model.input_blocks.0.0.weight": None,
99+
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": None,
100+
"model.diffusion_model.output_blocks.0.0.weight": None,
101+
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": None,
102+
"first_stage_model.encoder.down.0.block.0.conv1.weight": None,
103+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": None,
104+
}
105+
assert _has_anima_keys(sd) is False
106+
107+
def test_raw_cosmos_dit_without_llm_adapter(self):
108+
"""A raw Cosmos Predict2 DiT (without Anima's LLM adapter) should not match."""
109+
sd = {
110+
"blocks.0.adaln_modulation_cross_attn.1.weight": None,
111+
"blocks.0.self_attn.q_proj.weight": None,
112+
"t_embedder.1.linear_1.weight": None,
113+
"x_embedder.proj.1.weight": None,
114+
"final_layer.adaln_modulation.1.weight": None,
115+
}
116+
assert _has_anima_keys(sd) is False
117+
118+
def test_z_image_checkpoint(self):
119+
"""Z-Image uses blocks.* but with cap_embedder/context_refiner — should not match."""
120+
sd = {
121+
"model.diffusion_model.blocks.0.attn.to_q.weight": None,
122+
"model.diffusion_model.blocks.0.attn.to_k.weight": None,
123+
"model.diffusion_model.cap_embedder.0.weight": None,
124+
"model.diffusion_model.context_refiner.blocks.0.weight": None,
125+
"model.diffusion_model.t_embedder.mlp.0.weight": None,
126+
"model.diffusion_model.x_embedder.proj.weight": None,
127+
}
128+
# Z-Image has blocks/t_embedder/x_embedder but NOT llm_adapter
129+
assert _has_anima_keys(sd) is False
130+
131+
def test_qwen_image_checkpoint(self):
132+
"""QwenImage uses txt_in/txt_norm/img_in — should not match."""
133+
sd = {
134+
"txt_in.weight": None,
135+
"txt_norm.weight": None,
136+
"img_in.weight": None,
137+
"double_blocks.0.img_attn.proj.weight": None,
138+
"single_blocks.0.linear1.weight": None,
139+
}
140+
assert _has_anima_keys(sd) is False
141+
142+
def test_flux_lora_does_not_match(self):
143+
"""FLUX LoRA weights should not match as Anima."""
144+
sd = {
145+
"double_blocks.0.img_attn.proj.lora_down.weight": None,
146+
"double_blocks.0.img_attn.proj.lora_up.weight": None,
147+
"single_blocks.0.linear1.lora_down.weight": None,
148+
}
149+
assert _has_anima_keys(sd) is False
150+
151+
def test_cosmos_dit_bundled_without_llm_adapter(self):
152+
"""Bundled Cosmos DiT (model.diffusion_model. prefix) but no llm_adapter — should not match."""
153+
sd = {
154+
"model.diffusion_model.blocks.0.self_attn.q_proj.weight": None,
155+
"model.diffusion_model.t_embedder.1.linear_1.weight": None,
156+
"model.diffusion_model.x_embedder.proj.1.weight": None,
157+
"model.diffusion_model.final_layer.adaln_modulation.1.weight": None,
158+
"first_stage_model.encoder.downsamples.0.weight": None,
159+
"cond_stage_model.transformer.model.embed_tokens.weight": None,
160+
}
161+
# Has all the Cosmos DiT keys but missing llm_adapter — not Anima
162+
assert _has_anima_keys(sd) is False

0 commit comments

Comments
 (0)