Skip to content

Commit 49f5b35

Browse files
author
Ting-Yun Chang
committed
Use _keep_in_fp32_modules instead of autocast
1 parent c8513c2 commit 49f5b35

1 file changed

Lines changed: 12 additions & 16 deletions

File tree

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7575
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
7676

7777
def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
78-
with torch.amp.autocast(timestep.device.type, dtype=torch.float32):
79-
timesteps_proj = self.time_proj(timestep).to(torch.float32)
80-
temb = self.t_embedder(timesteps_proj)
81-
embedded_timestep = self.norm(timesteps_proj)
78+
timesteps_proj = self.time_proj(timestep.float())
79+
temb = self.t_embedder(timesteps_proj)
80+
embedded_timestep = self.norm(timesteps_proj)
8281
return temb, embedded_timestep
8382

8483

@@ -134,14 +133,12 @@ def forward(
134133
temb: torch.Tensor | None = None,
135134
) -> torch.Tensor:
136135
original_dtype = hidden_states.dtype
137-
with torch.amp.autocast(hidden_states.device.type, dtype=torch.float32):
138-
embedded_timestep = self.activation(embedded_timestep)
139-
embedded_timestep = self.linear_1(embedded_timestep)
140-
embedded_timestep = self.linear_2(embedded_timestep)
141-
142-
if temb is not None:
143-
embedded_timestep = embedded_timestep + temb
144-
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
136+
embedded_timestep = self.activation(embedded_timestep.float())
137+
embedded_timestep = self.linear_1(embedded_timestep)
138+
embedded_timestep = self.linear_2(embedded_timestep)
139+
if temb is not None:
140+
embedded_timestep = embedded_timestep + temb.float()
141+
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
145142
shift = shift.to(original_dtype)
146143
scale = scale.to(original_dtype)
147144
gate = gate.to(original_dtype)
@@ -611,7 +608,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
611608
_supports_gradient_checkpointing = True
612609
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
613610
_no_split_modules = ["CosmosTransformerBlock"]
614-
_keep_in_fp32_modules = ["learnable_pos_embed"]
611+
_keep_in_fp32_modules = ["learnable_pos_embed", "time_embed", "norm1", "norm2", "norm3", "norm_out", "proj_out"]
615612

616613
@register_to_config
617614
def __init__(
@@ -809,9 +806,8 @@ def forward(
809806
)
810807

811808
# 8. Output norm & projection & unpatchify
812-
with torch.amp.autocast(hidden_states.device.type, dtype=torch.float32):
813-
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
814-
hidden_states = self.proj_out(hidden_states)
809+
hidden_states = self.norm_out(hidden_states.float(), embedded_timestep, temb)
810+
hidden_states = self.proj_out(hidden_states)
815811
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
816812
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
817813
# NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.

0 commit comments

Comments
 (0)