Skip to content

Commit f0d09c3

Browse files
4pointohYour NamekappacommitPfannkuchensackclaude
authored
feat: add Anima model support (invoke-ai#8961)
* feat: add Anima model support * schema * image to image * regional guidance * loras * last fixes * tests * fix attributions * fix attributions * refactor to use diffusers reference * fix an additional lora type * some adjustments to follow flux 2 paper implementation * use t5 from model manager instead of downloading * make lora identification more reliable * fix: resolve lint errors in anima module Remove unused variable, fix import ordering, inline dict() call, and address minor lint issues across anima-related files. * Chore Ruff format again * fix regional guidance error * fix(anima): validate unexpected keys after strict=False checkpoint loading Capture the load_state_dict result and raise RuntimeError on unexpected keys (indicating a corrupted or incompatible checkpoint), while logging a warning for missing keys (expected for inv_freq buffers regenerated at runtime). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): make model loader submodel fields required instead of Optional Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): add Classification.Prototype to LoRA loaders, fix exception types Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): fix replace-all in key conversion, warn on DoRA+LoKR, unify grouping functions - Use key.replace(old, new, 1) in _convert_kohya_unet_key and _convert_kohya_te_key to avoid replacing multiple occurrences - Upgrade DoRA+LoKR dora_scale strip from logger.debug to logger.warning since it represents data loss - Replace _group_kohya_keys and _group_by_layer with a single _group_keys_by_layer function parameterized by extra_suffixes, with _KOHYA_KNOWN_SUFFIXES and _PEFT_EXTRA_SUFFIXES constants - Add test_empty_state_dict_returns_empty_model to verify empty input produces a model with no layers Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): add safety cap for Qwen3 sequence length to prevent OOM Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): add denoising range validation, fix closure capture, add edge case tests Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): add T5 to metadata, fix dead code, decouple scheduler type guard Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(anima): update VAE field description for required field Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: regenerate frontend types after upstream merge Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: ruff format anima_denoise.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(anima): add T5 encoder metadata recall handler The T5 encoder was added to generation metadata but had no recall handler, so it wasn't restored when recalling from metadata. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore(frontend): add regression test for buildAnimaGraph Add tests for CFG gating (negative conditioning omitted when cfgScale <= 1) and basic graph structure (model loader, text encoder, denoise nodes). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * only show 0.6b for anima * dont show 0.6b for other models * schema * Anima preview 3 * fix ci --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: kappacommit <samwolfe40@gmail.com> Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 60d0bcd commit f0d09c3

68 files changed

Lines changed: 6116 additions & 55 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

invokeai/app/api/dependencies.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
4747
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
4848
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
49+
AnimaConditioningInfo,
4950
BasicConditioningInfo,
5051
CogView4ConditioningInfo,
5152
ConditioningFieldData,
@@ -140,6 +141,7 @@ def initialize(
140141
SD3ConditioningInfo,
141142
CogView4ConditioningInfo,
142143
ZImageConditioningInfo,
144+
AnimaConditioningInfo,
143145
],
144146
ephemeral=True,
145147
),

invokeai/app/invocations/anima_denoise.py

Lines changed: 715 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Anima image-to-latents invocation.
2+
3+
Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE).
4+
5+
For Wan VAE (AutoencoderKLWan):
6+
- Input image is converted to 5D tensor [B, C, T, H, W] with T=1
7+
- After encoding, latents are normalized: (latents - mean) / std
8+
(inverse of the denormalization in anima_latents_to_image.py)
9+
10+
For FLUX VAE (AutoEncoder):
11+
- Encoding is handled internally by the FLUX VAE
12+
"""
13+
14+
from typing import Union
15+
16+
import einops
17+
import torch
18+
from diffusers.models.autoencoders import AutoencoderKLWan
19+
20+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
21+
from invokeai.app.invocations.fields import (
22+
FieldDescriptions,
23+
ImageField,
24+
Input,
25+
InputField,
26+
WithBoard,
27+
WithMetadata,
28+
)
29+
from invokeai.app.invocations.model import VAEField
30+
from invokeai.app.invocations.primitives import LatentsOutput
31+
from invokeai.app.services.shared.invocation_context import InvocationContext
32+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
33+
from invokeai.backend.model_manager.load.load_base import LoadedModel
34+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
35+
from invokeai.backend.util.devices import TorchDevice
36+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
37+
38+
AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder]
39+
40+
41+
@invocation(
42+
"anima_i2l",
43+
title="Image to Latents - Anima",
44+
tags=["image", "latents", "vae", "i2l", "anima"],
45+
category="image",
46+
version="1.0.0",
47+
classification=Classification.Prototype,
48+
)
49+
class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
50+
"""Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE)."""
51+
52+
image: ImageField = InputField(description="The image to encode.")
53+
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
54+
55+
@staticmethod
56+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
57+
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
58+
raise TypeError(
59+
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
60+
)
61+
62+
estimated_working_memory = estimate_vae_working_memory_flux(
63+
operation="encode",
64+
image_tensor=image_tensor,
65+
vae=vae_info.model,
66+
)
67+
68+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
69+
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
70+
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
71+
72+
vae_dtype = next(iter(vae.parameters())).dtype
73+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
74+
75+
with torch.inference_mode():
76+
if isinstance(vae, FluxAutoEncoder):
77+
# FLUX VAE handles scaling internally
78+
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
79+
latents = vae.encode(image_tensor, sample=True, generator=generator)
80+
else:
81+
# AutoencoderKLWan expects 5D input [B, C, T, H, W]
82+
if image_tensor.ndim == 4:
83+
image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
84+
85+
encoded = vae.encode(image_tensor, return_dict=False)[0]
86+
latents = encoded.sample().to(dtype=vae_dtype)
87+
88+
# Normalize to denoiser space: (latents - mean) / std
89+
# This is the inverse of the denormalization in anima_latents_to_image.py
90+
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
91+
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
92+
latents = (latents - latents_mean) / latents_std
93+
94+
# Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W]
95+
if latents.ndim == 5:
96+
latents = latents.squeeze(2)
97+
98+
return latents
99+
100+
@torch.no_grad()
101+
def invoke(self, context: InvocationContext) -> LatentsOutput:
102+
image = context.images.get_pil(self.image.image_name)
103+
104+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
105+
if image_tensor.dim() == 3:
106+
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
107+
108+
vae_info = context.models.load(self.vae.vae)
109+
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
110+
raise TypeError(
111+
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
112+
)
113+
114+
context.util.signal_progress("Running Anima VAE encode")
115+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
116+
117+
latents = latents.to("cpu")
118+
name = context.tensors.save(tensor=latents)
119+
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Anima latents-to-image invocation.
2+
3+
Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or
4+
compatible FLUX VAE as fallback.
5+
6+
Latents from the denoiser are in normalized space (zero-centered). Before
7+
VAE decode, they must be denormalized using the Wan 2.1 per-channel
8+
mean/std: latents = latents * std + mean (matching diffusers WanPipeline).
9+
10+
The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1.
11+
"""
12+
13+
import torch
14+
from diffusers.models.autoencoders import AutoencoderKLWan
15+
from einops import rearrange
16+
from PIL import Image
17+
18+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
19+
from invokeai.app.invocations.fields import (
20+
FieldDescriptions,
21+
Input,
22+
InputField,
23+
LatentsField,
24+
WithBoard,
25+
WithMetadata,
26+
)
27+
from invokeai.app.invocations.model import VAEField
28+
from invokeai.app.invocations.primitives import ImageOutput
29+
from invokeai.app.services.shared.invocation_context import InvocationContext
30+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
31+
from invokeai.backend.util.devices import TorchDevice
32+
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
33+
34+
35+
@invocation(
36+
"anima_l2i",
37+
title="Latents to Image - Anima",
38+
tags=["latents", "image", "vae", "l2i", "anima"],
39+
category="latents",
40+
version="1.0.2",
41+
classification=Classification.Prototype,
42+
)
43+
class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
44+
"""Generates an image from latents using the Anima VAE.
45+
46+
Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit
47+
latent denormalization, and FLUX VAE as fallback.
48+
"""
49+
50+
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
51+
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
52+
53+
@torch.no_grad()
54+
def invoke(self, context: InvocationContext) -> ImageOutput:
55+
latents = context.tensors.load(self.latents.latents_name)
56+
57+
vae_info = context.models.load(self.vae.vae)
58+
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
59+
raise TypeError(
60+
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
61+
)
62+
63+
estimated_working_memory = estimate_vae_working_memory_flux(
64+
operation="decode",
65+
image_tensor=latents,
66+
vae=vae_info.model,
67+
)
68+
69+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
70+
context.util.signal_progress("Running Anima VAE decode")
71+
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
72+
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
73+
74+
vae_dtype = next(iter(vae.parameters())).dtype
75+
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
76+
77+
TorchDevice.empty_cache()
78+
79+
with torch.inference_mode():
80+
if isinstance(vae, FluxAutoEncoder):
81+
# FLUX VAE handles scaling internally, expects 4D [B, C, H, W]
82+
img = vae.decode(latents)
83+
else:
84+
# Expects 5D latents [B, C, T, H, W]
85+
if latents.ndim == 4:
86+
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
87+
88+
# Denormalize from denoiser space to raw VAE space
89+
# (same as diffusers WanPipeline and ComfyUI Wan21.process_out)
90+
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
91+
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
92+
latents = latents * latents_std + latents_mean
93+
94+
decoded = vae.decode(latents, return_dict=False)[0]
95+
96+
# Output is 5D [B, C, T, H, W] — squeeze temporal dim
97+
if decoded.ndim == 5:
98+
decoded = decoded.squeeze(2)
99+
img = decoded
100+
101+
img = img.clamp(-1, 1)
102+
img = rearrange(img[0], "c h w -> h w c")
103+
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
104+
105+
TorchDevice.empty_cache()
106+
107+
image_dto = context.images.save(image=img_pil)
108+
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)