@@ -67,7 +67,7 @@ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
6767
6868
6969class CosmosEmbedding (nn .Module ):
70- def __init__ (self , embedding_dim : int , condition_dim : int , autocast_fp32 : bool = True ) -> None :
70+ def __init__ (self , embedding_dim : int , condition_dim : int , autocast_fp32 : bool = False ) -> None :
7171 super ().__init__ ()
7272
7373 self .autocast_fp32 = autocast_fp32
@@ -116,7 +116,7 @@ def forward(
116116
117117
118118class CosmosAdaLayerNormZero (nn .Module ):
119- def __init__ (self , in_features : int , hidden_features : int | None = None , autocast_fp32 : bool = True ) -> None :
119+ def __init__ (self , in_features : int , hidden_features : int | None = None , autocast_fp32 : bool = False ) -> None :
120120 super ().__init__ ()
121121
122122 self .autocast_fp32 = autocast_fp32
@@ -158,7 +158,7 @@ def forward(
158158
159159
160160class CosmosAttnProcessor2_0 :
161- def __init__ (self , autocast_fp32 : bool = True ):
161+ def __init__ (self , autocast_fp32 : bool = False ):
162162 if not hasattr (torch .nn .functional , "scaled_dot_product_attention" ):
163163 raise ImportError ("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
164164 self .autocast_fp32 = autocast_fp32
@@ -228,7 +228,7 @@ def __call__(
228228
229229
230230class CosmosAttnProcessor2_5 :
231- def __init__ (self , autocast_fp32 : bool = True ):
231+ def __init__ (self , autocast_fp32 : bool = False ):
232232 if not hasattr (torch .nn .functional , "scaled_dot_product_attention" ):
233233 raise ImportError ("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer." )
234234 self .autocast_fp32 = autocast_fp32
@@ -373,7 +373,7 @@ def __init__(
373373 img_context : bool = False ,
374374 before_proj : bool = False ,
375375 after_proj : bool = False ,
376- autocast_fp32 : bool = True ,
376+ autocast_fp32 : bool = False ,
377377 ) -> None :
378378 super ().__init__ ()
379379
@@ -622,7 +622,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
622622 img_context_dim_out (`int`):
623623 The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
624624 this parameter is ignored.
625- autocast_fp32 (`bool`, defaults to `True `):
625+ autocast_fp32 (`bool`, defaults to `False `):
626626 Whether to cast certain computations (AdaLN, timestep embedding, RoPE, final norm and projection) to
627627 float32 for numerical stability. Set to `False` to disable autocasting (e.g., when the model is already
628628 running in float32 or when autocasting is handled externally).
@@ -656,7 +656,7 @@ def __init__(
656656 img_context_dim_in : int | None = None ,
657657 img_context_num_tokens : int = 256 ,
658658 img_context_dim_out : int = 2048 ,
659- autocast_fp32 : bool = True ,
659+ autocast_fp32 : bool = False ,
660660 ) -> None :
661661 super ().__init__ ()
662662 hidden_size = num_attention_heads * attention_head_dim
0 commit comments