|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
7 | 8 |
|
8 | 9 | from ...configuration_utils import ConfigMixin |
9 | 10 | from ...loaders import FromOriginalModelMixin |
10 | 11 | from ...utils import logging |
11 | 12 | from ...utils.accelerate_utils import apply_forward_hook |
| 13 | +from ..activations import get_activation |
12 | 14 | from ..modeling_outputs import AutoencoderKLOutput |
13 | 15 | from ..modeling_utils import ModelMixin |
14 | | -from .autoencoder_kl_wan import ( |
15 | | - WanAttentionBlock as AttentionBlock, |
16 | | -) |
17 | | -from .autoencoder_kl_wan import ( |
18 | | - WanCausalConv3d as CausalConv3d, |
19 | | -) |
20 | | -from .autoencoder_kl_wan import ( |
21 | | - WanResample as Resample, |
22 | | -) |
23 | | -from .autoencoder_kl_wan import ( |
24 | | - WanResidualBlock as ResidualBlock, |
25 | | -) |
26 | | -from .autoencoder_kl_wan import ( |
27 | | - WanRMS_norm as RMS_norm, |
28 | | -) |
29 | 16 | from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution |
30 | 17 |
|
31 | 18 |
|
32 | 19 | logger = logging.get_logger(__name__) |
33 | 20 | CACHE_T = 2 |
34 | 21 |
|
35 | 22 |
|
| 23 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanCausalConv3d |
| 24 | +class CausalConv3d(nn.Conv3d): |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + in_channels: int, |
| 28 | + out_channels: int, |
| 29 | + kernel_size: int | tuple[int, int, int], |
| 30 | + stride: int | tuple[int, int, int] = 1, |
| 31 | + padding: int | tuple[int, int, int] = 0, |
| 32 | + ) -> None: |
| 33 | + super().__init__( |
| 34 | + in_channels=in_channels, |
| 35 | + out_channels=out_channels, |
| 36 | + kernel_size=kernel_size, |
| 37 | + stride=stride, |
| 38 | + padding=padding, |
| 39 | + ) |
| 40 | + |
| 41 | + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) |
| 42 | + self.padding = (0, 0, 0) |
| 43 | + |
| 44 | + def forward(self, x, cache_x=None): |
| 45 | + padding = list(self._padding) |
| 46 | + if cache_x is not None and self._padding[4] > 0: |
| 47 | + cache_x = cache_x.to(x.device) |
| 48 | + x = torch.cat([cache_x, x], dim=2) |
| 49 | + padding[4] -= cache_x.shape[2] |
| 50 | + x = F.pad(x, padding) |
| 51 | + return super().forward(x) |
| 52 | + |
| 53 | + |
| 54 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanRMS_norm |
| 55 | +class RMS_norm(nn.Module): |
| 56 | + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: |
| 57 | + super().__init__() |
| 58 | + broadcastable_dims = (1, 1, 1) if not images else (1, 1) |
| 59 | + shape = (dim, *broadcastable_dims) if channel_first else (dim,) |
| 60 | + |
| 61 | + self.channel_first = channel_first |
| 62 | + self.scale = dim**0.5 |
| 63 | + self.gamma = nn.Parameter(torch.ones(shape)) |
| 64 | + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 |
| 65 | + |
| 66 | + def forward(self, x): |
| 67 | + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( |
| 68 | + t in str(x.dtype) for t in ("float4_", "float8_") |
| 69 | + ) |
| 70 | + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( |
| 71 | + x.dtype |
| 72 | + ) |
| 73 | + |
| 74 | + return normalized * self.scale * self.gamma + self.bias |
| 75 | + |
| 76 | + |
| 77 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanUpsample |
| 78 | +class Upsample(nn.Upsample): |
| 79 | + def forward(self, x): |
| 80 | + return super().forward(x.float()).type_as(x) |
| 81 | + |
| 82 | + |
| 83 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResample |
| 84 | +class Resample(nn.Module): |
| 85 | + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: |
| 86 | + super().__init__() |
| 87 | + self.dim = dim |
| 88 | + self.mode = mode |
| 89 | + |
| 90 | + if upsample_out_dim is None: |
| 91 | + upsample_out_dim = dim // 2 |
| 92 | + |
| 93 | + if mode == "upsample2d": |
| 94 | + self.resample = nn.Sequential( |
| 95 | + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), |
| 96 | + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), |
| 97 | + ) |
| 98 | + elif mode == "upsample3d": |
| 99 | + self.resample = nn.Sequential( |
| 100 | + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), |
| 101 | + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), |
| 102 | + ) |
| 103 | + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) |
| 104 | + elif mode == "downsample2d": |
| 105 | + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 106 | + elif mode == "downsample3d": |
| 107 | + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 108 | + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) |
| 109 | + else: |
| 110 | + self.resample = nn.Identity() |
| 111 | + |
| 112 | + def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 113 | + b, c, t, h, w = x.size() |
| 114 | + if self.mode == "upsample3d": |
| 115 | + if feat_cache is not None: |
| 116 | + idx = feat_idx[0] |
| 117 | + if feat_cache[idx] is None: |
| 118 | + feat_cache[idx] = "Rep" |
| 119 | + feat_idx[0] += 1 |
| 120 | + else: |
| 121 | + cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 122 | + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": |
| 123 | + cache_x = torch.cat( |
| 124 | + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 |
| 125 | + ) |
| 126 | + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": |
| 127 | + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) |
| 128 | + if feat_cache[idx] == "Rep": |
| 129 | + x = self.time_conv(x) |
| 130 | + else: |
| 131 | + x = self.time_conv(x, feat_cache[idx]) |
| 132 | + feat_cache[idx] = cache_x |
| 133 | + feat_idx[0] += 1 |
| 134 | + |
| 135 | + x = x.reshape(b, 2, c, t, h, w) |
| 136 | + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) |
| 137 | + x = x.reshape(b, c, t * 2, h, w) |
| 138 | + t = x.shape[2] |
| 139 | + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) |
| 140 | + x = self.resample(x) |
| 141 | + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) |
| 142 | + |
| 143 | + if self.mode == "downsample3d": |
| 144 | + if feat_cache is not None: |
| 145 | + idx = feat_idx[0] |
| 146 | + if feat_cache[idx] is None: |
| 147 | + feat_cache[idx] = x.clone() |
| 148 | + feat_idx[0] += 1 |
| 149 | + else: |
| 150 | + cache_x = x[:, :, -1:, :, :].clone() |
| 151 | + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) |
| 152 | + feat_cache[idx] = cache_x |
| 153 | + feat_idx[0] += 1 |
| 154 | + return x |
| 155 | + |
| 156 | + |
| 157 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResidualBlock |
| 158 | +class ResidualBlock(nn.Module): |
| 159 | + def __init__( |
| 160 | + self, |
| 161 | + in_dim: int, |
| 162 | + out_dim: int, |
| 163 | + dropout: float = 0.0, |
| 164 | + non_linearity: str = "silu", |
| 165 | + ) -> None: |
| 166 | + super().__init__() |
| 167 | + self.in_dim = in_dim |
| 168 | + self.out_dim = out_dim |
| 169 | + self.nonlinearity = get_activation(non_linearity) |
| 170 | + |
| 171 | + self.norm1 = RMS_norm(in_dim, images=False) |
| 172 | + self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1) |
| 173 | + self.norm2 = RMS_norm(out_dim, images=False) |
| 174 | + self.dropout = nn.Dropout(dropout) |
| 175 | + self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1) |
| 176 | + self.conv_shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() |
| 177 | + |
| 178 | + def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 179 | + h = self.conv_shortcut(x) |
| 180 | + |
| 181 | + x = self.norm1(x) |
| 182 | + x = self.nonlinearity(x) |
| 183 | + |
| 184 | + if feat_cache is not None: |
| 185 | + idx = feat_idx[0] |
| 186 | + cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 187 | + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: |
| 188 | + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) |
| 189 | + |
| 190 | + x = self.conv1(x, feat_cache[idx]) |
| 191 | + feat_cache[idx] = cache_x |
| 192 | + feat_idx[0] += 1 |
| 193 | + else: |
| 194 | + x = self.conv1(x) |
| 195 | + |
| 196 | + x = self.norm2(x) |
| 197 | + x = self.nonlinearity(x) |
| 198 | + x = self.dropout(x) |
| 199 | + |
| 200 | + if feat_cache is not None: |
| 201 | + idx = feat_idx[0] |
| 202 | + cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 203 | + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: |
| 204 | + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) |
| 205 | + |
| 206 | + x = self.conv2(x, feat_cache[idx]) |
| 207 | + feat_cache[idx] = cache_x |
| 208 | + feat_idx[0] += 1 |
| 209 | + else: |
| 210 | + x = self.conv2(x) |
| 211 | + |
| 212 | + return x + h |
| 213 | + |
| 214 | + |
| 215 | +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanAttentionBlock |
| 216 | +class AttentionBlock(nn.Module): |
| 217 | + def __init__(self, dim): |
| 218 | + super().__init__() |
| 219 | + self.dim = dim |
| 220 | + |
| 221 | + self.norm = RMS_norm(dim) |
| 222 | + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) |
| 223 | + self.proj = nn.Conv2d(dim, dim, 1) |
| 224 | + |
| 225 | + def forward(self, x): |
| 226 | + identity = x |
| 227 | + batch_size, channels, time, height, width = x.size() |
| 228 | + |
| 229 | + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) |
| 230 | + x = self.norm(x) |
| 231 | + |
| 232 | + qkv = self.to_qkv(x) |
| 233 | + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) |
| 234 | + qkv = qkv.permute(0, 1, 3, 2).contiguous() |
| 235 | + q, k, v = qkv.chunk(3, dim=-1) |
| 236 | + |
| 237 | + x = F.scaled_dot_product_attention(q, k, v) |
| 238 | + |
| 239 | + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) |
| 240 | + x = self.proj(x) |
| 241 | + x = x.view(batch_size, time, channels, height, width) |
| 242 | + x = x.permute(0, 2, 1, 3, 4) |
| 243 | + |
| 244 | + return x + identity |
| 245 | + |
| 246 | + |
36 | 247 | class Encoder3d(nn.Module): |
37 | 248 |
|
38 | 249 | def __init__(self, |
|
0 commit comments