@@ -908,14 +908,11 @@ def forward(
908908 # ====Run eagle forward with extra training-time-test steps====
909909 for ttt_step in range (self .eagle_ttt_steps ):
910910 # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then.
911- if self .eagle_mix_hidden_states :
912- eagle_attention_mask = eagle_attn_mask_0
913- else :
914- eagle_attention_mask = (
915- eagle_attn_mask_0
916- if ttt_step == 0
917- else self ._get_ttt_attention_mask (b , seq_length , ttt_step )
918- )
911+ eagle_attention_mask = (
912+ eagle_attn_mask_0
913+ if self .eagle_mix_hidden_states or ttt_step == 0
914+ else self ._get_ttt_attention_mask (b , seq_length , ttt_step )
915+ )
919916 with (
920917 enable_cp_ttt_patch ()
921918 if self .training and not self .eagle_mix_hidden_states
@@ -935,20 +932,14 @@ def forward(
935932 num_to_replace = max (1 , seq_len_s // (2 ** ttt_step + 1 ))
936933
937934 # Randomly select positions for each batch to replace
938- rand_indices = torch .stack (
939- [
940- torch .randperm (seq_len_s , device = eagle_input_hiddens .device )[
941- :num_to_replace
942- ]
943- for _ in range (batch_size )
944- ],
945- dim = 0 ,
946- )
935+ rand_indices = torch .rand (
936+ batch_size , seq_len_s , device = eagle_input_hiddens .device
937+ ).argsort (dim = 1 )[:, :num_to_replace ]
947938
948- for batch_idx in range (batch_size ):
949- eagle_input_hiddens [batch_idx , rand_indices [ batch_idx ], :] = (
950- eagle_output_hiddens [ batch_idx , rand_indices [ batch_idx ], :]
951- )
939+ batch_indices = torch . arange (batch_size )[:, None ]
940+ eagle_input_hiddens [batch_indices , rand_indices ] = eagle_output_hiddens [
941+ batch_indices , rand_indices
942+ ]
952943 else :
953944 eagle_input_hiddens = eagle_output_hiddens
954945
0 commit comments