Skip to content

Commit 9abdeb2

Browse files
committed
[ROCm] Add Triton autotuning configs for MI300 and MI350
Extend the Triton GEMM autotuner with dedicated config sets for AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 58 configs), expanding beyond the generic 6-config ROCm default.
1 parent 421117c commit 9abdeb2

6 files changed

Lines changed: 135 additions & 1 deletion

File tree

xla/backends/gpu/autotuner/triton.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ namespace {
6767
std::vector<TritonGemmConfig> GetDefaultTritonConfigs(
6868
se::GpuComputeCapability compute_capability) {
6969
if (compute_capability.IsRocm()) {
70+
const auto* rocm_cc = compute_capability.rocm_compute_capability();
71+
if (rocm_cc->gfx9_mi300()) {
72+
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI300);
73+
} else if (rocm_cc->gfx9_mi350()) {
74+
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI350);
75+
}
7076
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultRocm);
7177
}
7278

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2026 The OpenXLA Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
config { block_m: 32 block_n: 32 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
16+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
17+
config { block_m: 32 block_n: 64 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
18+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
19+
config { block_m: 16 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
20+
config { block_m: 16 block_n: 128 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
21+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
22+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
23+
config { block_m: 128 block_n: 256 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
24+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
25+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
26+
config { block_m: 32 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
27+
config { block_m: 64 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
28+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
29+
config { block_m: 128 block_n: 64 block_k: 128 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
30+
config { block_m: 128 block_n: 128 block_k: 32 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
31+
config { block_m: 256 block_n: 128 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
32+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
33+
config { block_m: 128 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
34+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
35+
config { block_m: 128 block_n: 32 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
36+
config { block_m: 32 block_n: 32 block_k: 32 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
37+
config { block_m: 64 block_n: 32 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
38+
config { block_m: 256 block_n: 8 block_k: 32 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
39+
config { block_m: 128 block_n: 16 block_k: 128 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
40+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 5 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
41+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
42+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
43+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
44+
config { block_m: 32 block_n: 16 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
45+
config { block_m: 32 block_n: 16 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 2 }
46+
config { block_m: 256 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
47+
config { block_m: 128 block_n: 8 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2026 The OpenXLA Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
config { block_m: 32 block_n: 32 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
16+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
17+
config { block_m: 32 block_n: 64 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
18+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
19+
config { block_m: 16 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
20+
config { block_m: 16 block_n: 128 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
21+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
22+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
23+
config { block_m: 128 block_n: 256 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
24+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
25+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
26+
config { block_m: 32 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
27+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
28+
config { block_m: 32 block_n: 8 block_k: 32 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
29+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
30+
config { block_m: 256 block_n: 256 block_k: 16 num_stages: 4 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
31+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
32+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
33+
config { block_m: 256 block_n: 128 block_k: 16 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
34+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
35+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
36+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 2 }
37+
config { block_m: 64 block_n: 16 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
38+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
39+
config { block_m: 64 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
40+
config { block_m: 32 block_n: 16 block_k: 32 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
41+
config { block_m: 64 block_n: 8 block_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
42+
config { block_m: 16 block_n: 8 block_k: 256 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
43+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
44+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
45+
config { block_m: 16 block_n: 64 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
46+
config { block_m: 32 block_n: 16 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
47+
config { block_m: 64 block_n: 8 block_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
48+
config { block_m: 64 block_n: 8 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
49+
config { block_m: 64 block_n: 8 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
50+
config { block_m: 64 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
51+
config { block_m: 128 block_n: 8 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
52+
config { block_m: 128 block_n: 32 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
53+
config { block_m: 128 block_n: 64 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
54+
config { block_m: 32 block_n: 32 block_k: 128 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
55+
config { block_m: 32 block_n: 32 block_k: 128 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 4 }
56+
config { block_m: 32 block_n: 64 block_k: 128 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
57+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
58+
config { block_m: 64 block_n: 64 block_k: 32 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
59+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 4 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
60+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 1 }
61+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
62+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 4 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
63+
config { block_m: 64 block_n: 128 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
64+
config { block_m: 128 block_n: 64 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
65+
config { block_m: 128 block_n: 64 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
66+
config { block_m: 128 block_n: 128 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 1 }
67+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
68+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 2 }
69+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 4 }
70+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
71+
config { block_m: 256 block_n: 256 block_k: 64 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
72+
config { block_m: 256 block_n: 256 block_k: 64 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 1 }

xla/backends/gpu/autotuner/triton/triton_configs.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ const std::vector<TritonGemmConfig>& GetTritonConfigsForPlatform(
6060
ParseConfig(configs::get_cuda())},
6161
{TritonConfigsPlatform::kDefaultRocm,
6262
ParseConfig(configs::get_rocm())},
63-
{TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())}});
63+
{TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())},
64+
{TritonConfigsPlatform::kMI300, ParseConfig(configs::get_mi300())},
65+
{TritonConfigsPlatform::kMI350,
66+
ParseConfig(configs::get_mi350())}});
6467
return kConfigs->at(platform);
6568
}
6669

xla/backends/gpu/autotuner/triton/triton_configs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ enum class TritonConfigsPlatform {
2929
kDefaultCuda,
3030
kDefaultRocm,
3131
kHopper,
32+
kMI300,
33+
kMI350,
3234
};
3335

3436
const std::vector<TritonGemmConfig>& GetTritonConfigsForPlatform(

xla/backends/gpu/autotuner/triton/triton_configs_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ TEST(TritonConfigsTest, PlatformsReturnNonEmptyConfig) {
3434
SizeIs(2));
3535
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kHopper),
3636
SizeIs(25));
37+
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI300),
38+
SizeIs(33));
39+
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI350),
40+
SizeIs(58));
3741
}
3842

3943
} // namespace

0 commit comments

Comments
 (0)