Skip to content

Commit ba83675

Browse files
rohansjoshiclaude
andcommitted
Fix OOM in attention sparsity calibration by reducing peak GPU memory
Three sources of unnecessary memory allocation during calibration: 1. flash_skip_softmax.py: In calc_correction_factor_and_p, `p` (full attention-matrix-sized float tensor) and `p_larger_than_thresh` (same size boolean tensor) were both alive simultaneously alongside blocked_attn. Fuse the subtraction and comparison into a single expression to avoid materializing `p`, and explicitly del block_max, block_max_larger, block_max_cummax, and p_larger_than_thresh as soon as each is no longer needed. Applies to both prefill and decode paths. 2. calibrate.py (chunked prefill): del outputs after extracting past_key_values in each chunk to free logits between chunks. 3. calibrate.py (decode loop): del outputs after the prefill step to free the large [B, seqlen, vocab] logits tensor before the decode loop, and del outputs inside each decode step. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
1 parent 2f27cfa commit ba83675

2 files changed

Lines changed: 16 additions & 15 deletions

File tree

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def forward_loop(model: nn.Module) -> None:
133133
use_cache=True,
134134
)
135135
past_key_values = outputs.past_key_values
136+
del outputs # Free logits between chunks
136137

137138
# Clean up KV cache
138139
del past_key_values
@@ -182,15 +183,14 @@ def forward_loop(model: nn.Module) -> None:
182183
model.config._attn_implementation = "flash_attention_2"
183184
outputs = model(input_ids, use_cache=True)
184185
past_key_values = outputs.past_key_values
186+
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
187+
del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop
185188

186189
# Step 2: Switch to eager for decode (enables softmax hook)
187190
model.config._attn_implementation = "eager"
188191

189192
# Step 3: Manual decode loop for explicit control over token generation
190193
# model.generate() method is not used here because it doesn't allow explicit control over KV cache
191-
# Get the last token's logits and sample next token
192-
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
193-
194194
for _ in range(num_decode_tokens):
195195
outputs = model(
196196
next_token,
@@ -199,6 +199,7 @@ def forward_loop(model: nn.Module) -> None:
199199
)
200200
past_key_values = outputs.past_key_values
201201
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
202+
del outputs # Free decode logits between steps
202203
finally:
203204
# Restore original attention implementation
204205
model.config._attn_implementation = original_attn_impl

modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,18 @@ def calc_correction_factor_and_p(
173173
block_max_larger = torch.ones_like(block_max)
174174
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
175175
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
176+
del block_max, block_max_larger
176177

177-
# Step 4: Normalize attention scores by cumulative max
178-
# p represents log-space difference: log(score) - log(cummax)
179-
p = blocked_attn - block_max_cummax[..., None]
178+
# Step 4 & 5: Compute threshold mask directly without storing p.
179+
# Fusing the subtraction and comparison avoids allocating a second
180+
# full attention-matrix-sized tensor alongside blocked_attn.
181+
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
182+
del block_max_cummax
180183

181-
# Step 5: Apply threshold and create block-level mask
182-
# Keep blocks where at least one element exceeds log(threshold)
183-
p_larger_than_thresh = p > log_threshold
184184
# Reduce over bc (128 cols), then br (128 rows) to get block-level decision
185185
# Result: [batch, heads, block_rows, block_cols]
186186
block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2)
187+
del p_larger_than_thresh
187188

188189
# Step 6: Expand block mask back to element level
189190
# All 128x128 elements in a block share the same mask value
@@ -227,15 +228,14 @@ def calc_correction_factor_and_p(
227228
block_max_larger = torch.ones_like(block_max)
228229
block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1]
229230
correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item()
231+
del block_max, block_max_larger
230232

231-
# Step 4: Normalize scores by cumulative max
232-
# p = log(score) - log(cummax) in log-space
233-
p = blocked_attn - block_max_cummax[..., None]
233+
# Step 4 & 5: Compute threshold mask directly without storing p.
234+
p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold
235+
del block_max_cummax
234236

235-
# Step 5: Apply threshold and create block mask
236-
# Keep blocks where at least one element exceeds threshold
237-
p_larger_than_thresh = p > log_threshold
238237
block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False)
238+
del p_larger_than_thresh
239239

240240
# Step 6: Expand to element level and remove padding
241241
element_mask = block_mask[..., None].expand_as(blocked_attn)

0 commit comments

Comments
 (0)