Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.

training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
Expand Down
73 changes: 50 additions & 23 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import transformers
from packaging.version import Version
from torch import nn
from torch.cuda import nvtx
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
Expand Down Expand Up @@ -292,6 +293,10 @@ def __init__(self, config, decoder_layer_cls, bias=False):
num_layers=self.config.parallel_draft_heads_num_layers,
)

def _maybe_init_rope(self):
if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

def _expand_first_attn_in_dim(self, first_layer_attn):
"""Modify qkv projection in first layer to accept 2h hidden size."""
# Find Linear modules to expand
Expand Down Expand Up @@ -372,11 +377,6 @@ def forward(
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)

if self.config.eagle_decoder_type == "llama":
# Lazy init rope to avoid save/load meta tensor error
Comment thread
benchislett marked this conversation as resolved.
if not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(
config=self.config, device=hidden_states.device
)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
Expand Down Expand Up @@ -657,6 +657,8 @@ def _prepare_decoder_attention_mask(

return combined_attention_mask

@nvtx.range("prepare_eagle_inputs")
@torch.compile(dynamic=False)
def _prepare_eagle_inputs(
self,
input_ids,
Expand Down Expand Up @@ -716,7 +718,20 @@ def _prepare_eagle_inputs(
else:
eagle_position_ids = position_ids.view(-1, seq_length).long()

return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids
base_model_logits = base_outputs.logits
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach()

return (
eagle_input_embeds,
eagle_input_hiddens,
eagle_attention_mask,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
)

def _compute_ttt_attention_mask(
self, batch_size, seq_length, ttt_step
Expand Down Expand Up @@ -746,6 +761,7 @@ def _compute_ttt_attention_mask(
tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1)
return tensor_mask

@nvtx.range("base_model_forward")
def _base_model_forward(
self,
input_ids,
Expand Down Expand Up @@ -794,6 +810,8 @@ def _map_logits_to_draft_vocab(self, full_logits):
)
return full_logits[:, :, reverse_mapping]

@nvtx.range("eagle_forward")
@torch.compile(dynamic=False, mode="max-autotune")
def _eagle_forward(
self,
eagle_input_hidden_states,
Expand Down Expand Up @@ -890,13 +908,17 @@ def forward(

# ====Prepare inputs for the first eagle forward pass====
eagle_loss = None
train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)]
num_parallel = self.eagle_config.parallel_draft_step
num_ttt = self.eagle_ttt_steps
train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
Comment thread
benchislett marked this conversation as resolved.
b, seq_length, _ = base_outputs.out_hiddens.shape
(
eagle_input_embeds,
eagle_input_hiddens,
eagle_attn_mask_0,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
) = self._prepare_eagle_inputs(
input_ids,
attention_mask,
Expand All @@ -905,6 +927,8 @@ def forward(
base_outputs,
)

self.eagle_module._maybe_init_rope()

Comment thread
benchislett marked this conversation as resolved.
# ====Run eagle forward with extra training-time-test steps====
for ttt_step in range(self.eagle_ttt_steps):
# TODO: (hg) during cp training, this mask is not used. Maybe turn it off then.
Expand Down Expand Up @@ -949,7 +973,8 @@ def forward(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
# additionally, we mask the first n tok of eagle outputs at nth TTT step
base_outputs.logits[:, 1 + i + ttt_step :],
base_output_softmax_logits[:, 1 + i + ttt_step :],
base_output_predict_tok[:, 1 + i + ttt_step :],
eagle_logit[:, ttt_step : -(1 + i)],
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
)
Expand All @@ -958,10 +983,13 @@ def forward(
eagle_loss = (
classification_loss if eagle_loss is None else eagle_loss + classification_loss
)
train_accs[i].append(acc)
train_accs[i, ttt_step] = acc
if not self.training:
break

# Slice by actual number of steps taken, in case of early return
train_accs = train_accs[:, : ttt_step + 1].tolist()

# Merge base model loss and eagle loss
if base_outputs.loss is None and eagle_loss is None:
loss = None
Expand All @@ -977,29 +1005,27 @@ def forward(
train_acc=train_accs,
)

@nvtx.range("eagle_loss")
@torch.compile(dynamic=False, fullgraph=True)
def _eagle_loss(
self,
base_model_logits,
base_output_softmax_logits,
base_output_predict_tok,
eagle_logits,
loss_mask,
):
"""Function for EAGLE loss computing."""
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
loss_mask = loss_mask[:, : eagle_logits.shape[1], None]
classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)(
eagle_logits
)
classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / (
loss_mask.sum() + 1e-5
)
# Compute accuracy
base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1)
eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1)
eagle_logsoft = torch.log_softmax(eagle_logits, dim=2)
classification_loss = -torch.sum(
torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2)
) / (loss_mask.sum() + 1e-5)
# Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop)
eagle_predict_tok = eagle_logits.detach().argmax(dim=-1)
valid = loss_mask[:, :, 0].bool()
correct = (base_predict_tok == eagle_predict_tok) & valid
correct = (base_output_predict_tok == eagle_predict_tok) & valid
denom = valid.sum().clamp_min(1).float()
accuracy = round(correct.sum().float().div(denom).item(), 3)
accuracy = correct.sum().float() / denom

return classification_loss, accuracy

Expand Down Expand Up @@ -1039,6 +1065,7 @@ def pseudo_speculative_generate(
else:
eagle_input_hidden_states = base_model_hidden_states

self.eagle_module._maybe_init_rope()
draft_tokens = []
for step in range(steps):
b, seq_length = eagle_ids.shape
Expand Down
Loading