Skip to content

Commit 60fa326

Browse files
rootterarachang
authored andcommitted
autocast_fp32 default to False and only set to True at training
1 parent e985d88 commit 60fa326

2 files changed

Lines changed: 9 additions & 8 deletions

File tree

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def __getitem__(self, index: int) -> dict | Any:
423423

424424
# Load caption based on format
425425
video_path = self.video_paths[index]
426-
video_basename = os.path.basename(video_path).replace(".mp4", "")
426+
video_basename = os.path.splitext(os.path.basename(video_path))[0]
427427

428428
if self.caption_format == "json":
429429
caption_path = os.path.join(self.caption_dir, f"{video_basename}.json")
@@ -550,6 +550,7 @@ def main():
550550
vae = pipe.vae
551551
text_encoder = pipe.text_encoder
552552

553+
dit.set_autocast_fp32(True)
553554
dit.requires_grad_(False)
554555
vae.requires_grad_(False)
555556
text_encoder.requires_grad_(False)

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
6767

6868

6969
class CosmosEmbedding(nn.Module):
70-
def __init__(self, embedding_dim: int, condition_dim: int, autocast_fp32: bool = True) -> None:
70+
def __init__(self, embedding_dim: int, condition_dim: int, autocast_fp32: bool = False) -> None:
7171
super().__init__()
7272

7373
self.autocast_fp32 = autocast_fp32
@@ -116,7 +116,7 @@ def forward(
116116

117117

118118
class CosmosAdaLayerNormZero(nn.Module):
119-
def __init__(self, in_features: int, hidden_features: int | None = None, autocast_fp32: bool = True) -> None:
119+
def __init__(self, in_features: int, hidden_features: int | None = None, autocast_fp32: bool = False) -> None:
120120
super().__init__()
121121

122122
self.autocast_fp32 = autocast_fp32
@@ -158,7 +158,7 @@ def forward(
158158

159159

160160
class CosmosAttnProcessor2_0:
161-
def __init__(self, autocast_fp32: bool = True):
161+
def __init__(self, autocast_fp32: bool = False):
162162
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
163163
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
164164
self.autocast_fp32 = autocast_fp32
@@ -228,7 +228,7 @@ def __call__(
228228

229229

230230
class CosmosAttnProcessor2_5:
231-
def __init__(self, autocast_fp32: bool = True):
231+
def __init__(self, autocast_fp32: bool = False):
232232
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
233233
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
234234
self.autocast_fp32 = autocast_fp32
@@ -373,7 +373,7 @@ def __init__(
373373
img_context: bool = False,
374374
before_proj: bool = False,
375375
after_proj: bool = False,
376-
autocast_fp32: bool = True,
376+
autocast_fp32: bool = False,
377377
) -> None:
378378
super().__init__()
379379

@@ -622,7 +622,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
622622
img_context_dim_out (`int`):
623623
The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
624624
this parameter is ignored.
625-
autocast_fp32 (`bool`, defaults to `True`):
625+
autocast_fp32 (`bool`, defaults to `False`):
626626
Whether to cast certain computations (AdaLN, timestep embedding, RoPE, final norm and projection) to
627627
float32 for numerical stability. Set to `False` to disable autocasting (e.g., when the model is already
628628
running in float32 or when autocasting is handled externally).
@@ -656,7 +656,7 @@ def __init__(
656656
img_context_dim_in: int | None = None,
657657
img_context_num_tokens: int = 256,
658658
img_context_dim_out: int = 2048,
659-
autocast_fp32: bool = True,
659+
autocast_fp32: bool = False,
660660
) -> None:
661661
super().__init__()
662662
hidden_size = num_attention_heads * attention_head_dim

0 commit comments

Comments
 (0)