@@ -74,8 +74,8 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7474 self .t_embedder = CosmosTimestepEmbedding (embedding_dim , condition_dim )
7575 self .norm = RMSNorm (embedding_dim , eps = 1e-6 , elementwise_affine = True )
7676
77- def forward (self , hidden_states : torch .Tensor , timestep : torch .Tensor ) -> torch .Tensor :
78- timesteps_proj = self .time_proj (timestep . float () )
77+ def forward (self , hidden_states : torch .Tensor , timestep : torch .LongTensor ) -> torch .Tensor :
78+ timesteps_proj = self .time_proj (timestep ). type_as ( hidden_states )
7979 temb = self .t_embedder (timesteps_proj )
8080 embedded_timestep = self .norm (timesteps_proj )
8181 return temb , embedded_timestep
@@ -102,7 +102,6 @@ def forward(
102102 embedded_timestep = embedded_timestep + temb [..., : 2 * self .embedding_dim ]
103103
104104 shift , scale = embedded_timestep .chunk (2 , dim = - 1 )
105-
106105 hidden_states = self .norm (hidden_states )
107106
108107 if embedded_timestep .ndim == 2 :
@@ -132,16 +131,14 @@ def forward(
132131 embedded_timestep : torch .Tensor ,
133132 temb : torch .Tensor | None = None ,
134133 ) -> torch .Tensor :
135- original_dtype = hidden_states .dtype
136- embedded_timestep = self .activation (embedded_timestep .float ())
134+ embedded_timestep = self .activation (embedded_timestep )
137135 embedded_timestep = self .linear_1 (embedded_timestep )
138136 embedded_timestep = self .linear_2 (embedded_timestep )
137+
139138 if temb is not None :
140- embedded_timestep = embedded_timestep + temb .float ()
139+ embedded_timestep = embedded_timestep + temb
140+
141141 shift , scale , gate = embedded_timestep .chunk (3 , dim = - 1 )
142- shift = shift .to (original_dtype )
143- scale = scale .to (original_dtype )
144- gate = gate .to (original_dtype )
145142 hidden_states = self .norm (hidden_states )
146143
147144 if embedded_timestep .ndim == 2 :
@@ -184,11 +181,8 @@ def __call__(
184181 if image_rotary_emb is not None :
185182 from ..embeddings import apply_rotary_emb
186183
187- original_dtype = query .dtype
188- query = apply_rotary_emb (query .to (torch .float32 ), image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
189- key = apply_rotary_emb (key .to (torch .float32 ), image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
190- query = query .to (original_dtype )
191- key = key .to (original_dtype )
184+ query = apply_rotary_emb (query , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
185+ key = apply_rotary_emb (key , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
192186
193187 # 4. Prepare for GQA
194188 if torch .onnx .is_in_onnx_export ():
@@ -254,11 +248,8 @@ def __call__(
254248 if image_rotary_emb is not None :
255249 from ..embeddings import apply_rotary_emb
256250
257- original_dtype = query .dtype
258- query = apply_rotary_emb (query .to (torch .float32 ), image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
259- key = apply_rotary_emb (key .to (torch .float32 ), image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
260- query = query .to (original_dtype )
261- key = key .to (original_dtype )
251+ query = apply_rotary_emb (query , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
252+ key = apply_rotary_emb (key , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
262253
263254 if torch .onnx .is_in_onnx_export ():
264255 query_idx = torch .tensor (query .size (3 ), device = query .device )
@@ -608,7 +599,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin,
608599 _supports_gradient_checkpointing = True
609600 _skip_layerwise_casting_patterns = ["patch_embed" , "final_layer" , "norm" ]
610601 _no_split_modules = ["CosmosTransformerBlock" ]
611- _keep_in_fp32_modules = ["learnable_pos_embed" , "time_embed" , "norm1" , "norm2" , "norm3" , "norm_out" , "proj_out" ]
602+ _keep_in_fp32_modules = ["learnable_pos_embed" ]
612603
613604 @register_to_config
614605 def __init__ (
@@ -806,7 +797,7 @@ def forward(
806797 )
807798
808799 # 8. Output norm & projection & unpatchify
809- hidden_states = self .norm_out (hidden_states . float () , embedded_timestep , temb )
800+ hidden_states = self .norm_out (hidden_states , embedded_timestep , temb )
810801 hidden_states = self .proj_out (hidden_states )
811802 hidden_states = hidden_states .unflatten (2 , (p_h , p_w , p_t , - 1 ))
812803 hidden_states = hidden_states .unflatten (1 , (post_patch_num_frames , post_patch_height , post_patch_width ))
0 commit comments