Skip to content

Commit 4e605fa

Browse files
Xinyi Wangfacebook-github-bot
authored andcommitted
fix dram_kv.hit_rate_pct normalization (#5777)
Summary: X-link: facebookresearch/FBGEMM#2706 The `dram_kv.hit_rate_pct` metric in `SSDTableBatchedEmbeddingBags` was computed as `dram_read_hit_count / (dram_read_hit_count + dram_read_miss_count)` — the denominator only counts requests that reached DRAM, i.e. L1 misses. When `l1_cache_size` grows, L1 absorbs more keys and only the long-tail keys fall through to DRAM, so the L1-conditional DRAM hit rate drops mechanically even though the system is doing more — not less — work in the cheaper tier. This diff changes `dram_kv.hit_rate_pct` to be normalized against `num_unique` (total unique indices in the batch, captured from the L1 reporting path): hit_rate_pct = 100.0 * (num_unique - dram_read_miss_count) / num_unique Semantically this is now the overall (L1 + DRAM) hit rate — the fraction of unique requests that did not miss at DRAM. The value stays stable as cache sizes shift between tiers. Algebraically equivalent to the expanded form `L1_hit + (1 - L1_hit) * DRAM_hit_conditional` under the assumption that every L1 miss reaches DRAM (the only path today). A code comment documents this caveat in case a future SSD-bypass path is added. Implementation: - `_report_uvm_cache_stats` stashes `num_unique` into `_last_l1_num_unique` so `_report_dram_kv_perf_stats` can use it as the normalization denominator without re-reading L1 counters. Both reporters fire from the same `should_report(self.step)` cadence, so the stashed value corresponds to the same reporting window. - `l1_hit_rate_pct` and `l2_cache.hit_rate_pct` are untouched. Differential Revision: D105727013
1 parent 0fcdfc0 commit 4e605fa

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,10 @@ def __init__(
12331233
# 4: N_conflict_unique_misses, 5: N_conflict_misses
12341234
self.last_reported_ssd_stats: list[float] = []
12351235
self.last_reported_step = 0
1236+
# Stashed by _report_uvm_cache_stats so _report_l2_cache_perf_stats
1237+
# can normalize DRAM hit rate against total unique indices instead
1238+
# of only L1-miss lookups. See T272139146.
1239+
self._last_l1_num_unique: float = 0.0
12361240

12371241
self.register_buffer(
12381242
"ssd_cache_stats",
@@ -4185,6 +4189,7 @@ def _report_ssd_l1_cache_stats(self) -> None:
41854189
# L1 cache hit rate
41864190
num_unique = ssd_cache_stats_delta[UVMCacheStatsIndex.num_unique_indices]
41874191
num_misses = ssd_cache_stats_delta[UVMCacheStatsIndex.num_unique_misses]
4192+
self._last_l1_num_unique = num_unique
41884193
if num_unique > 0:
41894194
l1_hit_rate_pct = 100.0 * (num_unique - num_misses) / num_unique
41904195
# Per-TBE L1 hit rate
@@ -4847,8 +4852,23 @@ def _report_dram_kv_perf_stats(self) -> None:
48474852
data_bytes=dram_read_miss_count,
48484853
enable_tb_metrics=True,
48494854
)
4850-
if dram_read_total > 0:
4851-
hit_rate_pct = 100.0 * dram_read_hit_count / dram_read_total
4855+
# Hit rate normalized to total unique requests (L1 hits + DRAM
4856+
# hits) / total. Stable across different l1_cache_size — the
4857+
# previous formula (dram_hits / dram_calls) dropped mechanically
4858+
# as L1 grew because only long-tail keys reached DRAM.
4859+
# See T272139146.
4860+
#
4861+
# Algebraically equivalent to the expanded form
4862+
# L1_hit_rate + (1 - L1_hit_rate) * DRAM_hit_rate_conditional
4863+
# under the assumption that every L1 miss reaches DRAM, i.e.
4864+
# num_misses == dram_read_hit_count + dram_read_miss_count.
4865+
# This holds today since DRAM is the only tier behind L1. If a
4866+
# future code path lets L1 misses bypass DRAM (e.g. direct SSD
4867+
# read), this simplified form will silently diverge from the
4868+
# explicit two-term form — revisit then.
4869+
num_unique = self._last_l1_num_unique
4870+
if num_unique > 0:
4871+
hit_rate_pct = 100.0 * (num_unique - dram_read_miss_count) / num_unique
48524872
# Per-TBE hit rate
48534873
stats_reporter.report_data_amount(
48544874
iteration_step=self.step,

0 commit comments

Comments
 (0)