@@ -75,10 +75,9 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7575 self .norm = RMSNorm (embedding_dim , eps = 1e-6 , elementwise_affine = True )
7676
7777 def forward (self , hidden_states : torch .Tensor , timestep : torch .Tensor ) -> torch .Tensor :
78- with torch .amp .autocast (timestep .device .type , dtype = torch .float32 ):
79- timesteps_proj = self .time_proj (timestep ).to (torch .float32 )
80- temb = self .t_embedder (timesteps_proj )
81- embedded_timestep = self .norm (timesteps_proj )
78+ timesteps_proj = self .time_proj (timestep .float ())
79+ temb = self .t_embedder (timesteps_proj )
80+ embedded_timestep = self .norm (timesteps_proj )
8281 return temb , embedded_timestep
8382
8483
@@ -134,14 +133,12 @@ def forward(
134133 temb : torch .Tensor | None = None ,
135134 ) -> torch .Tensor :
136135 original_dtype = hidden_states .dtype
137- with torch .amp .autocast (hidden_states .device .type , dtype = torch .float32 ):
138- embedded_timestep = self .activation (embedded_timestep )
139- embedded_timestep = self .linear_1 (embedded_timestep )
140- embedded_timestep = self .linear_2 (embedded_timestep )
141-
142- if temb is not None :
143- embedded_timestep = embedded_timestep + temb
144- shift , scale , gate = embedded_timestep .chunk (3 , dim = - 1 )
136+ embedded_timestep = self .activation (embedded_timestep .float ())
137+ embedded_timestep = self .linear_1 (embedded_timestep )
138+ embedded_timestep = self .linear_2 (embedded_timestep )
139+ if temb is not None :
140+ embedded_timestep = embedded_timestep + temb .float ()
141+ shift , scale , gate = embedded_timestep .chunk (3 , dim = - 1 )
145142 shift = shift .to (original_dtype )
146143 scale = scale .to (original_dtype )
147144 gate = gate .to (original_dtype )
@@ -611,7 +608,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
611608 _supports_gradient_checkpointing = True
612609 _skip_layerwise_casting_patterns = ["patch_embed" , "final_layer" , "norm" ]
613610 _no_split_modules = ["CosmosTransformerBlock" ]
614- _keep_in_fp32_modules = ["learnable_pos_embed" ]
611+ _keep_in_fp32_modules = ["learnable_pos_embed" , "time_embed" , "norm1" , "norm2" , "norm3" , "norm_out" , "proj_out" ]
615612
616613 @register_to_config
617614 def __init__ (
@@ -809,9 +806,8 @@ def forward(
809806 )
810807
811808 # 8. Output norm & projection & unpatchify
812- with torch .amp .autocast (hidden_states .device .type , dtype = torch .float32 ):
813- hidden_states = self .norm_out (hidden_states , embedded_timestep , temb )
814- hidden_states = self .proj_out (hidden_states )
809+ hidden_states = self .norm_out (hidden_states .float (), embedded_timestep , temb )
810+ hidden_states = self .proj_out (hidden_states )
815811 hidden_states = hidden_states .unflatten (2 , (p_h , p_w , p_t , - 1 ))
816812 hidden_states = hidden_states .unflatten (1 , (post_patch_num_frames , post_patch_height , post_patch_width ))
817813 # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
0 commit comments