|
11 | 11 | from typing import Any, Optional |
12 | 12 |
|
13 | 13 | import torch.cuda |
| 14 | +from torch.cuda.nvtx import range as nvtx_range |
14 | 15 |
|
15 | 16 | from utils import set_seed, clear_l2_cache |
16 | 17 |
|
@@ -499,26 +500,126 @@ def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: l |
499 | 500 | return 112 |
500 | 501 |
|
501 | 502 |
|
502 | | -def run_single_profile(test: TestCase) -> str: |
| 503 | +def _run_single_profile(test: TestCase) -> str: |
503 | 504 | """ |
504 | 505 | Runs a single test case. Do not call directly |
505 | 506 | """ |
506 | 507 | from submission import custom_kernel |
507 | | - from torch.profiler import profile, record_function, ProfilerActivity |
508 | | - data = generate_input(**test.args) |
509 | | - torch.cuda.synchronize() |
| 508 | + from torch.profiler import profile, ProfilerActivity |
510 | 509 |
|
511 | | - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
512 | | - submission_output = custom_kernel(_clone_data(data, 0)) |
| 510 | + with nvtx_range("generate input"): |
| 511 | + data = generate_input(**test.args) |
513 | 512 | torch.cuda.synchronize() |
| 513 | + |
| 514 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
| 515 | + with nvtx_range("custom_kernel"): |
| 516 | + submission_output = custom_kernel(_clone_data(data, 0)) |
| 517 | + torch.cuda.synchronize() |
| 518 | + |
514 | 519 | return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) |
515 | 520 |
|
516 | 521 |
|
517 | | -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): |
| 522 | +def _run_distributed_profile(test: TestCase, rank: int) -> "EventList": |
| 523 | + """ |
| 524 | + Runs a single profiling case. Do not call directly |
| 525 | + """ |
| 526 | + from submission import custom_kernel |
| 527 | + from torch.profiler import profile, ProfilerActivity |
| 528 | + import torch.distributed as dist |
| 529 | + |
| 530 | + with nvtx_range(f"init nccl, rank {rank}"): |
| 531 | + world_size = test.args["world_size"] |
| 532 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 533 | + os.environ["MASTER_PORT"] = "12356" |
| 534 | + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) |
| 535 | + |
| 536 | + try: |
| 537 | + with nvtx_range(f"generate input, rank {rank}"): |
| 538 | + data = generate_input(**test.args, rank=rank) |
| 539 | + data = _clone_data(data, rank) |
| 540 | + torch.cuda.synchronize() |
| 541 | + dist.barrier() |
| 542 | + |
| 543 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
| 544 | + with nvtx_range(f"custom_kernel, rank {rank}"): |
| 545 | + submission_output = custom_kernel(data) |
| 546 | + torch.cuda.synchronize() |
| 547 | + dist.barrier() |
| 548 | + |
| 549 | + return prof.events() |
| 550 | + |
| 551 | + finally: |
| 552 | + dist.destroy_process_group() |
| 553 | + |
| 554 | + |
| 555 | +def _combine_traces(traces: list["EventList"]) -> "EventList": |
| 556 | + """ |
| 557 | + Combine multiple event traces obtained from multiple (distributed) torch.profiler |
| 558 | + activities. This function simply aggregates the data as like `prof.key_averages()`, |
| 559 | + except over multiple traces. Most of this function is reimplemented |
| 560 | + from `torch.autograd.profiler_util.EventList.key_averages()`. |
| 561 | + """ |
| 562 | + from torch.autograd.profiler_util import FunctionEventAvg, EventList |
| 563 | + from collections import defaultdict |
| 564 | + |
| 565 | + def get_key(event) -> tuple[str, ...]: |
| 566 | + return ( |
| 567 | + str(event.key), |
| 568 | + str(event.node_id), |
| 569 | + str(event.device_type), |
| 570 | + str(event.is_legacy), |
| 571 | + str(event.is_user_annotation), |
| 572 | + ) |
| 573 | + |
| 574 | + stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) |
| 575 | + |
| 576 | + for events in traces: |
| 577 | + for event in events: |
| 578 | + stats[get_key(event)].add(event) |
| 579 | + |
| 580 | + avg_list = EventList(stats.values()) |
| 581 | + for event in avg_list: |
| 582 | + event.stack = [] |
| 583 | + event.input_shapes = "" |
| 584 | + event.overload_name = "" |
| 585 | + |
| 586 | + return avg_list |
| 587 | + |
| 588 | + |
| 589 | +def run_multi_gpu_profile(pool: multiprocessing.Pool, test: TestCase, world_size: int) -> str: |
| 590 | + """ |
| 591 | + Runs a single test in another process. |
| 592 | + """ |
| 593 | + rets = [] |
| 594 | + # world_size is a mandatory argument for multi-gpu tests |
| 595 | + for i in range(world_size): |
| 596 | + rets.append( |
| 597 | + pool.apply_async( |
| 598 | + _run_distributed_profile, |
| 599 | + args=(test, i), |
| 600 | + ) |
| 601 | + ) |
| 602 | + |
| 603 | + rets = [el.get(120) for el in rets] |
| 604 | + return _combine_traces(rets).table(sort_by="self_cuda_time_total", row_limit=20) |
| 605 | + |
| 606 | + |
| 607 | +def run_single_profile(test: TestCase, pool: multiprocessing.Pool) -> str: |
| 608 | + """ |
| 609 | + Runs a single profiling activity in another process. |
| 610 | + """ |
| 611 | + world_size = test.args.get("world_size", None) |
| 612 | + if world_size is None: |
| 613 | + return pool.apply(_run_single_profile, (test,)) |
| 614 | + else: |
| 615 | + return run_multi_gpu_profile(pool, test, world_size) |
| 616 | + |
| 617 | + |
| 618 | +def run_profiling(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): |
518 | 619 | logger.log("benchmark-count", len(tests)) |
519 | 620 | for idx, test in enumerate(tests): |
520 | 621 | logger.log(f"benchmark.{idx}.spec", test.spec) |
521 | | - report = run_single_profile(test) |
| 622 | + report = run_single_profile(test, pool) |
522 | 623 | logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) |
523 | 624 | logger.log("check", "pass") |
524 | 625 | return 0 |
@@ -568,7 +669,7 @@ def main(): |
568 | 669 |
|
569 | 670 | logger.log("check", "pass" if passed else "fail") |
570 | 671 | elif mode == "profile": |
571 | | - run_profiling(logger, tests) |
| 672 | + run_profiling(logger, pool, tests) |
572 | 673 | else: |
573 | 674 | # invalid mode |
574 | 675 | return 2 |
|
0 commit comments