Skip to content

Commit a538f2e

Browse files
rohansjoshiclaude
andauthored
Fix skip softmax calibration memory issue (#923)
Fix OOM issue when running skip softmax calibration Test: ``` python examples/llm_sparsity/attention_sparsity/hf_sa.py \ --pyt_ckpt_path Qwen/Qwen3-30B-Instruct-A3B-2507 \ --sparse_attn skip_softmax_calib ``` works with >= 96GB GPU memory --------- Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 35e6099 commit a538f2e

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def forward_loop(model: nn.Module) -> None:
133133
use_cache=True,
134134
)
135135
past_key_values = outputs.past_key_values
136+
del outputs # Free logits between chunks
136137

137138
# Clean up KV cache
138139
del past_key_values
@@ -182,15 +183,14 @@ def forward_loop(model: nn.Module) -> None:
182183
model.config._attn_implementation = "flash_attention_2"
183184
outputs = model(input_ids, use_cache=True)
184185
past_key_values = outputs.past_key_values
186+
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
187+
del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop
185188

186189
# Step 2: Switch to eager for decode (enables softmax hook)
187190
model.config._attn_implementation = "eager"
188191

189192
# Step 3: Manual decode loop for explicit control over token generation
190193
# model.generate() method is not used here because it doesn't allow explicit control over KV cache
191-
# Get the last token's logits and sample next token
192-
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
193-
194194
for _ in range(num_decode_tokens):
195195
outputs = model(
196196
next_token,
@@ -199,6 +199,7 @@ def forward_loop(model: nn.Module) -> None:
199199
)
200200
past_key_values = outputs.past_key_values
201201
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
202+
del outputs # Free decode logits between steps
202203
finally:
203204
# Restore original attention implementation
204205
model.config._attn_implementation = original_attn_impl

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
399399
SKIP_SOFTMAX_CALIB = {
400400
"sparse_cfg": {
401401
"calibration": {
402-
"target_sparse_ratio": {"prefill": 0.9, "decode": 0.9},
402+
"target_sparse_ratio": {"prefill": 0.5, "decode": 0.5},
403403
"samples": 64,
404-
"max_seqlen": 65536,
404+
"max_seqlen": 16384,
405405
"chunk_size": 4096,
406406
},
407407
"*attn*": {

modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,18 @@ def calc_correction_factor_and_p(
173173
block_max_larger = torch.ones_like(block_max)
174174
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
175175
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
176+
del block_max, block_max_larger
176177

177-
# Step 4: Normalize attention scores by cumulative max
178-
# p represents log-space difference: log(score) - log(cummax)
179-
p = blocked_attn - block_max_cummax[..., None]
178+
# Step 4 & 5: Compute threshold mask directly without storing p.
179+
# Fusing the subtraction and comparison avoids allocating a second
180+
# full attention-matrix-sized tensor alongside blocked_attn.
181+
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
182+
del block_max_cummax
180183

181-
# Step 5: Apply threshold and create block-level mask
182-
# Keep blocks where at least one element exceeds log(threshold)
183-
p_larger_than_thresh = p > log_threshold
184184
# Reduce over bc (128 cols), then br (128 rows) to get block-level decision
185185
# Result: [batch, heads, block_rows, block_cols]
186186
block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2)
187+
del p_larger_than_thresh
187188

188189
# Step 6: Expand block mask back to element level
189190
# All 128x128 elements in a block share the same mask value
@@ -227,15 +228,14 @@ def calc_correction_factor_and_p(
227228
block_max_larger = torch.ones_like(block_max)
228229
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
229230
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
231+
del block_max, block_max_larger
230232

231-
# Step 4: Normalize scores by cumulative max
232-
# p = log(score) - log(cummax) in log-space
233-
p = blocked_attn - block_max_cummax[..., None]
233+
# Step 4 & 5: Compute threshold mask directly without storing p.
234+
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
235+
del block_max_cummax
234236

235-
# Step 5: Apply threshold and create block mask
236-
# Keep blocks where at least one element exceeds threshold
237-
p_larger_than_thresh = p > log_threshold
238237
block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False)
238+
del p_larger_than_thresh
239239

240240
# Step 6: Expand to element level and remove padding
241241
element_mask = block_mask[..., None].expand_as(blocked_attn)

0 commit comments

Comments
 (0)