Skip to content

Commit ef832c4

Browse files
committed
address comments
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 48cc15f commit ef832c4

1 file changed

Lines changed: 12 additions & 21 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -908,14 +908,11 @@ def forward(
908908
# ====Run eagle forward with extra training-time-test steps====
909909
for ttt_step in range(self.eagle_ttt_steps):
910910
# TODO: (hg) during cp training, this mask is not used. Maybe turn it off then.
911-
if self.eagle_mix_hidden_states:
912-
eagle_attention_mask = eagle_attn_mask_0
913-
else:
914-
eagle_attention_mask = (
915-
eagle_attn_mask_0
916-
if ttt_step == 0
917-
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
918-
)
911+
eagle_attention_mask = (
912+
eagle_attn_mask_0
913+
if self.eagle_mix_hidden_states or ttt_step == 0
914+
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
915+
)
919916
with (
920917
enable_cp_ttt_patch()
921918
if self.training and not self.eagle_mix_hidden_states
@@ -935,20 +932,14 @@ def forward(
935932
num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))
936933

937934
# Randomly select positions for each batch to replace
938-
rand_indices = torch.stack(
939-
[
940-
torch.randperm(seq_len_s, device=eagle_input_hiddens.device)[
941-
:num_to_replace
942-
]
943-
for _ in range(batch_size)
944-
],
945-
dim=0,
946-
)
935+
rand_indices = torch.rand(
936+
batch_size, seq_len_s, device=eagle_input_hiddens.device
937+
).argsort(dim=1)[:, :num_to_replace]
947938

948-
for batch_idx in range(batch_size):
949-
eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = (
950-
eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :]
951-
)
939+
batch_indices = torch.arange(batch_size)[:, None]
940+
eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
941+
batch_indices, rand_indices
942+
]
952943
else:
953944
eagle_input_hiddens = eagle_output_hiddens
954945

0 commit comments

Comments
 (0)