Skip to content

Commit 8b9f817

Browse files
adi776borategithub-actions[bot]hlky
authored
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning (#12814)
* Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning * Apply style fixes * Fix: Remove import-time autocast in Kandinsky to prevent warnings - Removed @torch.autocast decorator from Kandinsky classes. - Implemented manual F.linear casting to ensure numerical parity with FP32. - Verified bit-exact output matches main branch. Co-authored-by: hlky <hlky@hlky.ac> * Used _keep_in_fp32_modules to align with standards --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: hlky <hlky@hlky.ac>
1 parent b1f06b7 commit 8b9f817

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/diffusers/models/transformers/transformer_kandinsky.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0):
165165
self.activation = nn.SiLU()
166166
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
167167

168-
@torch.autocast(device_type="cuda", dtype=torch.float32)
169168
def forward(self, time):
170-
args = torch.outer(time, self.freqs.to(device=time.device))
169+
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
171170
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
172171
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
173172
return time_embed
@@ -269,7 +268,6 @@ def __init__(self, time_dim, model_dim, num_params):
269268
self.out_layer.weight.data.zero_()
270269
self.out_layer.bias.data.zero_()
271270

272-
@torch.autocast(device_type="cuda", dtype=torch.float32)
273271
def forward(self, x):
274272
return self.out_layer(self.activation(x))
275273

@@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel(
525523
"Kandinsky5TransformerEncoderBlock",
526524
"Kandinsky5TransformerDecoderBlock",
527525
]
526+
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
528527
_supports_gradient_checkpointing = True
529528

530529
@register_to_config

0 commit comments

Comments
 (0)