Skip to content

Commit 4871b45

Browse files
spcypptmeta-codesync[bot]
authored andcommitted
Add OOB benchmarking and V2 bounds check mode support to bounds_check_indices benchmark (#5797)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2725 Pull Request resolved: #5797 Add `--oob` flag (0-100) to the `bounds_check_indices` benchmark to set a percentage of indices out of bounds, enabling measurement of the atomic contention overhead in WARNING/IGNORE modes across v1 and v2 kernels. Add V2 bounds check mode support (`V2_IGNORE=4`, `V2_WARNING=5`, `V2_FATAL=6`) to the benchmark. V2 modes are decomposed into their base mode + `bounds_check_version=2`, mirroring the logic in `SplitTableBatchedEmbeddingBagsCodegen`. The `bounds_check_version` is now passed through to `torch.ops.fbgemm.bounds_check_indices`. Update `run_bounds_check_benchmark.sh` to accept `--oob` and to loop over multiple `--bounds-check-mode` values (e.g., `"1 2 5 4"` for v1/v2 x WARNING/IGNORE). Trace URLs now include mode and OOB percentage. Add a convenience sweep script for the OOB experiment. Reviewed By: q10 Differential Revision: D106606582 fbshipit-source-id: 58abe82d6793fbfdda130490563bc7f6d13e9818
1 parent 6c71acd commit 4871b45

3 files changed

Lines changed: 35 additions & 4 deletions

File tree

fbgemm_gpu/bench/tbe/tbe_utils_benchmark.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def pruned_array_lookup( # noqa C901
282282
help=f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
283283
f"WARNING={BoundsCheckMode.WARNING.value}, "
284284
f"IGNORE={BoundsCheckMode.IGNORE.value}, "
285-
f"NONE={BoundsCheckMode.NONE.value}",
285+
f"NONE={BoundsCheckMode.NONE.value}, "
286+
f"V2_IGNORE={BoundsCheckMode.V2_IGNORE.value}, "
287+
f"V2_WARNING={BoundsCheckMode.V2_WARNING.value}, "
288+
f"V2_FATAL={BoundsCheckMode.V2_FATAL.value}",
286289
)
287290
@click.option("--requests_data_file", type=str, default=None)
288291
@click.option("--tables", type=str, default=None)
@@ -299,6 +302,13 @@ def pruned_array_lookup( # noqa C901
299302
type=str,
300303
default="bounds_check_indices_trace_{ospid}.json",
301304
)
305+
@click.option(
306+
"--oob",
307+
type=int,
308+
default=0,
309+
help="Percentage of indices to set out of bounds (0 to 100). "
310+
"Use with WARNING or IGNORE mode (FATAL will crash).",
311+
)
302312
def bounds_check_indices( # noqa C901
303313
bag_size: int,
304314
batch_size: int,
@@ -312,6 +322,7 @@ def bounds_check_indices( # noqa C901
312322
batch_sizes: str,
313323
export_trace: bool,
314324
trace_url: str,
325+
oob: int,
315326
) -> None:
316327
np.random.seed(42)
317328
torch.manual_seed(42)
@@ -358,9 +369,27 @@ def bounds_check_indices( # noqa C901
358369
offset_dtype=torch.long,
359370
)
360371

372+
if oob > 0:
373+
for req in requests:
374+
num_indices = req.indices.numel()
375+
num_oob = int(num_indices * oob / 100)
376+
oob_positions = torch.randperm(num_indices)[:num_oob]
377+
req.indices[oob_positions] = E
378+
361379
warning = torch.tensor([0]).long().to(get_device())
362380
rows_per_table = torch.tensor([E for _ in range(T)]).long().to(get_device())
363381

382+
bc_mode = BoundsCheckMode(bounds_check_mode)
383+
bounds_check_version = 1
384+
if bc_mode.name.startswith("V2_"):
385+
bounds_check_version = 2
386+
if bc_mode == BoundsCheckMode.V2_IGNORE:
387+
bc_mode = BoundsCheckMode.IGNORE
388+
elif bc_mode == BoundsCheckMode.V2_WARNING:
389+
bc_mode = BoundsCheckMode.WARNING
390+
elif bc_mode == BoundsCheckMode.V2_FATAL:
391+
bc_mode = BoundsCheckMode.FATAL
392+
364393
def _kineto_trace_handler(p: profile) -> None:
365394
p.export_chrome_trace(trace_url.format(ospid=os.getpid()))
366395

@@ -411,20 +440,22 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
411440
rows_per_table,
412441
indices,
413442
offsets,
414-
BoundsCheckMode(bounds_check_mode),
443+
bc_mode,
415444
warning,
416445
B_offsets=B_offsets,
417446
max_B=max_B,
418447
b_t_map=b_t_map,
419448
info_B_num_bits=info_B_num_bits,
420449
info_B_mask=info_B_mask,
450+
bounds_check_version=bounds_check_version,
421451
),
422452
num_warmups=warmup_runs,
423453
)
424454

425455
logging.info(
426456
f"Bounds Check Indices: Bs: {Bs}, "
427457
f"E: {E}, T: {T}, L: {L}, "
458+
f"mode: {bc_mode.name}, v: {bounds_check_version}, "
428459
f"BW: {(8 * total_B * L + 8 * (total_B + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
429460
f"T: {time_per_iter * 1.0e6:.0f}us"
430461
)

fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from fbgemm_gpu.utils.loader import load_torch_module
3030

31-
from .cache_config import CacheAlgorithm
31+
from .cache_config import CacheAlgorithm # usort:skip
3232

3333
try:
3434
load_torch_module(

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
TBEStatsReporterConfig,
5555
)
5656

57-
from .ssd_config import BackendType, EvictionPolicy, KVZCHParams
57+
from .ssd_config import BackendType, EvictionPolicy, KVZCHParams # usort:skip
5858
from torch import distributed as dist, nn, Tensor # usort:skip
5959
import sys
6060
from dataclasses import dataclass

0 commit comments

Comments
 (0)