@@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
184184 The dimension of the output embedding.
185185 """
186186
187- def __init__ (self , embedding_dim : int ):
187+ def __init__ (self , embedding_dim : int , use_meanflow : bool = False ):
188188 super ().__init__ ()
189189
190190 self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
191191 self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
192192
193+ self .use_meanflow = use_meanflow
194+ self .time_proj_r = None
195+ self .timestep_embedder_r = None
196+ if use_meanflow :
197+ self .time_proj_r = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
198+ self .timestep_embedder_r = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
199+
193200 def forward (
194201 self ,
195202 timestep : torch .Tensor ,
203+ timestep_r : Optional [torch .Tensor ] = None ,
196204 ) -> torch .Tensor :
197205 timesteps_proj = self .time_proj (timestep )
198206 timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = timestep .dtype ))
199207
208+ if timestep_r is not None :
209+ timesteps_proj_r = self .time_proj_r (timestep_r )
210+ timesteps_emb_r = self .timestep_embedder_r (timesteps_proj_r .to (dtype = timestep .dtype ))
211+ timesteps_emb = timesteps_emb + timesteps_emb_r
212+
200213 return timesteps_emb
201214
202215
@@ -567,6 +580,7 @@ def __init__(
567580 # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
568581 target_size : int = 640 , # did not name sample_size since it is in pixel spaces
569582 task_type : str = "i2v" ,
583+ use_meanflow : bool = False ,
570584 ) -> None :
571585 super ().__init__ ()
572586
@@ -582,7 +596,7 @@ def __init__(
582596 )
583597 self .context_embedder_2 = HunyuanVideo15ByT5TextProjection (text_embed_2_dim , 2048 , inner_dim )
584598
585- self .time_embed = HunyuanVideo15TimeEmbedding (inner_dim )
599+ self .time_embed = HunyuanVideo15TimeEmbedding (inner_dim , use_meanflow = use_meanflow )
586600
587601 self .cond_type_embed = nn .Embedding (3 , inner_dim )
588602
@@ -612,6 +626,7 @@ def forward(
612626 timestep : torch .LongTensor ,
613627 encoder_hidden_states : torch .Tensor ,
614628 encoder_attention_mask : torch .Tensor ,
629+ timestep_r : Optional [torch .LongTensor ] = None ,
615630 encoder_hidden_states_2 : Optional [torch .Tensor ] = None ,
616631 encoder_attention_mask_2 : Optional [torch .Tensor ] = None ,
617632 image_embeds : Optional [torch .Tensor ] = None ,
@@ -643,7 +658,7 @@ def forward(
643658 image_rotary_emb = self .rope (hidden_states )
644659
645660 # 2. Conditional embeddings
646- temb = self .time_embed (timestep )
661+ temb = self .time_embed (timestep , timestep_r = timestep_r )
647662
648663 hidden_states = self .x_embedder (hidden_states )
649664
0 commit comments