forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.mojo
More file actions
86 lines (69 loc) · 2.75 KB
/
Copy pathsample.mojo
File metadata and controls
86 lines (69 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from sys import env_get_dtype, env_get_int
from benchmark import Bench, BenchConfig, Bencher, BenchId
from internal_utils import (
Mode,
arg_parse,
env_get_shape,
int_list_to_tuple,
update_bench_config_args,
)
from time import sleep
from os import getenv
# mojo build sample.mojo
# mpirun -n 8 ./sample -o output.csv
fn bench_func[
dtype: DType, M: Int, N: Int, K: Int, stages: Int
](mut m: Bench, mode: Mode, pe_rank: Int) raises:
@parameter
@always_inline
fn bench_iter(mut b: Bencher):
@parameter
@always_inline
fn call_fn():
sleep(0.01)
b.iter[call_fn]()
var name = String(
"gemm/dtype=", dtype, "/m=", M, "/n=", N, "/k=", N, "/stages=", stages
)
if Mode.BENCHMARK == mode:
m.bench_function[bench_iter](
BenchId(name, input_id=String("1st-metric (pe_rank=", pe_rank, ")"))
)
# TODO: enable the following line after adding support for multi-output to kplot and kprofile.
# m.bench_function[bench_iter](BenchId(name, input_id=String("2nd-metric (pe_rank=",pe_rank,")")))
if Mode.VERIFY == mode:
print("verifying dummy results...PASS")
if Mode.RUN == mode:
print("pretending to run the kernel...PASS")
def main():
comptime dtype = env_get_dtype["dtype", DType.float16]()
comptime shape_int_list = env_get_shape["shape", "1024x1024x1024"]()
comptime shape = int_list_to_tuple[shape_int_list]()
comptime stages = env_get_int["stages", 0]()
var runtime_x = arg_parse("x", 0)
# define benchmark mode: [run, benchmark, verify] or a combo (run+benchmark)
var mode = Mode(arg_parse("mode", "benchmark"))
print("mode=" + String(mode))
if Mode.RUN == mode:
print("-- mode: run kernel once")
if Mode.BENCHMARK == mode:
print("-- mode: run kernel benchmark")
if Mode.VERIFY == mode:
print("-- mode: verify kernel")
var m = Bench(BenchConfig(max_iters=1, max_batch_size=1))
var pe_rank = m.check_mpirun()
update_bench_config_args(m)
bench_func[dtype, shape[0], shape[1], shape[2], stages](m, mode, pe_rank)
m.dump_report()