Skip to content

Commit 5ebbbd7

Browse files
committed
upd
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 8190e3f commit 5ebbbd7

3 files changed

Lines changed: 255 additions & 233 deletions

File tree

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,6 @@
10071007
AutoencoderKLMagvit,
10081008
AutoencoderKLMochi,
10091009
AutoencoderKLQwenImage,
1010-
JoyAIImageVAE,
10111010
AutoencoderKLTemporalDecoder,
10121011
AutoencoderKLWan,
10131012
AutoencoderOobleck,
@@ -1044,11 +1043,12 @@
10441043
HunyuanDiT2DModel,
10451044
HunyuanDiT2DMultiControlNetModel,
10461045
HunyuanImageTransformer2DModel,
1047-
JoyAIImageTransformer3DModel,
10481046
HunyuanVideo15Transformer3DModel,
10491047
HunyuanVideoFramepackTransformer3DModel,
10501048
HunyuanVideoTransformer3DModel,
10511049
I2VGenXLUNet,
1050+
JoyAIImageTransformer3DModel,
1051+
JoyAIImageVAE,
10521052
Kandinsky3UNet,
10531053
Kandinsky5Transformer3DModel,
10541054
LatteTransformer3DModel,

src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py

Lines changed: 226 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,246 @@
44

55
import torch
66
import torch.nn as nn
7+
import torch.nn.functional as F
78

89
from ...configuration_utils import ConfigMixin
910
from ...loaders import FromOriginalModelMixin
1011
from ...utils import logging
1112
from ...utils.accelerate_utils import apply_forward_hook
13+
from ..activations import get_activation
1214
from ..modeling_outputs import AutoencoderKLOutput
1315
from ..modeling_utils import ModelMixin
14-
from .autoencoder_kl_wan import (
15-
WanAttentionBlock as AttentionBlock,
16-
)
17-
from .autoencoder_kl_wan import (
18-
WanCausalConv3d as CausalConv3d,
19-
)
20-
from .autoencoder_kl_wan import (
21-
WanResample as Resample,
22-
)
23-
from .autoencoder_kl_wan import (
24-
WanResidualBlock as ResidualBlock,
25-
)
26-
from .autoencoder_kl_wan import (
27-
WanRMS_norm as RMS_norm,
28-
)
2916
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3017

3118

3219
logger = logging.get_logger(__name__)
3320
CACHE_T = 2
3421

3522

23+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanCausalConv3d
24+
class CausalConv3d(nn.Conv3d):
25+
def __init__(
26+
self,
27+
in_channels: int,
28+
out_channels: int,
29+
kernel_size: int | tuple[int, int, int],
30+
stride: int | tuple[int, int, int] = 1,
31+
padding: int | tuple[int, int, int] = 0,
32+
) -> None:
33+
super().__init__(
34+
in_channels=in_channels,
35+
out_channels=out_channels,
36+
kernel_size=kernel_size,
37+
stride=stride,
38+
padding=padding,
39+
)
40+
41+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
42+
self.padding = (0, 0, 0)
43+
44+
def forward(self, x, cache_x=None):
45+
padding = list(self._padding)
46+
if cache_x is not None and self._padding[4] > 0:
47+
cache_x = cache_x.to(x.device)
48+
x = torch.cat([cache_x, x], dim=2)
49+
padding[4] -= cache_x.shape[2]
50+
x = F.pad(x, padding)
51+
return super().forward(x)
52+
53+
54+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanRMS_norm
55+
class RMS_norm(nn.Module):
56+
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
57+
super().__init__()
58+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
59+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
60+
61+
self.channel_first = channel_first
62+
self.scale = dim**0.5
63+
self.gamma = nn.Parameter(torch.ones(shape))
64+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
65+
66+
def forward(self, x):
67+
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
68+
t in str(x.dtype) for t in ("float4_", "float8_")
69+
)
70+
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
71+
x.dtype
72+
)
73+
74+
return normalized * self.scale * self.gamma + self.bias
75+
76+
77+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanUpsample
78+
class Upsample(nn.Upsample):
79+
def forward(self, x):
80+
return super().forward(x.float()).type_as(x)
81+
82+
83+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResample
84+
class Resample(nn.Module):
85+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
86+
super().__init__()
87+
self.dim = dim
88+
self.mode = mode
89+
90+
if upsample_out_dim is None:
91+
upsample_out_dim = dim // 2
92+
93+
if mode == "upsample2d":
94+
self.resample = nn.Sequential(
95+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
96+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
97+
)
98+
elif mode == "upsample3d":
99+
self.resample = nn.Sequential(
100+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
101+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
102+
)
103+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
104+
elif mode == "downsample2d":
105+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
106+
elif mode == "downsample3d":
107+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108+
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
109+
else:
110+
self.resample = nn.Identity()
111+
112+
def forward(self, x, feat_cache=None, feat_idx=[0]):
113+
b, c, t, h, w = x.size()
114+
if self.mode == "upsample3d":
115+
if feat_cache is not None:
116+
idx = feat_idx[0]
117+
if feat_cache[idx] is None:
118+
feat_cache[idx] = "Rep"
119+
feat_idx[0] += 1
120+
else:
121+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
122+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
123+
cache_x = torch.cat(
124+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
125+
)
126+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
127+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
128+
if feat_cache[idx] == "Rep":
129+
x = self.time_conv(x)
130+
else:
131+
x = self.time_conv(x, feat_cache[idx])
132+
feat_cache[idx] = cache_x
133+
feat_idx[0] += 1
134+
135+
x = x.reshape(b, 2, c, t, h, w)
136+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
137+
x = x.reshape(b, c, t * 2, h, w)
138+
t = x.shape[2]
139+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
140+
x = self.resample(x)
141+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
142+
143+
if self.mode == "downsample3d":
144+
if feat_cache is not None:
145+
idx = feat_idx[0]
146+
if feat_cache[idx] is None:
147+
feat_cache[idx] = x.clone()
148+
feat_idx[0] += 1
149+
else:
150+
cache_x = x[:, :, -1:, :, :].clone()
151+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
152+
feat_cache[idx] = cache_x
153+
feat_idx[0] += 1
154+
return x
155+
156+
157+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResidualBlock
158+
class ResidualBlock(nn.Module):
159+
def __init__(
160+
self,
161+
in_dim: int,
162+
out_dim: int,
163+
dropout: float = 0.0,
164+
non_linearity: str = "silu",
165+
) -> None:
166+
super().__init__()
167+
self.in_dim = in_dim
168+
self.out_dim = out_dim
169+
self.nonlinearity = get_activation(non_linearity)
170+
171+
self.norm1 = RMS_norm(in_dim, images=False)
172+
self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1)
173+
self.norm2 = RMS_norm(out_dim, images=False)
174+
self.dropout = nn.Dropout(dropout)
175+
self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1)
176+
self.conv_shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
177+
178+
def forward(self, x, feat_cache=None, feat_idx=[0]):
179+
h = self.conv_shortcut(x)
180+
181+
x = self.norm1(x)
182+
x = self.nonlinearity(x)
183+
184+
if feat_cache is not None:
185+
idx = feat_idx[0]
186+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
187+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
188+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
189+
190+
x = self.conv1(x, feat_cache[idx])
191+
feat_cache[idx] = cache_x
192+
feat_idx[0] += 1
193+
else:
194+
x = self.conv1(x)
195+
196+
x = self.norm2(x)
197+
x = self.nonlinearity(x)
198+
x = self.dropout(x)
199+
200+
if feat_cache is not None:
201+
idx = feat_idx[0]
202+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
203+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
204+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
205+
206+
x = self.conv2(x, feat_cache[idx])
207+
feat_cache[idx] = cache_x
208+
feat_idx[0] += 1
209+
else:
210+
x = self.conv2(x)
211+
212+
return x + h
213+
214+
215+
# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanAttentionBlock
216+
class AttentionBlock(nn.Module):
217+
def __init__(self, dim):
218+
super().__init__()
219+
self.dim = dim
220+
221+
self.norm = RMS_norm(dim)
222+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
223+
self.proj = nn.Conv2d(dim, dim, 1)
224+
225+
def forward(self, x):
226+
identity = x
227+
batch_size, channels, time, height, width = x.size()
228+
229+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
230+
x = self.norm(x)
231+
232+
qkv = self.to_qkv(x)
233+
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
234+
qkv = qkv.permute(0, 1, 3, 2).contiguous()
235+
q, k, v = qkv.chunk(3, dim=-1)
236+
237+
x = F.scaled_dot_product_attention(q, k, v)
238+
239+
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
240+
x = self.proj(x)
241+
x = x.view(batch_size, time, channels, height, width)
242+
x = x.permute(0, 2, 1, 3, 4)
243+
244+
return x + identity
245+
246+
36247
class Encoder3d(nn.Module):
37248

38249
def __init__(self,

0 commit comments

Comments
 (0)