Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def forward_loop(model: nn.Module) -> None:
use_cache=True,
)
past_key_values = outputs.past_key_values
del outputs # Free logits between chunks

# Clean up KV cache
del past_key_values
Expand Down Expand Up @@ -182,15 +183,14 @@ def forward_loop(model: nn.Module) -> None:
model.config._attn_implementation = "flash_attention_2"
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop

# Step 2: Switch to eager for decode (enables softmax hook)
model.config._attn_implementation = "eager"

# Step 3: Manual decode loop for explicit control over token generation
# model.generate() method is not used here because it doesn't allow explicit control over KV cache
# Get the last token's logits and sample next token
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)

for _ in range(num_decode_tokens):
outputs = model(
next_token,
Expand All @@ -199,6 +199,7 @@ def forward_loop(model: nn.Module) -> None:
)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
del outputs # Free decode logits between steps
finally:
# Restore original attention implementation
model.config._attn_implementation = original_attn_impl
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
SKIP_SOFTMAX_CALIB = {
"sparse_cfg": {
"calibration": {
"target_sparse_ratio": {"prefill": 0.9, "decode": 0.9},
"target_sparse_ratio": {"prefill": 0.5, "decode": 0.5},
"samples": 64,
"max_seqlen": 65536,
"max_seqlen": 16384,
"chunk_size": 4096,
},
"*attn*": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,18 @@ def calc_correction_factor_and_p(
block_max_larger = torch.ones_like(block_max)
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
del block_max, block_max_larger

# Step 4: Normalize attention scores by cumulative max
# p represents log-space difference: log(score) - log(cummax)
p = blocked_attn - block_max_cummax[..., None]
# Step 4 & 5: Compute threshold mask directly without storing p.
# Fusing the subtraction and comparison avoids allocating a second
# full attention-matrix-sized tensor alongside blocked_attn.
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
del block_max_cummax

# Step 5: Apply threshold and create block-level mask
# Keep blocks where at least one element exceeds log(threshold)
p_larger_than_thresh = p > log_threshold
# Reduce over bc (128 cols), then br (128 rows) to get block-level decision
# Result: [batch, heads, block_rows, block_cols]
block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2)
del p_larger_than_thresh

# Step 6: Expand block mask back to element level
# All 128x128 elements in a block share the same mask value
Expand Down Expand Up @@ -227,15 +228,14 @@ def calc_correction_factor_and_p(
block_max_larger = torch.ones_like(block_max)
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
del block_max, block_max_larger

# Step 4: Normalize scores by cumulative max
# p = log(score) - log(cummax) in log-space
p = blocked_attn - block_max_cummax[..., None]
# Step 4 & 5: Compute threshold mask directly without storing p.
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
del block_max_cummax

# Step 5: Apply threshold and create block mask
# Keep blocks where at least one element exceeds threshold
p_larger_than_thresh = p > log_threshold
block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False)
del p_larger_than_thresh

# Step 6: Expand to element level and remove padding
element_mask = block_mask[..., None].expand_as(blocked_attn)
Expand Down