Skip to content

Commit 9374e3f

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
Skip scratch pad eviction data in enrichment mode to avoid cudaFree overhead (pytorch#5645)
Summary: Pull Request resolved: pytorch#5645 X-link: https://github.com/facebookresearch/FBGEMM/pull/2593 CONTEXT: In KVZCH enrichment mode (_enrichment_enabled), the ssd_scratch_pad_eviction_data list accumulates UVA tensors every forward pass via _prefetch. The backward hook _evict_from_scratch_pad pops entries but does nothing useful (evict() is skipped in embedding_cache_mode, RES is disabled). The .clear() call in enrichment_query_id then triggers expensive cudaFree calls when releasing those UVA tensors, causing GPU stalls visible in Perfetto traces. WHAT: Skip appending to ssd_scratch_pad_eviction_data in _prefetch when _enrichment_enabled is True. Add early return in _evict_from_scratch_pad for enrichment mode. Remove the now-unnecessary .clear() in enrichment_query_id since the list is always empty. Reviewed By: emlin Differential Revision: D101102800 fbshipit-source-id: 16dcd8d32d55f77478235f4a27a3be10f692e288
1 parent 9637997 commit 9374e3f

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,9 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
16521652
(`inserted_rows`) on the `ssd_eviction_stream`. This is a hook
16531653
that is invoked right after TBE backward.
16541654
1655+
In enrichment mode, scratch pad eviction data is not populated
1656+
(skipped in _prefetch), so this hook returns early.
1657+
16551658
Conflict missed indices are specified in
16561659
`post_bwd_evicted_indices_cpu`. Indices that are not -1 and
16571660
their positions < `actions_count_cpu` (i.e., rows
@@ -1665,6 +1668,11 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
16651668
None
16661669
"""
16671670
with record_function("## ssd_evict_from_scratch_pad_pipeline ##"):
1671+
# In enrichment mode, scratch pad eviction data is not populated
1672+
# (_prefetch skips the append), so nothing to do here.
1673+
if self._enrichment_enabled:
1674+
return
1675+
16681676
current_stream = torch.cuda.current_stream()
16691677
current_stream.record_event(self.ssd_event_backward)
16701678

@@ -2421,7 +2429,10 @@ def _prefetch( # noqa C901
24212429

24222430
# Store scratch pad info for post backward eviction only for training
24232431
# for eval job, no backward pass, so no need to store this info
2424-
if self.training:
2432+
# Skip for enrichment mode: the backward hook only pops without
2433+
# evicting (embedding_cache_mode skips evict), and the .clear()
2434+
# in enrichment_query_id triggers expensive cudaFree on UVA tensors.
2435+
if self.training and not self._enrichment_enabled:
24252436
self.ssd_scratch_pad_eviction_data.append(
24262437
(
24272438
inserted_rows,
@@ -5228,13 +5239,6 @@ def enrichment_query_id(
52285239
dedup_linear_indices = sorted_linear_indices[mask]
52295240
dedup_weights = sorted_weights[mask]
52305241

5231-
if len(self.ssd_scratch_pad_eviction_data) > 0:
5232-
# IMPORTANT: Clear ALL accumulated scratch pad data, not just one!
5233-
# _prefetch appends one element per forward, but enrichment_query_id
5234-
# may not be called every forward. This prevents memory leak from
5235-
# accumulated GPU tensors (inserted_rows is a UVA tensor).
5236-
self.ssd_scratch_pad_eviction_data.clear()
5237-
52385242
# D2H copy on the same stream (already on enrichment_query_stream)
52395243
linear_cache_indices_cpu = self.to_pinned_cpu(dedup_linear_indices)
52405244
dedup_weights_cpu = self.to_pinned_cpu(dedup_weights)

0 commit comments

Comments
 (0)