Skip to content

Commit 63cfd82

Browse files
committed
[ROCm] Add Triton autotuning configs for MI300 and MI350
Extend the Triton GEMM autotuner with dedicated config sets for AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 58 configs), expanding beyond the generic 6-config ROCm default.
1 parent 32b0167 commit 63cfd82

8 files changed

Lines changed: 136 additions & 3 deletions

File tree

xla/backends/gpu/autotuner/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ cc_library(
405405
"//xla/service/gpu:matmul_utils",
406406
"//xla/service/gpu/model:triton_emitter_constraints",
407407
"//xla/stream_executor:device_description",
408-
"//xla/stream_executor:stream_executor_h",
409408
"//xla/stream_executor/cuda:cuda_compute_capability",
410409
"//xla/tsl/platform:env",
411410
"//xla/tsl/platform:errors",

xla/backends/gpu/autotuner/triton.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "absl/strings/str_cat.h"
2929
#include "absl/strings/string_view.h"
3030
#include "xla/tsl/platform/status_macros.h"
31+
#include "google/protobuf/any.pb.h"
3132
#include "google/protobuf/text_format.h"
3233
#include "xla/autotuning.pb.h"
3334
#include "xla/backends/autotuner/codegen_backend.h"
@@ -70,6 +71,13 @@ namespace {
7071
std::vector<TritonGemmConfig> GetDefaultTritonConfigs(
7172
se::GpuComputeCapability compute_capability) {
7273
if (compute_capability.IsRocm()) {
74+
const auto* rocm_cc = compute_capability.rocm_compute_capability();
75+
if (rocm_cc->gfx9_mi300()) {
76+
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI300);
77+
}
78+
if (rocm_cc->gfx9_mi350()) {
79+
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI350);
80+
}
7381
return GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultRocm);
7482
}
7583

xla/backends/gpu/autotuner/triton.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ limitations under the License.
2929
#include "xla/hlo/ir/hlo_instruction.h"
3030
#include "xla/hlo/ir/hlo_module.h"
3131
#include "xla/service/compiler.h"
32-
#include "xla/stream_executor/stream_executor.h"
3332
#include "xla/xla.pb.h"
3433

3534
namespace xla {
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2026 The OpenXLA Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
config { block_m: 32 block_n: 32 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
16+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
17+
config { block_m: 32 block_n: 64 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
18+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
19+
config { block_m: 16 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
20+
config { block_m: 16 block_n: 128 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
21+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
22+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
23+
config { block_m: 128 block_n: 256 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
24+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
25+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
26+
config { block_m: 32 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
27+
config { block_m: 64 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
28+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
29+
config { block_m: 128 block_n: 64 block_k: 128 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
30+
config { block_m: 128 block_n: 128 block_k: 32 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
31+
config { block_m: 256 block_n: 128 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
32+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
33+
config { block_m: 128 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
34+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
35+
config { block_m: 128 block_n: 32 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
36+
config { block_m: 32 block_n: 32 block_k: 32 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
37+
config { block_m: 64 block_n: 32 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
38+
config { block_m: 256 block_n: 8 block_k: 32 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
39+
config { block_m: 128 block_n: 16 block_k: 128 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
40+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 5 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
41+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
42+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
43+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
44+
config { block_m: 32 block_n: 16 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
45+
config { block_m: 32 block_n: 16 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 2 }
46+
config { block_m: 256 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
47+
config { block_m: 128 block_n: 8 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2026 The OpenXLA Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
config { block_m: 32 block_n: 32 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
16+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
17+
config { block_m: 32 block_n: 64 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
18+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
19+
config { block_m: 16 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
20+
config { block_m: 16 block_n: 128 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
21+
config { block_m: 256 block_n: 256 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
22+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
23+
config { block_m: 128 block_n: 256 block_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
24+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
25+
config { block_m: 128 block_n: 128 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
26+
config { block_m: 32 block_n: 8 block_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
27+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
28+
config { block_m: 32 block_n: 8 block_k: 32 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
29+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
30+
config { block_m: 256 block_n: 256 block_k: 16 num_stages: 4 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
31+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
32+
config { block_m: 256 block_n: 128 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
33+
config { block_m: 256 block_n: 128 block_k: 16 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
34+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
35+
config { block_m: 32 block_n: 16 block_k: 128 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
36+
config { block_m: 64 block_n: 8 block_k: 128 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 2 }
37+
config { block_m: 64 block_n: 16 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
38+
config { block_m: 128 block_n: 32 block_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
39+
config { block_m: 64 block_n: 32 block_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
40+
config { block_m: 32 block_n: 16 block_k: 32 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
41+
config { block_m: 64 block_n: 8 block_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
42+
config { block_m: 16 block_n: 8 block_k: 256 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
43+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
44+
config { block_m: 16 block_n: 16 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
45+
config { block_m: 16 block_n: 64 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
46+
config { block_m: 32 block_n: 16 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
47+
config { block_m: 64 block_n: 8 block_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 waves_per_eu: 4 }
48+
config { block_m: 64 block_n: 8 block_k: 64 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
49+
config { block_m: 64 block_n: 8 block_k: 256 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
50+
config { block_m: 64 block_n: 16 block_k: 256 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
51+
config { block_m: 128 block_n: 8 block_k: 32 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
52+
config { block_m: 128 block_n: 32 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
53+
config { block_m: 128 block_n: 64 block_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
54+
config { block_m: 32 block_n: 32 block_k: 128 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 0 }
55+
config { block_m: 32 block_n: 32 block_k: 128 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 4 }
56+
config { block_m: 32 block_n: 64 block_k: 128 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
57+
config { block_m: 64 block_n: 32 block_k: 32 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
58+
config { block_m: 64 block_n: 64 block_k: 32 num_stages: 4 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
59+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 4 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
60+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 1 }
61+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 3 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
62+
config { block_m: 64 block_n: 64 block_k: 128 num_stages: 4 num_warps: 4 num_ctas: 1 waves_per_eu: 4 }
63+
config { block_m: 64 block_n: 128 block_k: 128 num_stages: 1 num_warps: 2 num_ctas: 1 waves_per_eu: 1 }
64+
config { block_m: 128 block_n: 64 block_k: 64 num_stages: 2 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
65+
config { block_m: 128 block_n: 64 block_k: 128 num_stages: 2 num_warps: 2 num_ctas: 1 waves_per_eu: 0 }
66+
config { block_m: 128 block_n: 128 block_k: 32 num_stages: 2 num_warps: 8 num_ctas: 1 waves_per_eu: 1 }
67+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
68+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 2 }
69+
config { block_m: 128 block_n: 128 block_k: 128 num_stages: 1 num_warps: 8 num_ctas: 1 waves_per_eu: 4 }
70+
config { block_m: 128 block_n: 256 block_k: 64 num_stages: 1 num_warps: 4 num_ctas: 1 waves_per_eu: 2 }
71+
config { block_m: 256 block_n: 256 block_k: 64 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 0 }
72+
config { block_m: 256 block_n: 256 block_k: 64 num_stages: 3 num_warps: 8 num_ctas: 1 waves_per_eu: 1 }

xla/backends/gpu/autotuner/triton/triton_configs.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ const std::vector<TritonGemmConfig>& GetTritonConfigsForPlatform(
6060
ParseConfig(configs::get_cuda())},
6161
{TritonConfigsPlatform::kDefaultRocm,
6262
ParseConfig(configs::get_rocm())},
63-
{TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())}});
63+
{TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())},
64+
{TritonConfigsPlatform::kMI300, ParseConfig(configs::get_mi300())},
65+
{TritonConfigsPlatform::kMI350, ParseConfig(configs::get_mi350())}});
6466
return kConfigs->at(platform);
6567
}
6668

xla/backends/gpu/autotuner/triton/triton_configs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ enum class TritonConfigsPlatform {
2929
kDefaultCuda,
3030
kDefaultRocm,
3131
kHopper,
32+
kMI300,
33+
kMI350,
3234
};
3335

3436
const std::vector<TritonGemmConfig>& GetTritonConfigsForPlatform(

xla/backends/gpu/autotuner/triton/triton_configs_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ TEST(TritonConfigsTest, PlatformsReturnNonEmptyConfig) {
3434
SizeIs(2));
3535
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kHopper),
3636
SizeIs(25));
37+
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI300),
38+
SizeIs(33));
39+
EXPECT_THAT(GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI350),
40+
SizeIs(58));
3741
}
3842

3943
} // namespace

0 commit comments

Comments
 (0)