Skip to content

Commit 5fc2bd2

Browse files
YangKai0616github-actions[bot]dg845
authored
Stabilize low-precision custom autoencoder RMS normalization (#13316)
* Stabilize low-precision custom autoencoder RMS normalization * Add fp8/4 * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 6350a76 commit 5fc2bd2

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
8787
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
8888

8989
def forward(self, x):
90-
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
90+
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
91+
t in str(x.dtype) for t in ("float4_", "float8_")
92+
)
93+
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
94+
x.dtype
95+
)
96+
97+
return normalized * self.scale * self.gamma + self.bias
9198

9299

93100
class HunyuanImageRefinerAttnBlock(nn.Module):

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
8787
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
8888

8989
def forward(self, x):
90-
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
90+
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
91+
t in str(x.dtype) for t in ("float4_", "float8_")
92+
)
93+
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
94+
x.dtype
95+
)
96+
97+
return normalized * self.scale * self.gamma + self.bias
9198

9299

93100
class HunyuanVideo15AttnBlock(nn.Module):

src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
105105
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
106106

107107
def forward(self, x):
108-
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
108+
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
109+
t in str(x.dtype) for t in ("float4_", "float8_")
110+
)
111+
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
112+
x.dtype
113+
)
114+
115+
return normalized * self.scale * self.gamma + self.bias
109116

110117

111118
class QwenImageUpsample(nn.Upsample):

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
196196
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
197197

198198
def forward(self, x):
199-
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
199+
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
200+
t in str(x.dtype) for t in ("float4_", "float8_")
201+
)
202+
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
203+
x.dtype
204+
)
205+
206+
return normalized * self.scale * self.gamma + self.bias
200207

201208

202209
class WanUpsample(nn.Upsample):

0 commit comments

Comments
 (0)