Skip to content

Commit c943837

Browse files
rmatifgithub-actions[bot]yiyixuxu
authored
Add Anima modular pipeline (#13732)
* Add Anima pipeline * Fix empty Anima negative prompts * Fix Anima registration * Clean up Anima conditioner * Refactor Anima to modular * Use modular loader in Anima docs * Move Anima text conditioner * Apply style fixes * Address Anima review nits * Fix Anima rotary autocast * Clean up Anima LoRA checks * Update Anima model id * Document Anima LoRA loader * Document Anima conditioner forward --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent de6b049 commit c943837

24 files changed

Lines changed: 2525 additions & 1 deletion

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@
481481
title: Stable Audio
482482
title: Audio
483483
- sections:
484+
- local: api/pipelines/anima
485+
title: Anima
484486
- local: api/pipelines/animatediff
485487
title: AnimateDiff
486488
- local: api/pipelines/aura_flow

docs/source/en/api/loaders/lora.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
3030
- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).
3131
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
3232
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
33+
- [`AnimaLoraLoaderMixin`] provides similar functions for [Anima](https://huggingface.co/docs/diffusers/main/en/api/pipelines/anima).
3334
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
3435
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
3536
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
@@ -120,6 +121,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
120121

121122
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
122123

124+
## AnimaLoraLoaderMixin
125+
126+
[[autodoc]] loaders.lora_pipeline.AnimaLoraLoaderMixin
127+
123128
## HiDreamImageLoraLoaderMixin
124129

125130
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
@@ -141,4 +146,4 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
141146

142147
## LoraBaseMixin
143148

144-
[[autodoc]] loaders.lora_base.LoraBaseMixin
149+
[[autodoc]] loaders.lora_base.LoraBaseMixin
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<!-- Copyright 2026 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# Anima
16+
17+
Anima is a text-to-image model that reuses the [`CosmosTransformer3DModel`] with a Qwen3 text encoder, a T5-token text conditioner, and the [`AutoencoderKLQwenImage`] VAE.
18+
19+
```python
20+
import torch
21+
from diffusers import ModularPipeline
22+
23+
pipe = ModularPipeline.from_pretrained("circlestone-labs/Anima-Base-v1.0-Diffusers")
24+
pipe.load_components(torch_dtype=torch.bfloat16)
25+
pipe.to("cuda")
26+
27+
image = pipe(prompt="masterpiece, best quality, 1girl, solo, city lights").images[0]
28+
```
29+
30+
## AnimaModularPipeline
31+
32+
[[autodoc]] AnimaModularPipeline
33+
34+
## AnimaAutoBlocks
35+
36+
[[autodoc]] AnimaAutoBlocks
37+
38+
## AnimaTextConditioner
39+
40+
[[autodoc]] AnimaTextConditioner
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
"""
2+
Convert Anima checkpoints to Diffusers format.
3+
4+
Example:
5+
```bash
6+
python scripts/convert_anima_to_diffusers.py \
7+
--transformer_ckpt_path anima_model/anima-preview3-base.safetensors \
8+
--text_encoder_ckpt_path anima_model/qwen_3_06b_base.safetensors \
9+
--vae_ckpt_path anima_model/qwen_image_vae.safetensors \
10+
--qwen_tokenizer_path path/to/qwen25_tokenizer \
11+
--t5_tokenizer_path path/to/t5_tokenizer \
12+
--output_path anima_model/anima-preview3-diffusers \
13+
--save_pipeline
14+
```
15+
"""
16+
17+
import argparse
18+
import pathlib
19+
import sys
20+
from typing import Any
21+
22+
import torch
23+
from accelerate import init_empty_weights
24+
from convert_cosmos_to_diffusers import convert_transformer
25+
from safetensors.torch import load_file
26+
from transformers import AutoTokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast
27+
28+
from diffusers import (
29+
AnimaAutoBlocks,
30+
AnimaTextConditioner,
31+
AutoencoderKLQwenImage,
32+
FlowMatchEulerDiscreteScheduler,
33+
)
34+
35+
36+
DTYPE_MAPPING = {
37+
"fp32": torch.float32,
38+
"fp16": torch.float16,
39+
"bf16": torch.bfloat16,
40+
}
41+
42+
43+
def rename_residual_key(key: str) -> str:
44+
replacements = {
45+
".residual.0.": ".norm1.",
46+
".residual.2.": ".conv1.",
47+
".residual.3.": ".norm2.",
48+
".residual.6.": ".conv2.",
49+
".shortcut.": ".conv_shortcut.",
50+
}
51+
for old, new in replacements.items():
52+
key = key.replace(old, new)
53+
return key
54+
55+
56+
def rename_mid_key(key: str) -> str:
57+
replacements = {
58+
".middle.0.": ".mid_block.resnets.0.",
59+
".middle.1.": ".mid_block.attentions.0.",
60+
".middle.2.": ".mid_block.resnets.1.",
61+
}
62+
for old, new in replacements.items():
63+
key = key.replace(old, new)
64+
return rename_residual_key(key)
65+
66+
67+
def rename_decoder_upsample_key(key: str) -> str:
68+
prefix = "decoder.upsamples."
69+
suffix = key.removeprefix(prefix)
70+
index_str, rest = suffix.split(".", 1)
71+
index = int(index_str)
72+
73+
if index in (3, 7, 11):
74+
block_index = (index - 3) // 4
75+
new_key = f"decoder.up_blocks.{block_index}.upsamplers.0.{rest}"
76+
else:
77+
block_index = index // 4
78+
resnet_index = index % 4
79+
new_key = f"decoder.up_blocks.{block_index}.resnets.{resnet_index}.{rest}"
80+
81+
return rename_residual_key(new_key)
82+
83+
84+
def convert_qwen_image_vae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
85+
converted_state_dict = {}
86+
for key, value in state_dict.items():
87+
if key.startswith("conv1."):
88+
new_key = key.replace("conv1.", "quant_conv.", 1)
89+
elif key.startswith("conv2."):
90+
new_key = key.replace("conv2.", "post_quant_conv.", 1)
91+
elif key.startswith("encoder.conv1."):
92+
new_key = key.replace("encoder.conv1.", "encoder.conv_in.", 1)
93+
elif key.startswith("decoder.conv1."):
94+
new_key = key.replace("decoder.conv1.", "decoder.conv_in.", 1)
95+
elif key.startswith("encoder.downsamples."):
96+
new_key = rename_residual_key(key.replace("encoder.downsamples.", "encoder.down_blocks.", 1))
97+
elif key.startswith("decoder.upsamples."):
98+
new_key = rename_decoder_upsample_key(key)
99+
elif key.startswith("encoder.middle.") or key.startswith("decoder.middle."):
100+
new_key = rename_mid_key(key)
101+
elif key.startswith("encoder.head.0."):
102+
new_key = key.replace("encoder.head.0.", "encoder.norm_out.", 1)
103+
elif key.startswith("encoder.head.2."):
104+
new_key = key.replace("encoder.head.2.", "encoder.conv_out.", 1)
105+
elif key.startswith("decoder.head.0."):
106+
new_key = key.replace("decoder.head.0.", "decoder.norm_out.", 1)
107+
elif key.startswith("decoder.head.2."):
108+
new_key = key.replace("decoder.head.2.", "decoder.conv_out.", 1)
109+
else:
110+
new_key = rename_residual_key(key)
111+
112+
if new_key in converted_state_dict:
113+
raise ValueError(f"Duplicate converted VAE key: {new_key}")
114+
converted_state_dict[new_key] = value
115+
116+
return converted_state_dict
117+
118+
119+
def convert_qwen_image_vae(state_dict: dict[str, torch.Tensor]) -> AutoencoderKLQwenImage:
120+
converted_state_dict = convert_qwen_image_vae_state_dict(state_dict)
121+
with init_empty_weights():
122+
vae = AutoencoderKLQwenImage()
123+
124+
expected_keys = set(vae.state_dict().keys())
125+
converted_keys = set(converted_state_dict.keys())
126+
missing_keys = expected_keys - converted_keys
127+
unexpected_keys = converted_keys - expected_keys
128+
if missing_keys or unexpected_keys:
129+
if missing_keys:
130+
print(f"ERROR: missing VAE keys ({len(missing_keys)}):", file=sys.stderr)
131+
for key in sorted(missing_keys):
132+
print(key, file=sys.stderr)
133+
if unexpected_keys:
134+
print(f"ERROR: unexpected VAE keys ({len(unexpected_keys)}):", file=sys.stderr)
135+
for key in sorted(unexpected_keys):
136+
print(key, file=sys.stderr)
137+
sys.exit(1)
138+
139+
vae.load_state_dict(converted_state_dict, strict=True, assign=True)
140+
return vae
141+
142+
143+
def infer_text_conditioner_config(state_dict: dict[str, torch.Tensor]) -> dict[str, Any]:
144+
model_dim = state_dict["blocks.0.self_attn.q_proj.weight"].shape[0]
145+
source_dim = state_dict["blocks.0.cross_attn.k_proj.weight"].shape[1]
146+
target_vocab_size, target_dim = state_dict["embed.weight"].shape
147+
attention_head_dim = state_dict["blocks.0.self_attn.q_norm.weight"].shape[0]
148+
num_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("blocks."))
149+
150+
return {
151+
"source_dim": source_dim,
152+
"target_dim": target_dim,
153+
"model_dim": model_dim,
154+
"num_layers": num_layers,
155+
"num_attention_heads": model_dim // attention_head_dim,
156+
"target_vocab_size": target_vocab_size,
157+
}
158+
159+
160+
def convert_text_conditioner(state_dict: dict[str, torch.Tensor]) -> AnimaTextConditioner:
161+
config = infer_text_conditioner_config(state_dict)
162+
with init_empty_weights():
163+
text_conditioner = AnimaTextConditioner(**config)
164+
165+
expected_keys = set(text_conditioner.state_dict().keys())
166+
converted_keys = set(state_dict.keys())
167+
missing_keys = expected_keys - converted_keys
168+
unexpected_keys = converted_keys - expected_keys
169+
if missing_keys or unexpected_keys:
170+
if missing_keys:
171+
print(f"ERROR: missing text conditioner keys ({len(missing_keys)}):", file=sys.stderr)
172+
for key in sorted(missing_keys):
173+
print(key, file=sys.stderr)
174+
if unexpected_keys:
175+
print(f"ERROR: unexpected text conditioner keys ({len(unexpected_keys)}):", file=sys.stderr)
176+
for key in sorted(unexpected_keys):
177+
print(key, file=sys.stderr)
178+
sys.exit(1)
179+
180+
text_conditioner.load_state_dict(state_dict, strict=True, assign=True)
181+
return text_conditioner
182+
183+
184+
def infer_qwen3_config(state_dict: dict[str, torch.Tensor]) -> Qwen3Config:
185+
vocab_size, hidden_size = state_dict["embed_tokens.weight"].shape
186+
intermediate_size = state_dict["layers.0.mlp.gate_proj.weight"].shape[0]
187+
num_hidden_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("layers."))
188+
head_dim = state_dict["layers.0.self_attn.q_norm.weight"].shape[0]
189+
num_attention_heads = state_dict["layers.0.self_attn.q_proj.weight"].shape[0] // head_dim
190+
num_key_value_heads = state_dict["layers.0.self_attn.k_proj.weight"].shape[0] // head_dim
191+
192+
return Qwen3Config(
193+
vocab_size=vocab_size,
194+
hidden_size=hidden_size,
195+
intermediate_size=intermediate_size,
196+
num_hidden_layers=num_hidden_layers,
197+
num_attention_heads=num_attention_heads,
198+
num_key_value_heads=num_key_value_heads,
199+
max_position_embeddings=32768,
200+
rms_norm_eps=1e-6,
201+
rope_theta=1000000.0,
202+
head_dim=head_dim,
203+
attention_bias=False,
204+
tie_word_embeddings=False,
205+
)
206+
207+
208+
def convert_text_encoder(state_dict: dict[str, torch.Tensor]) -> Qwen3Model:
209+
state_dict = {key.removeprefix("model."): value for key, value in state_dict.items()}
210+
config = infer_qwen3_config(state_dict)
211+
with init_empty_weights():
212+
text_encoder = Qwen3Model(config)
213+
214+
expected_keys = set(text_encoder.state_dict().keys())
215+
converted_keys = set(state_dict.keys())
216+
missing_keys = expected_keys - converted_keys
217+
unexpected_keys = converted_keys - expected_keys
218+
if missing_keys or unexpected_keys:
219+
if missing_keys:
220+
print(f"ERROR: missing Qwen3 keys ({len(missing_keys)}):", file=sys.stderr)
221+
for key in sorted(missing_keys):
222+
print(key, file=sys.stderr)
223+
if unexpected_keys:
224+
print(f"ERROR: unexpected Qwen3 keys ({len(unexpected_keys)}):", file=sys.stderr)
225+
for key in sorted(unexpected_keys):
226+
print(key, file=sys.stderr)
227+
sys.exit(1)
228+
229+
text_encoder.load_state_dict(state_dict, strict=True, assign=True)
230+
return text_encoder
231+
232+
233+
def split_anima_transformer_checkpoint(
234+
state_dict: dict[str, torch.Tensor],
235+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
236+
transformer_state_dict = {}
237+
text_conditioner_state_dict = {}
238+
adapter_prefix = "net.llm_adapter."
239+
240+
for key, value in state_dict.items():
241+
if key.startswith(adapter_prefix):
242+
text_conditioner_state_dict[key.removeprefix(adapter_prefix)] = value
243+
else:
244+
transformer_state_dict[key] = value
245+
246+
return transformer_state_dict, text_conditioner_state_dict
247+
248+
249+
def save_pipeline(args, transformer, text_conditioner, text_encoder, vae):
250+
tokenizer = AutoTokenizer.from_pretrained(args.qwen_tokenizer_path)
251+
t5_tokenizer = T5TokenizerFast.from_pretrained(args.t5_tokenizer_path)
252+
scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0)
253+
254+
pipe = AnimaAutoBlocks().init_pipeline()
255+
pipe.update_components(
256+
text_encoder=text_encoder,
257+
tokenizer=tokenizer,
258+
t5_tokenizer=t5_tokenizer,
259+
text_conditioner=text_conditioner,
260+
transformer=transformer,
261+
vae=vae,
262+
scheduler=scheduler,
263+
)
264+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size=args.max_shard_size)
265+
266+
267+
def get_args():
268+
parser = argparse.ArgumentParser()
269+
parser.add_argument("--transformer_ckpt_path", type=str, required=True, help="Path to Anima DiT safetensors")
270+
parser.add_argument("--text_encoder_ckpt_path", type=str, required=True, help="Path to Qwen3 text encoder")
271+
parser.add_argument("--vae_ckpt_path", type=str, required=True, help="Path to Qwen-Image VAE safetensors")
272+
parser.add_argument("--qwen_tokenizer_path", type=str, default=None)
273+
parser.add_argument("--t5_tokenizer_path", type=str, default=None)
274+
parser.add_argument("--output_path", type=str, required=True)
275+
parser.add_argument("--save_pipeline", action="store_true")
276+
parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAPPING.keys()))
277+
parser.add_argument("--max_shard_size", default="5GB")
278+
return parser.parse_args()
279+
280+
281+
if __name__ == "__main__":
282+
args = get_args()
283+
output_path = pathlib.Path(args.output_path)
284+
dtype = DTYPE_MAPPING[args.dtype]
285+
286+
raw_transformer_state_dict = load_file(args.transformer_ckpt_path, device="cpu")
287+
transformer_state_dict, text_conditioner_state_dict = split_anima_transformer_checkpoint(
288+
raw_transformer_state_dict
289+
)
290+
transformer = convert_transformer(
291+
"Cosmos-2.0-Diffusion-2B-Text2Image", state_dict=transformer_state_dict, weights_only=True
292+
).to(dtype=dtype)
293+
text_conditioner = convert_text_conditioner(text_conditioner_state_dict).to(dtype=dtype)
294+
295+
text_encoder_state_dict = load_file(args.text_encoder_ckpt_path, device="cpu")
296+
text_encoder = convert_text_encoder(text_encoder_state_dict).to(dtype=dtype)
297+
298+
vae_state_dict = load_file(args.vae_ckpt_path, device="cpu")
299+
vae = convert_qwen_image_vae(vae_state_dict).to(dtype=dtype)
300+
301+
if args.save_pipeline:
302+
if args.qwen_tokenizer_path is None or args.t5_tokenizer_path is None:
303+
raise ValueError("`--qwen_tokenizer_path` and `--t5_tokenizer_path` are required with `--save_pipeline`.")
304+
save_pipeline(args, transformer, text_conditioner, text_encoder, vae)
305+
else:
306+
output_path.mkdir(parents=True, exist_ok=True)
307+
transformer.save_pretrained(
308+
output_path / "transformer", safe_serialization=True, max_shard_size=args.max_shard_size
309+
)
310+
text_conditioner.save_pretrained(
311+
output_path / "text_conditioner", safe_serialization=True, max_shard_size=args.max_shard_size
312+
)
313+
text_encoder.save_pretrained(
314+
output_path / "text_encoder", safe_serialization=True, max_shard_size=args.max_shard_size
315+
)
316+
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size=args.max_shard_size)

0 commit comments

Comments
 (0)