Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,37 @@ def _export_kineto_trace(
default=False,
help="Use manual seed for reproduction.",
)
@click.option(
"--device",
type=click.Choice(["cpu", "cuda"]),
default="cuda",
help="Device to run the benchmark on. Default is cuda.",
)
@click.option(
"--export-trace",
is_flag=True,
default=False,
help="Enable export of trace for profiling.",
)
@click.option(
"--trace-url",
type=str,
default="batched_dense_vec_jagged_2d_mul_trace_{ospid}.json",
)
def batched_dense_vec_jagged_2d_mul(
batch_size: int,
h_dim: int,
embedding_dim: int,
max_len: int,
elem_type: str,
manual_seed: bool,
device: str,
export_trace: bool,
trace_url: str,
) -> None:
if device == "cuda" and not torch.cuda.is_available():
raise click.UsageError("CUDA requested but not available.")

# set manual seed for reproducibility
if manual_seed:
torch.manual_seed(42)
Expand All @@ -586,7 +609,7 @@ def batched_dense_vec_jagged_2d_mul(
# pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`.
values_2d = torch.rand(total_lengths, h_dim * embedding_dim, dtype=dtype)
dense = torch.rand(batch_size * h_dim, max_len, dtype=dtype)
if torch.cuda.is_available():
if device == "cuda":
offsets = offsets.cuda()
values_2d = values_2d.cuda()
dense = dense.cuda()
Expand All @@ -605,6 +628,41 @@ def batched_dense_vec_jagged_2d_mul(
f"batched_dense_vec_jagged_2d_mul {time} sec {num_flops / time / 1e9} GFLOP/s"
)

if export_trace:
is_cuda = device != "cpu"

# pyre-fixme[53]: Captured variable `dense` is not annotated.
# pyre-fixme[53]: Captured variable `values_2d` is not annotated.
# pyre-fixme[53]: Captured variable `offsets` is not annotated.
def fn() -> torch.Tensor:
return torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul(
dense, values_2d, offsets
)

for _ in range(100):
fn()
if is_cuda:
torch.cuda.synchronize(device)
# pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`.
activities = [torch.profiler.ProfilerActivity.CPU]
if is_cuda:
# pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`.
activities.append(torch.profiler.ProfilerActivity.CUDA)
num_active = 5
with profile(
activities=activities,
schedule=schedule(wait=0, warmup=0, active=num_active, repeat=1),
record_shapes=True,
on_trace_ready=lambda p: p.export_chrome_trace(
trace_url.format(ospid=os.getpid())
),
) as prof:
for _ in range(num_active):
fn()
if is_cuda:
torch.cuda.synchronize(device)
prof.step()


@cli.command()
@click.option("--batch-size", type=int, default=1024)
Expand All @@ -615,12 +673,37 @@ def batched_dense_vec_jagged_2d_mul(
default=False,
help="Use manual seed for reproduction.",
)
@click.option(
"--device",
type=click.Choice(["cpu"]),
default="cpu",
help="Device to run the benchmark on. CPU-only (no CUDA kernel exists).",
)
@click.option(
"--export-trace",
is_flag=True,
default=False,
help="Enable export of trace for profiling.",
)
@click.option(
"--trace-url",
type=str,
default="jagged_1d_to_truncated_values_trace_{ospid}.json",
)
def jagged_1d_to_truncated_values(
batch_size: int,
max_len: int,
dtype: str,
manual_seed: bool,
device: str,
export_trace: bool,
trace_url: str,
) -> None:
if max_len <= 0:
raise click.UsageError("max_len must be positive.")
if batch_size <= 0:
raise click.UsageError("batch_size must be positive.")

# set manual seed for reproducibility
if manual_seed:
torch.manual_seed(42)
Expand Down Expand Up @@ -664,6 +747,32 @@ def ref(values: torch.Tensor, lengths: torch.Tensor, max_len: int) -> torch.Tens
logging.info(f"reference {time_ref} sec {bytes / time_ref / 1e9} GB/s")
logging.info(f"truncate_jagged_1d {time} sec {bytes / time / 1e9} GB/s")

if export_trace:

# pyre-fixme[53]: Captured variable `values` is not annotated.
# pyre-fixme[53]: Captured variable `lengths` is not annotated.
def fn() -> torch.Tensor:
return torch.ops.fbgemm.jagged_1d_to_truncated_values(
values, lengths, max_len
)

for _ in range(100):
fn()
# pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`.
activities = [torch.profiler.ProfilerActivity.CPU]
num_active = 5
with profile(
activities=activities,
schedule=schedule(wait=0, warmup=0, active=num_active, repeat=1),
record_shapes=True,
on_trace_ready=lambda p: p.export_chrome_trace(
trace_url.format(ospid=os.getpid())
),
) as prof:
for _ in range(num_active):
fn()
prof.step()


@cli.command()
@click.option("--batch-size", type=int, default=1024)
Expand Down
Loading