Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5d6b63c
Utilize amd_arch_db python bindings where applicable.
mirza-halilcevic Apr 19, 2026
b2c40d8
Address copilot comments.
mirza-halilcevic Apr 20, 2026
73fb737
Address Cursor code review comments.
mirza-halilcevic Apr 20, 2026
908e0b7
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 21, 2026
7112ca0
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 21, 2026
9f7c4de
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 23, 2026
b9a8d43
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 23, 2026
c8a62f0
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 24, 2026
3277fa8
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 24, 2026
e800f70
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 28, 2026
8cbd295
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 29, 2026
abac46d
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic Apr 30, 2026
6afe86f
Merge remote-tracking branch 'origin/develop' into amd-arch-db-py
mirza-halilcevic May 2, 2026
ef4c485
Merge remote-tracking branch 'origin/develop' into amd-arch-db-py
mirza-halilcevic May 3, 2026
811d0a2
Merge branch 'develop' into amd-arch-db-py
dorde-antic May 5, 2026
d604349
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic May 11, 2026
53deaef
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic May 12, 2026
1f0d563
Address code review.
mirza-halilcevic May 12, 2026
b962144
Fix typo
mirza-halilcevic May 14, 2026
e389a46
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic May 14, 2026
4745e80
Fix formatting.
mirza-halilcevic May 14, 2026
3da3f52
Merge branch 'develop' into amd-arch-db-py
mirza-halilcevic May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/Rock/IR/AmdArchDb.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct AmdArchInfo {
int64_t minNumCU;
bool hasFp8ConversionInstrs;
bool hasOcpFp8ConversionInstrs;
bool hasFp4;
bool hasScaledGemm;
int64_t maxNumXCC;
bool hasLdsTransposeLoad;
Expand All @@ -42,14 +43,15 @@ struct AmdArchInfo {
int64_t totalVGPRPerEU, int64_t sharedMemPerCU,
int64_t sharedMemPerWG, int64_t numEUPerCU,
int64_t minNumCU, bool hasFp8ConversionInstrs,
bool hasOcpFp8ConversionInstrs, bool hasScaledGemm,
int64_t maxNumXCC, bool hasLdsTransposeLoad)
bool hasOcpFp8ConversionInstrs, bool hasFp4,
bool hasScaledGemm, int64_t maxNumXCC,
bool hasLdsTransposeLoad)
: defaultFeatures(defaultFeatures), waveSize(waveSize),
maxWavesPerEU(maxWavesPerEU), totalSGPRPerEU(totalSGPRPerEU),
totalVGPRPerEU(totalVGPRPerEU), totalSharedMemPerCU(sharedMemPerCU),
maxSharedMemPerWG(sharedMemPerWG), numEUPerCU(numEUPerCU),
minNumCU(minNumCU), hasFp8ConversionInstrs(hasFp8ConversionInstrs),
hasOcpFp8ConversionInstrs(hasOcpFp8ConversionInstrs),
hasOcpFp8ConversionInstrs(hasOcpFp8ConversionInstrs), hasFp4(hasFp4),
hasScaledGemm(hasScaledGemm), maxNumXCC(maxNumXCC),
hasLdsTransposeLoad(hasLdsTransposeLoad) {}

Expand Down
35 changes: 22 additions & 13 deletions mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,34 @@ static constexpr AmdArchInfo
/*totalVGPRPerEU*/ 256, /*totalSharedMemPerCU*/ 65536,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/80,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
cdna50Info(GemmFeatures::dot, /*waveSize=*/64, /*maxWavesPerEU*/ 8,
/*totalSGPRPerEU*/ 512, /*totalVGPRPerEU*/ 256,
/*totalSharedMemPerCU*/ 65536, /*maxSharedMemPerWG*/ 65536,
/*numEUPerCU=*/4, /*minNumCU=*/10,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
cdnaInfo(GemmFeatures::mfma | GemmFeatures::dot | GemmFeatures::atomic_add |
GemmFeatures::atomic_add_f16,
/*waveSize=*/64, /*maxWavesPerEU*/ 10, /*totalSGPRPerEU*/ 800,
/*totalVGPRPerEU*/ 256, /*totalSharedMemPerCU*/ 65536,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/120,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
cdna2Info(GemmFeatures::mfma | GemmFeatures::dot |
GemmFeatures::atomic_add | GemmFeatures::atomic_add_f16,
/*waveSize=*/64, /*maxWavesPerEU*/ 8, /*totalSGPRPerEU*/ 800,
/*totalVGPRPerEU*/ 512, /*totalSharedMemPerCU*/ 65536,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/104,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
cdna3Info(GemmFeatures::mfma | GemmFeatures::dot |
GemmFeatures::atomic_add | GemmFeatures::atomic_add_f16 |
Expand All @@ -77,7 +81,8 @@ static constexpr AmdArchInfo
/*totalVGPRPerEU*/ 512, /*totalSharedMemPerCU*/ 65536,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/20,
/*hasFp8ConversionInstrs=*/true,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/8, /*hasLdsTransposeLoad=*/false),
cdna40Info(GemmFeatures::mfma | GemmFeatures::dot |
GemmFeatures::atomic_add | GemmFeatures::atomic_add_f16 |
Expand All @@ -88,7 +93,8 @@ static constexpr AmdArchInfo
/*totalVGPRPerEU*/ 512, /*totalSharedMemPerCU*/ 163840,
/*maxSharedMemPerWG*/ 163840, /*numEUPerCU=*/4, /*minNumCU=*/256,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/true, /*hasScaledGemm=*/true,
/*hasOcpFp8ConversionInstrs=*/true, /*hasFp4=*/true,
/*hasScaledGemm=*/true,
/*maxNumXCC=*/8, /*hasLdsTransposeLoad=*/true),
// amdgpu target builds all RDNA in WGP Mode
rdnaNoDotInfo(GemmFeatures::atomic_fmax_f32, /*waveSize=*/32,
Expand All @@ -97,22 +103,25 @@ static constexpr AmdArchInfo
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4,
/*minNumCU=*/30,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
rdnaInfo(GemmFeatures::dot | GemmFeatures::atomic_fmax_f32,
/*waveSize=*/32, /*maxWavesPerEU*/ 16, /*totalSGPRPerEU*/ 512,
/*totalVGPRPerEU*/ 1024, /*totalSharedMemPerCU*/ 131072,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/2,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
rdna3Info(GemmFeatures::dot | GemmFeatures::atomic_add |
GemmFeatures::atomic_fmax_f32 | GemmFeatures::wmma,
/*waveSize=*/32, /*maxWavesPerEU*/ 16, /*totalSGPRPerEU*/ 800,
/*totalVGPRPerEU*/ 1536, /*totalSharedMemPerCU*/ 131072,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/2,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/false, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
rdna4Info(GemmFeatures::dot | GemmFeatures::atomic_add |
GemmFeatures::atomic_fmax_f32 | GemmFeatures::wmma |
Expand All @@ -121,7 +130,8 @@ static constexpr AmdArchInfo
/*totalVGPRPerEU*/ 1536, /*totalSharedMemPerCU*/ 131072,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/12,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/true, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/true, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false),
// TODO: update with right information
gfx1250Info(GemmFeatures::dot | GemmFeatures::atomic_add |
Expand All @@ -132,7 +142,8 @@ static constexpr AmdArchInfo
/*totalVGPRPerEU*/ 1536, /*totalSharedMemPerCU*/ 131072,
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/12,
/*hasFp8ConversionInstrs=*/false,
/*hasOcpFp8ConversionInstrs=*/true, /*hasScaledGemm=*/false,
/*hasOcpFp8ConversionInstrs=*/true, /*hasFp4=*/false,
/*hasScaledGemm=*/false,
Comment thread
umangyadav marked this conversation as resolved.
/*maxNumXCC=*/1, /*hasLdsTransposeLoad=*/false);

static std::tuple<StringRef, unsigned> parseArchString(StringRef arch) {
Expand Down Expand Up @@ -360,8 +371,6 @@ AmdArchInfo nativeArchInfo(unsigned deviceId = 0) {
#endif // !_WIN32 && ROCMLIR_ENABLE_NATIVE_ARCH

AmdArchInfo mlir::rock::lookupArchInfo(StringRef arch) {
// Keep this implementation in sync with
// mlir/test/lit.site.cfg.py.in:set_arch_features()
auto [chip, deviceId] = parseArchString(arch);
if (chip == "native") {
#if !defined(_WIN32) && defined(ROCMLIR_ENABLE_NATIVE_ARCH)
Expand Down
27 changes: 25 additions & 2 deletions mlir/lib/Dialect/Rock/utility/Bindings/AmdArchDbBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@
#include <pybind11/pybind11.h>

#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
#include "llvm/ADT/StringRef.h"

namespace py = pybind11;

PYBIND11_MODULE(amd_arch_db, m) {
m.doc() = "Database of AMD GPU features";

py::enum_<mlir::rock::GemmFeatures>(m, "GemmFeatures")
py::enum_<mlir::rock::GemmFeatures>(m, "GemmFeatures", py::arithmetic())
.value("NONE", mlir::rock::GemmFeatures::none)
.value("MFMA", mlir::rock::GemmFeatures::mfma)
.value("WMMA", mlir::rock::GemmFeatures::wmma)
.value("DOT", mlir::rock::GemmFeatures::dot)
.value("ATOMIC_ADD", mlir::rock::GemmFeatures::atomic_add)
.value("ATOMIC_ADD_BF16", mlir::rock::GemmFeatures::atomic_add_bf16)
.value("ATOMIC_ADD_F16", mlir::rock::GemmFeatures::atomic_add_f16)
.value("ATOMIC_FMAX_F32", mlir::rock::GemmFeatures::atomic_fmax_f32);
.value("ATOMIC_FMAX_F32", mlir::rock::GemmFeatures::atomic_fmax_f32)
.value("DIRECT_TO_LDS_32B", mlir::rock::GemmFeatures::direct_to_lds_32b)
.value("DIRECT_TO_LDS_128B",
mlir::rock::GemmFeatures::direct_to_lds_128b);

py::class_<mlir::rock::AmdArchInfo>(m, "AmdArchInfo")
.def_readonly("default_features",
Expand All @@ -45,12 +49,31 @@ PYBIND11_MODULE(amd_arch_db, m) {
&mlir::rock::AmdArchInfo::hasFp8ConversionInstrs)
.def_readonly("has_ocp_fp8_conversion_instrs",
&mlir::rock::AmdArchInfo::hasOcpFp8ConversionInstrs)
.def_readonly("has_fp4", &mlir::rock::AmdArchInfo::hasFp4)
.def_readonly("has_scaled_gemm", &mlir::rock::AmdArchInfo::hasScaledGemm)
.def_readonly("max_num_xcc", &mlir::rock::AmdArchInfo::maxNumXCC)
.def_readonly("has_lds_transpose_load",
&mlir::rock::AmdArchInfo::hasLdsTransposeLoad);

m.def(
"has_feature",
[](mlir::rock::GemmFeatures features, mlir::rock::GemmFeatures flag) {
return bitEnumContainsAny(features, flag);
},
"Return True if any bit set in `flag` is also set in `features`. "
"Matches `bool(int(features) & int(flag))`.");

m.def("lookup_arch_info", [](const std::string &arch) {
// The "native:<deviceId>" code path in lookupArchInfo requires the build to
// have been configured with ROCMLIR_ENABLE_NATIVE_ARCH=ON. Without it the
// underlying call hits an llvm_unreachable, which would abort the Python
// interpreter; raise a Python-level error instead.
#ifndef ROCMLIR_ENABLE_NATIVE_ARCH
if (llvm::StringRef(arch).starts_with("native"))
throw py::value_error(
"\"native\" arch lookup is not available in this build "
"(requires ROCMLIR_ENABLE_NATIVE_ARCH=ON)");
#endif
return mlir::rock::lookupArchInfo(arch);
});
}
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Rock/utility/Bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,12 @@ if(NOT WIN32)

pybind11_add_module(amd_arch_db AmdArchDbBindings.cpp)
target_link_libraries(amd_arch_db PUBLIC MLIRRockUtility)
if(ROCMLIR_ENABLE_NATIVE_ARCH)
target_compile_definitions(amd_arch_db PRIVATE ROCMLIR_ENABLE_NATIVE_ARCH=1)
endif()
set_target_properties(amd_arch_db PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ROCMLIR_BIN_DIR}")
message(VERBOSE "amd_arch_db Python binding will be built in ${ROCMLIR_BIN_DIR}")
message(VERBOSE "To use it outside the build bin directory, set: "
"export PYTHONPATH=${ROCMLIR_BIN_DIR}:\$PYTHONPATH")
endif()
4 changes: 4 additions & 0 deletions mlir/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ list(APPEND ROCMLIR_TEST_DEPENDS
rocmlir-common-python-test-utils
)

if (TARGET amd_arch_db)
list(APPEND ROCMLIR_TEST_DEPENDS amd_arch_db)
endif()

if(MLIR_ENABLE_ROCM_RUNNER)
list(APPEND ROCMLIR_TEST_DEPENDS
mlir_runner_utils
Expand Down
71 changes: 27 additions & 44 deletions mlir/test/common_utils/common.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,31 @@
from hip import hip
from amd_arch_db import GemmFeatures, has_feature, lookup_arch_info


Comment thread
mirza-halilcevic marked this conversation as resolved.
# Helper function to decode arch to its features
# Keep this in sync with mlir/lib/Dialect/Rock/Generator/AmdArchDb.cpp:mlir::rock::lookupArchInfo
def get_arch_features(arch: str):
chip_name = arch.split(':')[0]
if len(chip_name) < 5:
return
def features_to_string(features):
val = int(features)
if val == 0:
return 'none'
# Iteration follows the .value(...) chain in
# mlir/lib/Dialect/Rock/utility/Bindings/AmdArchDbBindings.cpp, which is
# kept in sync with the bit positions in RockAttrDefs.td. Do not reorder
# without updating the bindings; lit tests match on this exact spelling.
names = []
for name, member in GemmFeatures.__members__.items():
bit = int(member)
if bit and (val & bit):
names.append(name.lower())
return '|'.join(names)


arch_features = None
support_mfma = False
support_wmma = False
support_accel_fp8 = False
major = chip_name[:-2]
minor = chip_name[-2:]
if major == 'gfx9':
if minor in ['08', '0a']:
arch_features = 'mfma|dot|atomic_add|atomic_add_f16'
elif minor == '42':
arch_features = 'mfma|dot|atomic_add|atomic_add_f16|direct_to_lds_32b'
support_accel_fp8 = True
elif minor == '50':
arch_features = 'mfma|dot|atomic_add|atomic_add_f16|atomic_add_bf16|direct_to_lds_32b|direct_to_lds_128b|lds_transpose_load'
support_accel_fp8 = True
elif minor == '06':
arch_features = 'dot'
else:
arch_features = 'none'
elif major == 'gfx10':
if minor in ['11', '13']:
arch_features = 'atomic_fmax_f32'
elif minor in ['10', '12'] or minor[0] == '3':
arch_features = 'dot|atomic_fmax_f32'
else:
arch_features = 'atomic_fmax_f32'
elif major == 'gfx11':
arch_features = 'dot|atomic_add|atomic_fmax_f32|wmma'
elif major == 'gfx12':
arch_features = 'dot|atomic_add|atomic_add_f16|atomic_add_bf16|atomic_fmax_f32|wmma'
support_accel_fp8 = True
if arch_features and 'mfma' in arch_features:
support_mfma = True
pass
elif arch_features and 'wmma' in arch_features:
support_wmma = True
pass
def get_arch_features(arch: str):
info = lookup_arch_info(arch)
arch_features = features_to_string(info.default_features)
if info.has_lds_transpose_load:
arch_features += '|lds_transpose_load'
support_mfma = has_feature(info.default_features, GemmFeatures.MFMA)
support_wmma = has_feature(info.default_features, GemmFeatures.WMMA)
support_accel_fp8 = info.has_fp8_conversion_instrs or info.has_ocp_fp8_conversion_instrs
return arch_features, support_mfma, support_wmma, support_accel_fp8


Expand Down Expand Up @@ -82,4 +63,6 @@ def get_default_agent():

def is_xdlops_present() -> bool:
"""This function checks whether a GPU with xdlops support is present"""
return any([agent.startswith("gfx9") for agent in get_agents()])
return any(
has_feature(lookup_arch_info(agent).default_features, GemmFeatures.MFMA)
for agent in get_agents())
1 change: 1 addition & 0 deletions mlir/test/e2e/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ config.rocmlir_common_python_tests_utils = "@ROCMLIR_COMMON_PYTHON_TESTS_UTILS@"

# Add common python test utils
sys.path.append(config.rocmlir_common_python_tests_utils)
sys.path.append(config.mlir_rock_tools_dir)
from common import get_agents, get_arch_features, get_default_agent

# Support substitution of the tools_dir with user parameters. This is
Expand Down
9 changes: 3 additions & 6 deletions mlir/test/fusion/e2e/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ config.rocmlir_common_python_tests_utils = "@ROCMLIR_COMMON_PYTHON_TESTS_UTILS@"

# Add common python test utils
sys.path.append(config.rocmlir_common_python_tests_utils)
from common import get_agents, get_default_agent
sys.path.append(config.mlir_rock_tools_dir)
from common import get_agents, get_arch_features, get_default_agent

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
Expand Down Expand Up @@ -78,11 +79,7 @@ if config.rocm_path:
"HIP_VISIBLE_DEVICES will be set to '0' to ensure binary compatibility."
% (', '.join(sorted(agents)), default_agent))
config.arch = default_agent
# Check features for the device we'll actually use
if any([arch in default_agent for arch in ["gfx908", "gfx90a", "gfx942", "gfx950"]]):
config.arch_support_mfma = True
elif "gfx11" in default_agent or "gfx12" in default_agent:
config.arch_support_wmma = True
_, config.arch_support_mfma, config.arch_support_wmma, _ = get_arch_features(default_agent)
if not config.arch:
config.no_AMD_GPU = True
except subprocess.CalledProcessError:
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ config.rocmlir_common_python_tests_utils = "@ROCMLIR_COMMON_PYTHON_TESTS_UTILS@"

# Add common python test utils
sys.path.append(config.rocmlir_common_python_tests_utils)
sys.path.append(config.mlir_rock_tools_dir)
from common import get_agents, get_arch_features, get_default_agent

# Support substitution of the tools_dir with user parameters. This is
Expand Down
3 changes: 3 additions & 0 deletions mlir/unittests/Dialect/Rock/AmdArchDbTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ TEST_P(NativeArchTest, NativeArchInfoMatchesPresetInfo) {
nativeInfo.hasFp8ConversionInstrs);
EXPECT_EQ(presetInfo.hasOcpFp8ConversionInstrs,
nativeInfo.hasOcpFp8ConversionInstrs);
EXPECT_EQ(presetInfo.hasFp4, nativeInfo.hasFp4);
EXPECT_EQ(presetInfo.hasScaledGemm, nativeInfo.hasScaledGemm);
EXPECT_GE(presetInfo.maxNumXCC, nativeInfo.maxNumXCC);
EXPECT_EQ(presetInfo.hasLdsTransposeLoad, nativeInfo.hasLdsTransposeLoad);
}

INSTANTIATE_TEST_SUITE_P(NativeArchTests, NativeArchTest,
Expand Down
4 changes: 4 additions & 0 deletions mlir/utils/performance/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ list(TRANSFORM PERFORMANCE_SCRIPTS PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")

add_custom_target(ci-performance-scripts
COMMAND ${CMAKE_COMMAND} -E copy ${PERFORMANCE_SCRIPTS} ${ROCMLIR_BIN_DIR})

if (TARGET amd_arch_db)
add_dependencies(ci-performance-scripts amd_arch_db)
endif()
7 changes: 5 additions & 2 deletions mlir/utils/performance/analysis/quickTuningGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import pandas as pd
import pulp

from amd_arch_db import GemmFeatures, has_feature, lookup_arch_info

# Column definitions for grouping problems
GEMM_COLUMNS = ['TransA', 'TransB', 'G', 'M', 'K', 'N']
CONV_COLUMNS = [
Expand All @@ -37,9 +39,10 @@ def get_instruction_type(arch, dtype, op):
"""Determine instruction type based on architecture, data type, and operation."""
if op == "attention":
return "GemmGemm"
if arch.startswith("gfx9"):
features = lookup_arch_info(arch).default_features
if has_feature(features, GemmFeatures.MFMA):
return "XDL"
elif arch.startswith("gfx1") and dtype != "f32":
if has_feature(features, GemmFeatures.WMMA) and dtype != "f32":
return "Wmma"
return "NonAccel"

Expand Down
Loading
Loading