Skip to content

Commit 13cfbb0

Browse files
author
Ting-Yun Chang
committed
revert fp32 upcast and support bs > 1
1 parent 770abee commit 13cfbb0

2 files changed

Lines changed: 15 additions & 22 deletions

File tree

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7474
self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
7575
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
7676

77-
def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
78-
timesteps_proj = self.time_proj(timestep.float())
77+
def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
78+
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
7979
temb = self.t_embedder(timesteps_proj)
8080
embedded_timestep = self.norm(timesteps_proj)
8181
return temb, embedded_timestep
@@ -102,7 +102,6 @@ def forward(
102102
embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
103103

104104
shift, scale = embedded_timestep.chunk(2, dim=-1)
105-
106105
hidden_states = self.norm(hidden_states)
107106

108107
if embedded_timestep.ndim == 2:
@@ -132,16 +131,14 @@ def forward(
132131
embedded_timestep: torch.Tensor,
133132
temb: torch.Tensor | None = None,
134133
) -> torch.Tensor:
135-
original_dtype = hidden_states.dtype
136-
embedded_timestep = self.activation(embedded_timestep.float())
134+
embedded_timestep = self.activation(embedded_timestep)
137135
embedded_timestep = self.linear_1(embedded_timestep)
138136
embedded_timestep = self.linear_2(embedded_timestep)
137+
139138
if temb is not None:
140-
embedded_timestep = embedded_timestep + temb.float()
139+
embedded_timestep = embedded_timestep + temb
140+
141141
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
142-
shift = shift.to(original_dtype)
143-
scale = scale.to(original_dtype)
144-
gate = gate.to(original_dtype)
145142
hidden_states = self.norm(hidden_states)
146143

147144
if embedded_timestep.ndim == 2:
@@ -184,11 +181,8 @@ def __call__(
184181
if image_rotary_emb is not None:
185182
from ..embeddings import apply_rotary_emb
186183

187-
original_dtype = query.dtype
188-
query = apply_rotary_emb(query.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
189-
key = apply_rotary_emb(key.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
190-
query = query.to(original_dtype)
191-
key = key.to(original_dtype)
184+
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
185+
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
192186

193187
# 4. Prepare for GQA
194188
if torch.onnx.is_in_onnx_export():
@@ -254,11 +248,8 @@ def __call__(
254248
if image_rotary_emb is not None:
255249
from ..embeddings import apply_rotary_emb
256250

257-
original_dtype = query.dtype
258-
query = apply_rotary_emb(query.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
259-
key = apply_rotary_emb(key.to(torch.float32), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
260-
query = query.to(original_dtype)
261-
key = key.to(original_dtype)
251+
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
252+
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
262253

263254
if torch.onnx.is_in_onnx_export():
264255
query_idx = torch.tensor(query.size(3), device=query.device)
@@ -608,7 +599,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
608599
_supports_gradient_checkpointing = True
609600
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
610601
_no_split_modules = ["CosmosTransformerBlock"]
611-
_keep_in_fp32_modules = ["learnable_pos_embed", "time_embed", "norm1", "norm2", "norm3", "norm_out", "proj_out"]
602+
_keep_in_fp32_modules = ["learnable_pos_embed"]
612603

613604
@register_to_config
614605
def __init__(
@@ -806,7 +797,7 @@ def forward(
806797
)
807798

808799
# 8. Output norm & projection & unpatchify
809-
hidden_states = self.norm_out(hidden_states.float(), embedded_timestep, temb)
800+
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
810801
hidden_states = self.proj_out(hidden_states)
811802
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
812803
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,9 @@ def __call__(
781781
# NOTE: assumes sigma(t) \in [0, 1]
782782
sigma_t = self.scheduler.sigmas[i].expand(batch_size).to(device=device, dtype=torch.float32)
783783
if conditional_frame_timestep >= 0:
784-
in_timestep = cond_indicator * conditional_frame_timestep + (1 - cond_indicator) * sigma_t
784+
in_timestep = cond_indicator * conditional_frame_timestep + (1 - cond_indicator) * sigma_t.view(
785+
batch_size, 1, 1, 1, 1
786+
)
785787
else:
786788
in_timestep = sigma_t
787789
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents

0 commit comments

Comments
 (0)