diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index c38a716d79..dbc4d5bc27 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -133,6 +133,7 @@ def forward_loop(model: nn.Module) -> None: use_cache=True, ) past_key_values = outputs.past_key_values + del outputs # Free logits between chunks # Clean up KV cache del past_key_values @@ -182,15 +183,14 @@ def forward_loop(model: nn.Module) -> None: model.config._attn_implementation = "flash_attention_2" outputs = model(input_ids, use_cache=True) past_key_values = outputs.past_key_values + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop # Step 2: Switch to eager for decode (enables softmax hook) model.config._attn_implementation = "eager" # Step 3: Manual decode loop for explicit control over token generation # model.generate() method is not used here because it doesn't allow explicit control over KV cache - # Get the last token's logits and sample next token - next_token = outputs.logits[:, -1:, :].argmax(dim=-1) - for _ in range(num_decode_tokens): outputs = model( next_token, @@ -199,6 +199,7 @@ def forward_loop(model: nn.Module) -> None: ) past_key_values = outputs.past_key_values next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + del outputs # Free decode logits between steps finally: # Restore original attention implementation model.config._attn_implementation = original_attn_impl diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index d2d3b1078b..2d73f13ad7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -399,9 +399,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): SKIP_SOFTMAX_CALIB = { "sparse_cfg": { "calibration": { - "target_sparse_ratio": {"prefill": 0.9, "decode": 0.9}, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, "samples": 64, - "max_seqlen": 65536, + "max_seqlen": 16384, "chunk_size": 4096, }, "*attn*": { diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index e575de4da0..f911b95f79 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -173,17 +173,18 @@ def calc_correction_factor_and_p( block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() + del block_max, block_max_larger - # Step 4: Normalize attention scores by cumulative max - # p represents log-space difference: log(score) - log(cummax) - p = blocked_attn - block_max_cummax[..., None] + # Step 4 & 5: Compute threshold mask directly without storing p. + # Fusing the subtraction and comparison avoids allocating a second + # full attention-matrix-sized tensor alongside blocked_attn. + p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold + del block_max_cummax - # Step 5: Apply threshold and create block-level mask - # Keep blocks where at least one element exceeds log(threshold) - p_larger_than_thresh = p > log_threshold # Reduce over bc (128 cols), then br (128 rows) to get block-level decision # Result: [batch, heads, block_rows, block_cols] block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + del p_larger_than_thresh # Step 6: Expand block mask back to element level # All 128x128 elements in a block share the same mask value @@ -227,15 +228,14 @@ def calc_correction_factor_and_p( block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() + del block_max, block_max_larger - # Step 4: Normalize scores by cumulative max - # p = log(score) - log(cummax) in log-space - p = blocked_attn - block_max_cummax[..., None] + # Step 4 & 5: Compute threshold mask directly without storing p. + p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold + del block_max_cummax - # Step 5: Apply threshold and create block mask - # Keep blocks where at least one element exceeds threshold - p_larger_than_thresh = p > log_threshold block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + del p_larger_than_thresh # Step 6: Expand to element level and remove padding element_mask = block_mask[..., None].expand_as(blocked_attn)