Skip to content

Commit fa7ba13

Browse files
committed
feat: Enable global attention for Gemma3/Gemma4 drafter models
1 parent 9567c0a commit fa7ba13

3 files changed

Lines changed: 55 additions & 33 deletions

File tree

examples/run_gemma3_27b_eagle3_online.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ torchrun \
1414
--draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \
1515
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
1616
--output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \
17+
--eval-holdout-ratio 0.03 \
1718
--num-epochs 10 \
1819
--batch-size 8 \
1920
--tp-size $TP_SIZE \

scripts/train_eagle3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
402402
# Use provided config file
403403
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
404404

405+
# if the target model is gemma, we should use global attention for the draft model
406+
if "gemma" in getattr(draft_model_config, "target_model_type", "").lower():
407+
draft_model_config.use_global_attention = True
408+
print_on_rank0("Using global attention for draft model.")
409+
405410
# Handle base ckpt, config file
406411
draft_model_last_checkpoint = None
407412
is_resume_checkpoint = False
@@ -427,6 +432,7 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
427432
if draft_model_last_checkpoint:
428433
draft_model = AutoEagle3DraftModel.from_pretrained(
429434
draft_model_last_checkpoint,
435+
config=draft_model_config,
430436
attention_backend=args.attention_backend,
431437
torch_dtype=torch.bfloat16,
432438
).cuda()

specforge/modeling/draft/llama3_eagle.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def __init__(self, config):
523523
self.num_key_value_heads = config.num_key_value_heads
524524
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
525525
self.max_position_embeddings = config.max_position_embeddings
526+
self.use_global_attention = getattr(config, "use_global_attention", False)
526527

527528
self.q_proj = nn.Linear(
528529
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
@@ -760,6 +761,10 @@ class LlamaFlexAttention(LlamaAttention):
760761
- past_key_values: dynamic cache used for storing past key and value states.
761762
"""
762763

764+
def __init__(self, config):
765+
super().__init__(config)
766+
self.use_global_attention = getattr(config, "use_global_attention", False)
767+
763768
def forward(
764769
self,
765770
hidden_states: torch.Tensor,
@@ -821,39 +826,45 @@ def forward(
821826
cache_kwargs=cache_kwargs,
822827
)
823828

824-
seq_lengths = attention_mask.sum(dim=-1)
825-
# Shrink the attention mask to align with the padding to the right.
826-
# This is equivalent to the shrinking logic in eagle3.py
827-
seq_lengths -= lck
828-
# TODO: Remove the usage of uncompiled create_block_mask after
829-
# https://github.com/pytorch/pytorch/issues/160018
830-
if q_len <= 128:
831-
create_block_mask_func = create_block_mask
832-
flex_attention_func = flex_attention
829+
if self.use_global_attention:
830+
block_mask = None # Enables full attention
833831
else:
834-
create_block_mask_func = compile_friendly_create_block_mask
835-
flex_attention_func = compile_friendly_flex_attention
836-
837-
block_mask = create_block_mask_func(
838-
mask_mod=generate_eagle3_mask(
839-
seq_lengths=seq_lengths,
840-
Q_LEN=q_len,
841-
KV_LEN=key_cache.shape[-2],
842-
lck=lck,
843-
),
844-
B=bsz,
845-
H=1, # Rely on broadcast
846-
Q_LEN=q_len,
847-
KV_LEN=key_cache.shape[-2],
848-
device=query_states.device,
849-
)
850-
attn_output = flex_attention_func(
851-
query=query_states,
852-
key=key_cache.contiguous(),
853-
value=value_cache.contiguous(),
854-
block_mask=block_mask,
855-
enable_gqa=True,
856-
)
832+
seq_lengths = attention_mask.sum(dim=-1)
833+
# Shrink the attention mask to align with the padding to the right.
834+
# This is equivalent to the shrinking logic in eagle3.py
835+
seq_lengths -= lck
836+
# TODO: Remove the usage of uncompiled create_block_mask after
837+
# https://github.com/pytorch/pytorch/issues/160018
838+
if q_len <= 128:
839+
create_block_mask_func = create_block_mask
840+
flex_attention_func = flex_attention
841+
else:
842+
create_block_mask_func = compile_friendly_create_block_mask
843+
flex_attention_func = compile_friendly_flex_attention
844+
845+
if self.use_global_attention:
846+
block_mask = None # This will result in dense attention
847+
else:
848+
block_mask = create_block_mask_func(
849+
mask_mod=generate_eagle3_mask(
850+
seq_lengths=seq_lengths,
851+
Q_LEN=q_len,
852+
KV_LEN=key_cache.shape[-2],
853+
lck=lck,
854+
),
855+
B=bsz,
856+
H=1, # Rely on broadcast
857+
Q_LEN=q_len,
858+
KV_LEN=key_cache.shape[-2],
859+
device=query_states.device,
860+
)
861+
attn_output = flex_attention_func(
862+
query=query_states,
863+
key=key_cache.contiguous(),
864+
value=value_cache.contiguous(),
865+
block_mask=block_mask,
866+
enable_gqa=True,
867+
)
857868
attn_output = attn_output.transpose(1, 2).contiguous()
858869
attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads)
859870
attn_output = self.o_proj(attn_output)
@@ -869,6 +880,10 @@ class LlamaFlashAttention(LlamaAttention):
869880
- cache_hidden: manual cache used for storing past key and value states
870881
"""
871882

883+
def __init__(self, config):
884+
super().__init__(config)
885+
self.use_global_attention = getattr(config, "use_global_attention", False)
886+
872887
def forward(
873888
self,
874889
hidden_states: torch.Tensor,
@@ -934,7 +949,7 @@ def forward(
934949
v0,
935950
dropout_p=0.0,
936951
softmax_scale=1.0 / math.sqrt(self.head_dim),
937-
causal=True,
952+
causal=not self.use_global_attention, # Set causal based on the flag
938953
return_attn_probs=True,
939954
)
940955
lse = lse.transpose(1, 2)

0 commit comments

Comments
 (0)