forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbencher_utils.py
More file actions
139 lines (114 loc) · 4.39 KB
/
bencher_utils.py
File metadata and controls
139 lines (114 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
import csv
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class BenchMetric:
code: int
"""Op-code of the Metric."""
name: str
"""Metric's name."""
unit: str
"""Metric's throughput rate unit (count/second)."""
@dataclass
class ThroughputMeasure:
"""Records a throughput metric of metric BenchMetric and value."""
metric: BenchMetric
"""Type of throughput metric."""
value: int
"""Measured count of throughput metric."""
def compute(self, elapsed_sec: float) -> float:
"""Computes throughput rate for this metric per unit of time (second).
Args:
elapsed_sec: Elapsed time measured in seconds.
Returns:
The throughput values as a floating point 64.
"""
# TODO: do we need support other units of time (ms, ns)?
return (self.value * 1e-9) / elapsed_sec
@dataclass
class Bench:
"""Constructs a Benchmark object, used for running multiple benchmarks
and comparing the results.
Args:
name: Name of benchmark entry (string).
met: Measured execution time for the entry in seconds (float).
iters: Number of iterations in the measurement (int).
metric_list: List of ThroughputMeasure's.
"""
name: str
met: float
iters: int
metric_list: list[ThroughputMeasure] = field(default_factory=list)
elements = BenchMetric(0, "throughput", "GElems/s")
bytes = BenchMetric(1, "DataMovement", "GB/s")
flops = BenchMetric(2, "Arithmetic", "GFLOPS/s")
theoretical_flops = BenchMetric(3, "TheoreticalArithmetic", "GFLOPS/s")
BENCH_LABEL = "name"
MET_LABEL = "met (s)"
ITERS_LABEL = "iters"
def dump_report(self, output_path: Path | None = None) -> None:
metrics = [self.BENCH_LABEL, self.MET_LABEL, self.ITERS_LABEL] + [
f"{m.metric.name} ({m.metric.unit})" for m in self.metric_list
]
vals = ['"' + self.name + '"', self.met, self.iters] + [
f"{m.compute(self.met)}" for m in self.metric_list
]
rows = (metrics, vals)
if output_path:
with open(output_path, "w") as f:
w = csv.writer(f, delimiter=",", quotechar="'")
w.writerows(rows)
with sys.stdout as f:
w = csv.writer(f, delimiter=",", quotechar="'")
w.writerows(rows)
def arg_parse(handle: str, default: Any = None, short_handle: str = "") -> str:
# TODO: add constraints on dtype of return value
handle = handle.lstrip("-")
short_handle = short_handle.lstrip("-")
args = sys.argv
for i in range(len(args)):
if handle and args[i].startswith("--" + handle):
if "=" in args[i]:
name_val = args[i].split("=")
return name_val[1]
else:
return args[i + 1]
elif short_handle and args[i].startswith("-" + short_handle):
return args[i + 1]
return default
def check_mpirun() -> int:
"""
Check environment to examine whether the benchmark is called via mpirun.
If so, use pe_rank=OMPI_COMM_WORLD_RANK as a suffix for output file.
Raises:
If the operation fails.
Returns:
An integer representing pe rank (default=-1).
"""
pe_rank = -1
if (
"OMPI_COMM_WORLD_RANK" in os.environ
and "OMPI_COMM_WORLD_SIZE" in os.environ
):
TORCHRUN_DEFAULT_PORT = "29500"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = TORCHRUN_DEFAULT_PORT
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
pe_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
return pe_rank