diff --git a/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/gpus/crosstool/BUILD.rocm.tpl index c61f1fb8fc702..392b29f3c8b54 100644 --- a/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -35,7 +35,7 @@ cc_toolchain_suite( cc_toolchain( name = "cc-compiler-local", - all_files = "@local_config_rocm//rocm:all_files", + all_files = ":crosstool_wrapper_driver_is_not_gcc", compiler_files = ":crosstool_wrapper_driver_is_not_gcc", ar_files = ":crosstool_wrapper_driver_is_not_gcc", as_files = ":crosstool_wrapper_driver_is_not_gcc", diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 38e11de3f261e..4529dd624d0d4 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -1,6 +1,6 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") -load("@local_config_rocm//rocm:build_defs.bzl", "rocm_version_number", "select_threshold") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_lib_import") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -78,6 +78,7 @@ cc_library( hdrs = glob([ "%{rocm_root}/include/**", ]), + defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"}, strip_include_prefix = "%{rocm_root}/include", ) @@ -94,247 +95,197 @@ cc_library( deps = [ ":rocm_config", ":rocm_headers_includes", - ":rocm_rpath", ], ) -cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - ":hip", - ":hipblas", - ":hipblaslt", - ":hiprand", - ":hipsolver", - ":hipsparse", - ":hsa_rocr", - ":miopen", - ":rocblas", - ":rocm_config", - ":rocprofiler_register", - ":rocsolver", - ":rocsparse", - ":roctracer", - ":hipfft", - ], -) - -cc_library( - name = "hsa_rocr", - srcs = glob(["%{rocm_root}/lib/libhsa-runtime*.so*"]), - hdrs = glob(["%{rocm_root}/include/hsa/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkstatic = 1, - strip_include_prefix = "%{rocm_root}", - deps = [":rocm_config"], -) - -# workaround to bring data to the same fs layout as expected in the rocm libs -# rocblas assumes that miopen db files are located in ../share/miopen/db directory -# hibplatslt assumes that tensile files are located in ../hipblaslt/library directory +# Provides -Wl,-rpath flags for ROCm libraries. +# These must live in a cc_library (not a toolchain feature) because +# cc_library linkopts propagate transitively through CcInfo to the +# final linking target, whereas toolchain features do not. cc_library( name = "rocm_rpath", linkopts = select({ ":build_hermetic": [ "-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib", - "-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib/llvm/lib", - "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ], ":multiple_rocm_paths": [ "-Wl,-rpath=%{rocm_lib_paths}", - "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ], "//conditions:default": [ "-Wl,-rpath,/opt/rocm/lib", - "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ], }), visibility = ["//visibility:public"], ) -cc_library( +alias( name = "hip", + actual = ":hip_runtime", visibility = ["//visibility:public"], - deps = [ - ":rocm_hip", - ":rocm_rpath", - ], ) -cc_library( - name = "rocm_hip", - srcs = glob( +rocm_lib_import( + name = "hip_runtime", + data = glob( [ - "%{rocm_root}/lib/libamdhip*.so*", - "%{rocm_root}/lib/libhiprtc.so*", - "%{rocm_root}/lib/libhiprtc-builtins.so*", - ], - exclude = [ - # exclude files like libamdhip64.so.7.1.25445-7484b05b13 -> misplaced - "%{rocm_root}/**/*.so.*.*", + "%{rocm_root}/lib/libamdhip64.so*", ], ), - hdrs = glob(["%{rocm_root}/include/hip/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", + interface_library = "%{rocm_root}/lib/libamdhip64.so", + deps = [ + ":amd_comgr_libs", + ":hiprtc_libs", + ":hsa_rocr_libs", + ":rocprofiler_register_libs", + ":system_libs", ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], +) + +filegroup( + name = "hsa_rocr_libs_data", + srcs = glob(["%{rocm_root}/lib/libhsa-runtime64.so*"]), +) + +cc_library( + name = "hsa_rocr_libs", + data = [":hsa_rocr_libs_data"], deps = [ - ":amd_comgr", - ":hsa_rocr", - ":rocm_config", - ":rocm_smi", - ":rocprofiler_register", + ":rocprofiler_register_libs", ":system_libs", ], ) -# Used by jax_rocm_plugin to minimally link to hip runtime. cc_library( - name = "hip_runtime", - srcs = glob( + name = "hiprtc_libs", + data = glob( [ - "%{rocm_root}/lib/libamdhip*.so*", "%{rocm_root}/lib/libhiprtc.so*", "%{rocm_root}/lib/libhiprtc-builtins.so*", ], - exclude = [ - # exclude files like libamdhip64.so.7.1.25445-7484b05b13 -> misplaced - "%{rocm_root}/**/*.so.*.*", - ], ), - hdrs = glob(["%{rocm_root}/include/hip/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", + deps = [ + ":amd_comgr_libs", + ":hsa_rocr_libs", ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], +) + +cc_library( + name = "amd_comgr_libs", + data = glob( + [ + "%{rocm_root}/lib/libamd_comgr_loader.so*", + "%{rocm_root}/lib/libamd_comgr.so*", + "%{rocm_root}/lib/llvm/lib/libLLVM.so*", + ], + ), deps = [ - ":amd_comgr", - ":rocm_config", - ":rocm_rpath", - ":rocprofiler_register", ":system_libs", ], ) +filegroup( + name = "rocprofiler_register_libs_data", + srcs = glob( + [ + "%{rocm_root}/lib/librocprofiler-register.so*", + ], + ), +) + cc_library( + name = "rocprofiler_register_libs", + data = [":rocprofiler_register_libs_data"], +) + +rocm_lib_import( name = "rocblas", - hdrs = glob(["%{rocm_root}/include/rocblas/**"]), data = glob([ - "%{rocm_root}/lib/librocblas*.so*", - "%{rocm_root}/lib/librocroller*.so*", + "%{rocm_root}/lib/librocblas.so*", "%{rocm_root}/lib/rocblas/**", ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-lrocblas"], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + interface_library = "%{rocm_root}/lib/librocblas.so", deps = [ - ":hipblaslt", - ":rocm_config", - ":rocm_rpath", - ":roctracer", + ":hip_runtime_libs", + ":hipblaslt_libs", + ":roctx_libs", ], ) -cc_library( - name = "rocfft", - data = glob(["%{rocm_root}/lib/librocfft*.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib"], - linkstatic = 1, - visibility = ["//visibility:public"], +rocm_lib_import( + name = "hipfft", + data = glob(["%{rocm_root}/lib/libhipfft.so*"]), + interface_library = "%{rocm_root}/lib/libhipfft.so", deps = [ - ":rocm_config", - ":rocm_rpath", + ":hip_runtime_libs", + ":rocfft_libs", ], ) cc_library( - name = "hipfft", - data = glob(["%{rocm_root}/lib/libhipfft*.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib"], - linkstatic = 1, - visibility = ["//visibility:public"], + name = "rocfft_libs", + data = glob(["%{rocm_root}/lib/librocfft.so*"]), deps = [ - ":rocm_config", - ":rocm_rpath", + ":hip_runtime_libs", + ":hiprtc_libs", ], ) -cc_library( +rocm_lib_import( name = "hiprand", - srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), - hdrs = glob(["%{rocm_root}/include/hiprand/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - "%{rocm_root}/include/rocrand", + data = glob(["%{rocm_root}/lib/libhiprand.so*"]), + interface_library = "%{rocm_root}/lib/libhiprand.so", + deps = [ + ":hip_runtime_libs", + ":rocrand_libs", ], - linkstatic = 1, - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], ) cc_library( + name = "rocrand_libs", + data = glob(["%{rocm_root}/lib/librocrand.so*"]), + deps = [ + ":hip_runtime_libs", + ], +) + +rocm_lib_import( name = "miopen", - hdrs = glob(["%{rocm_root}/include/miopen/**"]), data = glob([ - "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/lib/libMIOpen.so*", "%{rocm_root}/share/miopen/**", ]), - linkopts = ["-lMIOpen"], - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + interface_library = "%{rocm_root}/lib/libMIOpen.so", deps = [ - ":rocm-core", - ":rocm_config", - ":rocm_rpath", + ":amd_comgr_libs", + ":hip_runtime_libs", + ":hipblaslt_libs", + ":hiprtc_libs", + ":rocblas_libs", + ":roctx_libs", + ":system_libs", ], ) -cc_library( +rocm_lib_import( name = "rccl", - srcs = glob(["%{rocm_root}/lib/librccl*.so*"]), - hdrs = glob(["%{rocm_root}/include/rccl/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-lnuma"], - linkstatic = 1, - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + data = glob(["%{rocm_root}/lib/librccl.so*"]), + interface_library = "%{rocm_root}/lib/librccl.so", deps = [ - ":rocm_config", - ":roctracer", - ":system_libs", + ":hip_runtime_libs", + ":rocm_smi_libs", + ":rocprofiler_register_libs", + ":roctx_libs", ], ) +cc_library( + name = "rocm_smi_libs", + data = glob([ + "%{rocm_root}/lib/librocm_smi64.so*", + "%{rocm_root}/lib/libamd_smi.so*", + ]), +) + bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], @@ -343,315 +294,144 @@ bzl_library( cc_library( name = "rocprim", - srcs = [ - "%{rocm_root}/include/hipcub/hipcub_version.hpp", - "%{rocm_root}/include/rocprim/rocprim_version.hpp", - ], - hdrs = glob([ - "%{rocm_root}/include/hipcub/**", - "%{rocm_root}/include/rocprim/**", - ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/hipcub", - "%{rocm_root}/include/rocprim", - ], - strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ - ":rocm_config", - ":rocm_headers", + ":rocm_headers_includes", ], ) -cc_library( +rocm_lib_import( name = "hipsparse", - srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), - hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), - data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], -) - -cc_library( - name = "roctracer", - srcs = glob([ - "%{rocm_root}/lib/libroctracer*.so*", - "%{rocm_root}/lib/libroctx64.so*", - ]), - hdrs = glob(["%{rocm_root}/include/roctracer/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], -) - -cc_library( - name = "rocprofiler-sdk", - srcs = glob(["%{rocm_root}/lib/librocprofiler-sdk*.so*"]), - hdrs = glob(["%{rocm_root}/include/rocprofiler-sdk/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], -) - -cc_library( - name = "rocsolver", - hdrs = glob(["%{rocm_root}/include/rocsolver/**"]), - data = glob(["%{rocm_root}/lib/librocsolver*.so*"]), - include_prefix = "rocm", - linkopts = ["-lrocsolver"], - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + data = glob(["%{rocm_root}/lib/libhipsparse.so*"]), + interface_library = "%{rocm_root}/lib/libhipsparse.so", deps = [ - ":rocm_config", - ":rocm_rpath", + ":hip_runtime_libs", + ":rocsparse_libs", ], ) cc_library( - name = "rocsparse", - srcs = glob(["%{rocm_root}/lib/librocsparse*.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], -) - -cc_library( - name = "hipsolver", - hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), - data = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - linkopts = ["-lhipsolver"], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + name = "rocsparse_libs", + data = glob(["%{rocm_root}/lib/librocsparse.so*"]), deps = [ - ":rocm_config", - ":rocm_rpath", + ":hip_runtime_libs", + ":roctx_libs", ], ) cc_library( - name = "hipblas", - hdrs = glob(["%{rocm_root}/include/hipblas/**"]), - data = glob(["%{rocm_root}/lib/libhipblas.so*"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - linkopts = ["-lhipblas"], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [ - ":hipblas-common", - ":rocm_config", - ":rocm_rpath", - ], -) - -cc_library( - name = "hipblas-common", - hdrs = glob(["%{rocm_root}/include/hipblas-common/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], -) - -cc_library( - name = "rocm-core", - srcs = glob([ - "%{rocm_root}/lib/librocm-core.so*", + name = "roctx_libs", + data = glob([ + "%{rocm_root}/lib/libroctx64.so*", ]), - visibility = ["//visibility:public"], - deps = [":rocm_config"], ) -cc_library( - name = "hipblaslt", - hdrs = glob(["%{rocm_root}/include/hipblaslt/**"]), +rocm_lib_import( + name = "roctracer", data = glob([ - "%{rocm_root}/lib/hipblaslt/**", - "%{rocm_root}/lib/libhipblaslt.so*", + "%{rocm_root}/lib/libroctracer64.so*", ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/hipblaslt", - ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], + interface_library = "%{rocm_root}/lib/libroctracer64.so", deps = [ - ":hip_runtime", - ":rocm_config", - ":rocm_rpath", - ], -) - -cc_library( - name = "rocrand", - srcs = glob(["%{rocm_root}/lib/librocrand*.so*"]), - hdrs = glob(["%{rocm_root}/include/rocrand/**"]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include/", + ":hsa_rocr_libs", ], - strip_include_prefix = "%{rocm_root}", - visibility = ["//visibility:public"], - deps = [":rocm_config"], ) -cc_library( - name = "rocprofiler_register", - srcs = glob([ - "%{rocm_root}/lib/librocprofiler-register.so*", - ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", +rocm_lib_import( + name = "rocprofiler_sdk", + data = glob(["%{rocm_root}/lib/librocprofiler-sdk*.so*"]), + interface_library = "%{rocm_root}/lib/librocprofiler-sdk.so", + deps = [ + ":amd_comgr_libs", + ":system_libs", ], - strip_include_prefix = "%{rocm_root}", - deps = [":rocm_config"], ) -cc_library( - name = "amd_comgr_dynamic", - srcs = ["%{rocm_root}/lib/libamd_comgr_stub.a"], - hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), +rocm_lib_import( + name = "rocsolver", data = glob([ - "%{rocm_root}/lib/libamd_comgr_loader.so*", - "%{rocm_root}/lib/libamd_comgr.so*", - "%{rocm_root}/lib/llvm/lib/libLLVM.so*", + "%{rocm_root}/lib/librocsolver.so*", + "%{rocm_root}/lib/host-math/lib/*.so*", ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-lamd_comgr_loader"], - strip_include_prefix = "%{rocm_root}", + interface_library = "%{rocm_root}/lib/librocsolver.so", deps = [ - ":rocm_config", - ":rocm_rpath", - ":system_libs", + ":hip_runtime_libs", + ":rocblas_libs", ], ) -cc_library( - name = "amd_comgr_static", - hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), - data = glob([ - "%{rocm_root}/lib/libamd_comgr.so*", - ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", - ], - linkopts = ["-lamd_comgr"], - strip_include_prefix = "%{rocm_root}", +rocm_lib_import( + name = "hipsolver", + data = glob(["%{rocm_root}/lib/libhipsolver.so*"]), + interface_library = "%{rocm_root}/lib/libhipsolver.so", deps = [ - ":rocm_config", - ":rocm_rpath", - ":system_libs", + ":hip_runtime_libs", + ":rocblas_libs", + ":rocsolver_libs", + ":rocsparse_libs", ], ) - -alias( - name = "amd_comgr", - actual = select_threshold( - threshold_dict = { - 62000: ":amd_comgr_static", - 71000: ":amd_comgr_dynamic", - 71200: ":amd_comgr_static", - }, - value = rocm_version_number(), - ), - visibility = ["//visibility:public"], +rocm_lib_import( + name = "hipblas", + data = glob(["%{rocm_root}/lib/libhipblas.so*"]), + interface_library = "%{rocm_root}/lib/libhipblas.so", + deps = [ + ":rocblas_libs", + ":rocsolver_libs", + ], ) -cc_library( - name = "rocm_smi", - srcs = glob([ - "%{rocm_root}/lib/librocm_smi64.so*", - "%{rocm_root}/lib/libroam.so*", - ]), - hdrs = glob([ - "%{rocm_root}/include/oam/**", - "%{rocm_root}/include/rocm_smi/**", +rocm_lib_import( + name = "hipblaslt", + data = glob([ + "%{rocm_root}/lib/hipblaslt/**", + "%{rocm_root}/lib/libhipblaslt.so*", + "%{rocm_root}/lib/librocroller.so*", ]), - include_prefix = "rocm", - includes = [ - "%{rocm_root}/include", + interface_library = "%{rocm_root}/lib/libhipblaslt.so", + deps = [ + ":hip_runtime_libs", + ":roctx_libs", ], - strip_include_prefix = "%{rocm_root}", - deps = [":rocm_config"], ) -cc_library( - name = "system_libs", +filegroup( + name = "system_libs_data", srcs = glob([ "%{rocm_root}/lib/rocm_sysdeps/lib/*.so*", - ]), - data = glob([ "%{rocm_root}/lib/rocm_sysdeps/share/**", ]), ) -filegroup( - name = "rocm_root", - srcs = [ - "%{rocm_root}/bin/clang-offload-bundler", - ], - visibility = ["//visibility:public"], +cc_library( + name = "system_libs", + data = [":system_libs_data"], ) filegroup( name = "toolchain_data", - srcs = glob([ - "%{rocm_root}/bin/hipcc", - "%{rocm_root}/lib/llvm/**", - "%{rocm_root}/share/hip/**", - "%{rocm_root}/amdgcn/**", - "%{rocm_root}/lib/rocm_sysdeps/lib/*.so*", - "%{rocm_root}/lib/libamd_comgr_loader.so*", - "%{rocm_root}/lib/libamd_comgr.so*", - ]), - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["%{rocm_root}/**"]), + srcs = glob( + include = [ + "%{rocm_root}/bin/hipcc", + "%{rocm_root}/lib/llvm/**", + "%{rocm_root}/share/hip/version", + "%{rocm_root}/amdgcn/**", + ], + exclude = ["%{rocm_root}/lib/llvm/lib/*.a"], + ) + [":system_libs_data"], visibility = ["//visibility:public"], ) filegroup( name = "rocminfo", - srcs = ["%{rocm_root}/bin/rocminfo"], + srcs = [ + "%{rocm_root}/bin/rocminfo", + ] + [ + ":hsa_rocr_libs_data", + ":rocprofiler_register_libs_data", + ":system_libs_data", + ], visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index fbb70bf194907..ddfbeddb4d166 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -11,15 +11,6 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false }) -def select_threshold(value, threshold_dict): - sorted_keys = sorted(threshold_dict.keys()) - result = threshold_dict[sorted_keys[0]] # Default to the first threshold's value - for key in sorted_keys: - if value >= key: - result = threshold_dict[key] - - return result - def rocm_default_copts(): """Default options for all ROCm compilations.""" return if_rocm(["-x", "rocm"] + %{rocm_extra_copts}) @@ -90,3 +81,27 @@ def rocm_library(copts = [], deps = [], **kwargs): def get_rbe_amdgpu_pool(is_single_gpu = False): return "%{single_gpu_rbe_pool}" if is_single_gpu else "%{multi_gpu_rbe_pool}" + +def rocm_lib_import(name, interface_library, data, deps=[]): + native.cc_import( + name = name + "_interface", + interface_library = interface_library, + system_provided = True, + visibility = ["//visibility:private"], + ) + native.cc_library( + name = name + "_libs", + data = data, + deps = deps, + visibility = ["//visibility:private"], + ) + native.cc_library( + name = name, + deps = [ + ":{}_interface".format(name), + ":{}_libs".format(name), + ":rocm_headers_includes", + ":rocm_rpath", + ], + visibility = ["//visibility:public"], + ) diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl index 53e5fd35f1034..20506f64b2b9c 100644 --- a/third_party/gpus/rocm/rocm_config.h.tpl +++ b/third_party/gpus/rocm/rocm_config.h.tpl @@ -22,15 +22,5 @@ limitations under the License. #define TF_MIOPEN_VERSION %{miopen_version_number} #define TF_HIPRUNTIME_VERSION %{hipruntime_version_number} #define TF_HIPBLASLT %{hipblaslt_flag} -#define TF_HIPRUNTIME_SOVERSION "%{hip_soversion_number}" -#define TF_ROCBLAS_SOVERSION "%{rocblas_soversion_number}" -#define TF_HIPBLASLT_SOVERSION "%{hipblaslt_soversion_number}" -#define TF_MIOPEN_SOVERSION "%{miopen_soversion_number}" -#define TF_HIPFFT_SOVERSION "%{hipfft_soversion_number}" -#define TF_ROCSOLVER_SOVERSION "%{rocsolver_soversion_number}" -#define TF_HIPSPARSE_SOVERSION "%{hipsparse_soversion_number}" -#define TF_ROCTRACER_SOVERSION "%{roctracer_soversion_number}" -#define TF_HIPSOLVER_SOVERSION "%{hipsolver_soversion_number}" -#define TF_ROCRAND_SOVERSION "%{rocrand_soversion_number}" #endif // ROCM_ROCM_CONFIG_H_ diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 2679e2e0447a4..c5e1c3cd4e1b5 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -259,35 +259,6 @@ def _batch_files_exist(repository_ctx, libs_paths, bash_bin): all_paths.append(lib_path) return files_exist(repository_ctx, all_paths, bash_bin) -def _soversion(repository_ctx, path, bash_bin = None): - """Returns the soversion of a given library. - - Args: - repository_ctx: the repository_ctx - path: a path on the file system - bash_bin: path to the bash interpreter - - Returns: - Parsed soversion string form the SONAME dtag of the library - """ - if bash_bin == None: - bash_bin = get_bash_bin(repository_ctx) - - exec_result = execute(repository_ctx, [bash_bin, "-c", "readelf --dynamic \"%s\"" % path]) - - if exec_result.return_code: - auto_configure_fail("Failed to run readelf to find soversion: %s" % err_out(exec_result)) - - soversion = "" - for row in exec_result.stdout.strip().split("\n"): - match = row.find("SONAME") - if match >= 0: - match = row.find(".so.", match) - if match >= 0: - soversion = row[match + 4:-1] - break - return soversion - def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): test_results = _batch_files_exist(repository_ctx, libs_paths, bash_bin) @@ -313,7 +284,6 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): libs[name] = struct( file_name = selected_path.basename, path = realpath(repository_ctx, selected_path, bash_bin), - soversion = _soversion(repository_ctx, selected_path, bash_bin), ) return libs @@ -750,16 +720,6 @@ def _create_local_rocm_repository(repository_ctx): "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, "%{hipblaslt_flag}": "1", - "%{hip_soversion_number}": rocm_libs["amdhip64"].soversion, - "%{rocblas_soversion_number}": rocm_libs["rocblas"].soversion, - "%{hipblaslt_soversion_number}": rocm_libs["hipblaslt"].soversion if rocm_libs["hipblaslt"] != None else "", - "%{miopen_soversion_number}": rocm_libs["MIOpen"].soversion, - "%{hipfft_soversion_number}": rocm_libs["hipfft"].soversion, - "%{rocsolver_soversion_number}": rocm_libs["rocsolver"].soversion if rocm_libs["rocsolver"] != None else "", - "%{hipsolver_soversion_number}": rocm_libs["hipsolver"].soversion if rocm_libs["hipsolver"] != None else "", - "%{hipsparse_soversion_number}": rocm_libs["hipsparse"].soversion, - "%{roctracer_soversion_number}": rocm_libs["roctracer64"].soversion, - "%{rocrand_soversion_number}": rocm_libs["rocrand"].soversion, }, ) @@ -777,16 +737,6 @@ def _create_local_rocm_repository(repository_ctx): "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, "%{hipblaslt_flag}": "1", - "%{hip_soversion_number}": rocm_libs["amdhip64"].soversion, - "%{rocblas_soversion_number}": rocm_libs["rocblas"].soversion, - "%{hipblaslt_soversion_number}": rocm_libs["hipblaslt"].soversion if rocm_libs["hipblaslt"] != None else "", - "%{miopen_soversion_number}": rocm_libs["MIOpen"].soversion, - "%{hipfft_soversion_number}": rocm_libs["hipfft"].soversion, - "%{rocsolver_soversion_number}": rocm_libs["rocsolver"].soversion if rocm_libs["rocsolver"] != None else "", - "%{hipsolver_soversion_number}": rocm_libs["hipsolver"].soversion if rocm_libs["hipsolver"] != None else "", - "%{hipsparse_soversion_number}": rocm_libs["hipsparse"].soversion, - "%{roctracer_soversion_number}": rocm_libs["roctracer64"].soversion, - "%{rocrand_soversion_number}": rocm_libs["rocrand"].soversion, }, ) diff --git a/xla/backends/profiler/gpu/BUILD b/xla/backends/profiler/gpu/BUILD index 8b2074790bce9..b9ae3583b35a0 100644 --- a/xla/backends/profiler/gpu/BUILD +++ b/xla/backends/profiler/gpu/BUILD @@ -378,34 +378,38 @@ cc_library( "//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"], }), visibility = ["//visibility:public"], + deps = select({ + ":use_v1": ["@local_config_rocm//rocm:roctracer"], + ":use_rocprofiler_sdk": ["@local_config_rocm//rocm:rocprofiler_sdk"], + "//conditions:default": ["@local_config_rocm//rocm:rocprofiler_sdk"], + }), ) cc_library( - name = "rocm_tracer_utils", - srcs = ["rocm_tracer_utils.cc"], - hdrs = ["rocm_tracer_utils.h"], - tags = [ + name = "rocm_tracer_utils", + srcs = ["rocm_tracer_utils.cc"], + hdrs = ["rocm_tracer_utils.h"], + tags = [ "gpu", "manual", "rocm-only", ], - deps = [ - "//xla/tsl/profiler/backends/cpu:annotation_stack", - "//xla/tsl/profiler/utils:time_utils", - "//xla/tsl/profiler/utils:math_utils", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@tsl//tsl/platform:env_time", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:macros", - "@local_config_rocm//rocm:rocprofiler-sdk", - ], - visibility = ["//visibility:public"], + visibility = ["//visibility:public"], + deps = [ + "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:math_utils", + "//xla/tsl/profiler/utils:time_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:env_time", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:macros", + ], ) cc_library( @@ -420,8 +424,8 @@ cc_library( "manual", ]), deps = [ - ":rocm_tracer_utils", ":rocm_profiler_backend_cfg", + ":rocm_tracer_utils", "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", "//xla/tsl/profiler/utils:parse_annotation", @@ -446,28 +450,27 @@ cc_library( "@tsl//tsl/platform:types", "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/lib:profiler_interface", - "@local_config_rocm//rocm:rocprofiler-sdk", ], ) cc_library( name = "rocm_tracer_headers", hdrs = [ - "rocm_tracer.h", "rocm_profiler_sdk.h", + "rocm_tracer.h", "rocm_tracer_v1.h", ], - tags = [ - "gpu", - "manual", - "rocm-only", - ], # PROPAGATE the layout macro to every dependent TU: defines = select({ ":use_v1": ["XLA_GPU_ROCM_TRACER_BACKEND=1"], ":use_rocprofiler_sdk": ["XLA_GPU_ROCM_TRACER_BACKEND=3"], "//conditions:default": ["XLA_GPU_ROCM_TRACER_BACKEND=3"], }), + tags = [ + "gpu", + "manual", + "rocm-only", + ], visibility = ["//visibility:public"], ) @@ -484,8 +487,9 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_tracer_headers", ":rocm_collector", + ":rocm_tracer_headers", + "//xla:debug_options_flags", "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", "//xla/tsl/profiler/utils:time_utils", @@ -493,7 +497,6 @@ cc_library( "//xla/tsl/profiler/utils:xplane_schema", "//xla/tsl/profiler/utils:xplane_utils", "//xla/tsl/util:env_var", - "//xla:debug_options_flags", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -522,8 +525,11 @@ cc_library( "manual", "rocm-only", ], - deps = [":rocm_tracer_headers", ":rocm_tracer_impl"], visibility = ["//visibility:public"], + deps = [ + ":rocm_tracer_headers", + ":rocm_tracer_impl", + ], ) # upstream it's called xla_cc_test as no GPU involved. @@ -534,7 +540,7 @@ xla_test( tags = [ "gpu", "rocm-only", - "skip_rocprofiler_sdk", # due to rocprofiler-sdk's rocprofiler_force_configure + "skip_rocprofiler_sdk", # due to rocprofiler-sdk's rocprofiler_force_configure ] + if_google([ # Optional: only run internally if ROCm config is enabled "manual", @@ -551,31 +557,31 @@ xla_test( ], ) -xla_test( - name = "rocm_collector_test", - size = "small", - srcs = ["rocm_collector_test.cc"], - tags = [ - "gpu", - "rocm-only", - ] + if_google([ - "manual", - ]), - deps = [ - ":rocm_collector", - ":rocm_tracer_utils", - "//xla/tsl/profiler/utils:xplane_builder", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:env_time", - "@tsl//tsl/platform:status_matchers", - "@tsl//tsl/platform:test", - "@tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:macros", - ], +xla_test( + name = "rocm_collector_test", + size = "small", + srcs = ["rocm_collector_test.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + "manual", + ]), + deps = [ + ":rocm_collector", + ":rocm_tracer_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:env_time", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + ], ) cc_library( diff --git a/xla/backends/profiler/gpu/rocm_collector.h b/xla/backends/profiler/gpu/rocm_collector.h index c3bacc787b2a2..e670efcd188d5 100644 --- a/xla/backends/profiler/gpu/rocm_collector.h +++ b/xla/backends/profiler/gpu/rocm_collector.h @@ -34,6 +34,19 @@ limitations under the License. #include "xla/stream_executor/rocm/roctracer_wrapper.h" #include "xla/backends/profiler/gpu/rocm_tracer_utils.h" +// Backend: 3=v3 (rocprofiler-sdk), 1=v1 (roctracer). Default to v3. +#if !defined(XLA_GPU_ROCM_TRACER_BACKEND_V3) +#define XLA_GPU_ROCM_TRACER_BACKEND_V3 3 +#endif + +#if !defined(XLA_GPU_ROCM_TRACER_BACKEND_V1) +#define XLA_GPU_ROCM_TRACER_BACKEND_V1 1 +#endif + +#ifndef XLA_GPU_ROCM_TRACER_BACKEND +#error "XLA_GPU_ROCM_TRACER_BACKEND not defined" +#endif + namespace xla { namespace profiler { diff --git a/xla/backends/profiler/gpu/rocm_tracer.h b/xla/backends/profiler/gpu/rocm_tracer.h index 0ce1b17e074e1..0016cec73211f 100644 --- a/xla/backends/profiler/gpu/rocm_tracer.h +++ b/xla/backends/profiler/gpu/rocm_tracer.h @@ -17,11 +17,16 @@ limitations under the License. #define XLA_BACKENDS_PROFILER_GPU_ROCM_TRACER_FACADE_H_ // Backend: 3=v3 (rocprofiler-sdk), 1=v1 (roctracer). Default to v3. +#if !defined(XLA_GPU_ROCM_TRACER_BACKEND_V3) #define XLA_GPU_ROCM_TRACER_BACKEND_V3 3 +#endif + +#if !defined(XLA_GPU_ROCM_TRACER_BACKEND_V1) #define XLA_GPU_ROCM_TRACER_BACKEND_V1 1 +#endif #ifndef XLA_GPU_ROCM_TRACER_BACKEND -#define XLA_GPU_ROCM_TRACER_BACKEND XLA_GPU_ROCM_TRACER_BACKEND_V3 +#error "XLA_GPU_ROCM_TRACER_BACKEND not defined" #endif #if XLA_GPU_ROCM_TRACER_BACKEND == 3 diff --git a/xla/service/gpu/llvm_gpu_backend/BUILD b/xla/service/gpu/llvm_gpu_backend/BUILD index 2eefaa35f4bd9..9aad440216a43 100644 --- a/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/xla/service/gpu/llvm_gpu_backend/BUILD @@ -197,6 +197,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@llvm-project//llvm:AMDGPUAsmParser", # buildcleaner: keep "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", "@llvm-project//llvm:CodeGen", @@ -205,12 +206,12 @@ cc_library( "@llvm-project//llvm:Linker", "@llvm-project//llvm:MC", "@llvm-project//llvm:ObjCARC", # buildcleaner: keep + "@llvm-project//llvm:Object", "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TargetParser", - "@local_config_rocm//rocm:amd_comgr", "@tsl//tsl/platform:path", "@tsl//tsl/platform:random", "@tsl//tsl/profiler/lib:traceme", @@ -310,6 +311,34 @@ xla_cc_test( ], ) +xla_cc_test( + name = "amdgpu_register_spilling_test", + size = "small", + srcs = ["amdgpu_register_spilling_test.cc"], + data = [ + "tests_data/amdgpu_dynamic_stack.ll", + "tests_data/amdgpu_no_spills.ll", + "tests_data/amdgpu_sgpr_spills.ll", + "tests_data/amdgpu_vgpr_spills.ll", + ], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + ":amdgpu_backend", + ":load_ir_module", + "//xla:xla_proto_cc", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "load_ir_module_test", size = "small", diff --git a/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc b/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc index e951c4954955a..52bab56520e28 100644 --- a/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc +++ b/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc @@ -46,6 +46,8 @@ limitations under the License. #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/BinaryFormat/ELF.h" +#include "llvm/BinaryFormat/MsgPackDocument.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/CodeGen/CommandFlags.h" @@ -61,9 +63,13 @@ limitations under the License. #include "llvm/InitializePasses.h" #include "llvm/Linker/Linker.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Object/ELF.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Object/ObjectFile.h" #include "llvm/PassRegistry.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/AMDGPUMetadata.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/FileSystem.h" @@ -151,6 +157,232 @@ struct HsacoCache { static HsacoCache g_hsacoCache; // NOLINT: static/global vars forbidden +// Structure to hold register spilling and stack information from HSACO metadata +struct RegisterSpillInfo { + uint64_t sgpr_spill_count = 0; + uint64_t vgpr_spill_count = 0; + uint64_t private_segment_size = 0; + bool uses_dynamic_stack = false; + + bool HasSpilling() const { + return sgpr_spill_count > 0 || vgpr_spill_count > 0; + } + + bool HasStackUsage() const { + return private_segment_size > 0 || uses_dynamic_stack; + } +}; + +// Parse NT_AMDGPU_METADATA note contents and extract register spill counts. +// The metadata is in MessagePack format containing kernel information. +RegisterSpillInfo ParseAMDGPUMetadataForSpills(llvm::StringRef metadata) { + RegisterSpillInfo spill_info; + + // Parse the MsgPack metadata + llvm::msgpack::Document doc; + if (!doc.readFromBlob(metadata, /*Multi=*/false)) { + VLOG(2) << "Could not parse MsgPack metadata from NT_AMDGPU_METADATA note"; + return spill_info; + } + + llvm::msgpack::DocNode root = doc.getRoot(); + if (!root.isMap()) { + VLOG(2) << "AMDGPU metadata root is not a map (unexpected format)"; + return spill_info; + } + + // Look for "amdhsa.kernels" array + llvm::msgpack::MapDocNode root_map = root.getMap(); + auto kernels_it = root_map.find("amdhsa.kernels"); + + if (kernels_it == root_map.end() || !kernels_it->second.isArray()) { + VLOG(2) << "NT_AMDGPU_METADATA found but missing 'amdhsa.kernels' array"; + return spill_info; + } + + llvm::msgpack::ArrayDocNode kernels_array = kernels_it->second.getArray(); + + // Iterate through each kernel + for (auto& kernel_node : kernels_array) { + uint64_t kernel_sgpr_spill = 0; + uint64_t kernel_vgpr_spill = 0; + uint64_t kernel_sgpr_count = 0; + uint64_t kernel_vgpr_count = 0; + uint64_t kernel_private_size = 0; + bool kernel_uses_dynamic = false; + + if (!kernel_node.isMap()) continue; + + llvm::msgpack::MapDocNode kernel_map = kernel_node.getMap(); + + // Look for ".sgpr_spill_count" + auto sgpr_it = kernel_map.find(".sgpr_spill_count"); + if (sgpr_it != kernel_map.end() && + sgpr_it->second.getKind() == llvm::msgpack::Type::UInt) { + kernel_sgpr_spill = sgpr_it->second.getUInt(); + spill_info.sgpr_spill_count = + std::max(spill_info.sgpr_spill_count, kernel_sgpr_spill); + } + + // Look for ".vgpr_spill_count" + auto vgpr_it = kernel_map.find(".vgpr_spill_count"); + if (vgpr_it != kernel_map.end() && + vgpr_it->second.getKind() == llvm::msgpack::Type::UInt) { + kernel_vgpr_spill = vgpr_it->second.getUInt(); + spill_info.vgpr_spill_count = + std::max(spill_info.vgpr_spill_count, kernel_vgpr_spill); + } + + // Look for ".private_segment_fixed_size" + auto priv_it = kernel_map.find(".private_segment_fixed_size"); + if (priv_it != kernel_map.end() && + priv_it->second.getKind() == llvm::msgpack::Type::UInt) { + kernel_private_size = priv_it->second.getUInt(); + spill_info.private_segment_size = + std::max(spill_info.private_segment_size, kernel_private_size); + } + + // Look for ".uses_dynamic_stack" + auto dyn_it = kernel_map.find(".uses_dynamic_stack"); + if (dyn_it != kernel_map.end() && + dyn_it->second.getKind() == llvm::msgpack::Type::Boolean) { + kernel_uses_dynamic = dyn_it->second.getBool(); + spill_info.uses_dynamic_stack = + spill_info.uses_dynamic_stack || kernel_uses_dynamic; + } + + // Helper to get kernel name for logging (only when needed) + auto get_kernel_name = [&kernel_map]() -> std::string { + auto name_it = kernel_map.find(".name"); + if (name_it != kernel_map.end() && + name_it->second.getKind() == llvm::msgpack::Type::String) { + return name_it->second.getString().str(); + } + return "unknown"; + }; + + // Log per-kernel spill information with register usage + if (kernel_sgpr_spill > 0 || kernel_vgpr_spill > 0) { + // Look for ".sgpr_count" (total SGPRs used) + auto sgpr_count_it = kernel_map.find(".sgpr_count"); + if (sgpr_count_it != kernel_map.end() && + sgpr_count_it->second.getKind() == llvm::msgpack::Type::UInt) { + kernel_sgpr_count = sgpr_count_it->second.getUInt(); + } + + // Look for ".vgpr_count" (total VGPRs used) + auto vgpr_count_it = kernel_map.find(".vgpr_count"); + if (vgpr_count_it != kernel_map.end() && + vgpr_count_it->second.getKind() == llvm::msgpack::Type::UInt) { + kernel_vgpr_count = vgpr_count_it->second.getUInt(); + } + + VLOG(2) << "Kernel '" << get_kernel_name() << "' has register spilling: " + << "SGPR=" << kernel_sgpr_spill << ", VGPR=" << kernel_vgpr_spill + << ". Register count: SGPR=" << kernel_sgpr_count + << ", VGPR=" << kernel_vgpr_count; + } + + // Log per-kernel stack usage + if (kernel_private_size > 0 || kernel_uses_dynamic) { + VLOG(2) << "Kernel '" << get_kernel_name() << "' stack usage: " + << "private=" << kernel_private_size + << ", dynamic=" << (kernel_uses_dynamic ? "true" : "false"); + } + } + + return spill_info; +} + +// ELF note descriptor alignment per ELF specification +constexpr int kElfNoteDescAlignment = 4; + +// Returns spill counts by parsing AMDGPU metadata from note sections of HSACO +// ELF binary. +// +// HSACO file (ELF binary) +// -- .note section(s) +// -- ELF Note with type=NT_AMDGPU_METADATA +// -- MessagePack data +// -- Root map +// -- "amdhsa.kernels" array +// -- Each kernel object +// - ".sgpr_spill_count" +// - ".vgpr_spill_count" +// - ... (other kernel properties) +RegisterSpillInfo ExtractRegisterSpillingFromHsaco( + const std::vector& hsaco) { + RegisterSpillInfo spill_info; + + // Create memory buffer from HSACO data + std::unique_ptr mem_buffer = + llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(reinterpret_cast(hsaco.data()), + hsaco.size()), + "", /*RequiresNullTerminator=*/false); + + // Parse as ELF object file + llvm::Expected> obj_or_err = + llvm::object::ObjectFile::createObjectFile(mem_buffer->getMemBufferRef()); + + if (!obj_or_err) { + VLOG(2) << "Could not parse HSACO as ELF object file: " + << llvm::toString(obj_or_err.takeError()); + return spill_info; + } + + llvm::object::ObjectFile* obj = obj_or_err->get(); + + // Cast to ELF64LE object file (AMDGPU uses 64-bit little-endian ELF) + auto* elf_obj = llvm::dyn_cast(obj); + if (!elf_obj) { + VLOG(2) << "HSACO is not a 64-bit little-endian ELF file"; + return spill_info; + } + + // Get the underlying ELFFile to access the notes() API + const auto& elf_file = elf_obj->getELFFile(); + + for (const auto& section : elf_obj->sections()) { + llvm::Expected + shdr_or_err = elf_obj->getSection(section.getRawDataRefImpl()); + + if (!shdr_or_err) { + continue; // Skip sections we can't access + } + + const auto* shdr = *shdr_or_err; + + if (shdr->sh_type != llvm::ELF::SHT_NOTE) { + continue; + } + + llvm::Error err = llvm::Error::success(); + for (const auto& note : elf_file.notes(*shdr, err)) { + if (note.getType() == llvm::ELF::NT_AMDGPU_METADATA) { + llvm::StringRef metadata = + note.getDescAsStringRef(kElfNoteDescAlignment); + + if (metadata.empty()) { + VLOG(2) << "Found NT_AMDGPU_METADATA note but it contains no data"; + continue; + } + + // Parse the metadata and extract spill counts, return immediately + return ParseAMDGPUMetadataForSpills(metadata); + } + } + + if (err) { + VLOG(2) << "Error parsing notes: " << llvm::toString(std::move(err)); + } + } + + // If we reach here, no metadata was found + VLOG(2) << "No AMDGPU metadata found in HSACO"; + return spill_info; +} + bool HsacoCache::Find(const std::string& ir, uint64_t& hash, const std::string& gfx, std::vector& hsaco) { absl::MutexLock lock(g_hsacoCache.mutex); @@ -337,136 +569,43 @@ absl::StatusOr EmitModuleToHsaco( hsaco_file.close(); // Check for register spilling using HSACO metadata - // Use amd_comgr library for fast in-process metadata extraction VLOG(2) << "Checking for register spilling in: " << module->getModuleIdentifier(); - bool has_spilling = false; - int sgpr_spill_count = 0; - int vgpr_spill_count = 0; - int private_segment_size = 0; - - // Use already-loaded HSACO data for amd_comgr parsing - { - // Create amd_comgr data object from HSACO - amd_comgr_data_t comgr_data; - amd_comgr_status_t status = - amd_comgr_create_data(AMD_COMGR_DATA_KIND_EXECUTABLE, &comgr_data); - - if (status == AMD_COMGR_STATUS_SUCCESS) { - status = amd_comgr_set_data(comgr_data, hsaco.size(), - reinterpret_cast(hsaco.data())); - - if (status == AMD_COMGR_STATUS_SUCCESS) { - // Get metadata from the executable - amd_comgr_metadata_node_t metadata; - status = amd_comgr_get_data_metadata(comgr_data, &metadata); - - if (status == AMD_COMGR_STATUS_SUCCESS) { - // Helper lambda to lookup integer value from metadata map - auto lookup_int_value = [](amd_comgr_metadata_node_t root, - const char* key) -> int { - amd_comgr_metadata_node_t value_node; - amd_comgr_status_t s = - amd_comgr_metadata_lookup(root, key, &value_node); - if (s != AMD_COMGR_STATUS_SUCCESS) { - return 0; - } - - size_t size = 0; - s = amd_comgr_get_metadata_string(value_node, &size, nullptr); - if (s != AMD_COMGR_STATUS_SUCCESS || size == 0) { - amd_comgr_destroy_metadata(value_node); - return 0; - } - - std::string str_value(size, '\0'); - s = amd_comgr_get_metadata_string(value_node, &size, - str_value.data()); - amd_comgr_destroy_metadata(value_node); - - if (s != AMD_COMGR_STATUS_SUCCESS) { - return 0; - } - - // Parse the integer value - try { - return std::stoi(str_value); - } catch (...) { - return 0; - } - }; - - // Navigate to amdhsa.kernels array and check each kernel - amd_comgr_metadata_node_t kernels_node; - if (amd_comgr_metadata_lookup(metadata, "amdhsa.kernels", - &kernels_node) == - AMD_COMGR_STATUS_SUCCESS) { - size_t kernel_count = 0; - amd_comgr_get_metadata_list_size(kernels_node, &kernel_count); - - for (size_t i = 0; i < kernel_count; ++i) { - amd_comgr_metadata_node_t kernel_node; - if (amd_comgr_index_list_metadata(kernels_node, i, - &kernel_node) == - AMD_COMGR_STATUS_SUCCESS) { - // Get spill counts for this kernel - int kernel_sgpr_spill = - lookup_int_value(kernel_node, ".sgpr_spill_count"); - int kernel_vgpr_spill = - lookup_int_value(kernel_node, ".vgpr_spill_count"); - int kernel_private_size = lookup_int_value( - kernel_node, ".private_segment_fixed_size"); - - // Aggregate max values across all kernels - sgpr_spill_count = - std::max(sgpr_spill_count, kernel_sgpr_spill); - vgpr_spill_count = - std::max(vgpr_spill_count, kernel_vgpr_spill); - private_segment_size = - std::max(private_segment_size, kernel_private_size); - - amd_comgr_destroy_metadata(kernel_node); - } - } - amd_comgr_destroy_metadata(kernels_node); - } - - amd_comgr_destroy_metadata(metadata); - } else { - VLOG(2) << "Could not get HSACO metadata via amd_comgr"; - } - } - amd_comgr_release_data(comgr_data); - } else { - VLOG(2) << "Could not create amd_comgr data object"; - } + RegisterSpillInfo spill_info = ExtractRegisterSpillingFromHsaco(hsaco); - if (sgpr_spill_count > 0 || vgpr_spill_count > 0 || - private_segment_size > 0) { - has_spilling = true; - } + if (spill_info.HasSpilling()) { + // We can have SGPR spills without stack being used. They are saved to + // VGPRs. In that case, we don't want to discard such kernel, so just + // report such cases. + VLOG(1) << "Register spilling (SGPR: " << spill_info.sgpr_spill_count + << ", VGPR: " << spill_info.vgpr_spill_count << ") detected in " + << module->getModuleIdentifier(); + } else { + VLOG(2) << "No register spilling detected in " + << module->getModuleIdentifier(); } - if (has_spilling) { - VLOG(0) << "====== REGISTER SPILLING DETECTED ======"; - VLOG(0) << "Module: " << module->getModuleIdentifier(); - VLOG(0) << "SGPR spill count: " << sgpr_spill_count; - VLOG(0) << "VGPR spill count: " << vgpr_spill_count; - VLOG(0) << "Private segment size: " << private_segment_size << " bytes"; - VLOG(0) << "Performance may be degraded due to register pressure"; - VLOG(0) << "========================================"; + if (spill_info.HasStackUsage()) { + VLOG(1) << "Stack usage (private: " << spill_info.private_segment_size + << ", dynamic: " + << (spill_info.uses_dynamic_stack ? "true" : "false") + << ") detected in " << module->getModuleIdentifier(); // Filter out kernels with register spilling during autotuning // This matches NVIDIA's behavior in ptx_compiler_impl.cc - if (debug_options - .xla_gpu_filter_kernels_spilling_registers_on_autotuning() && + // TODO: remove ptx from xla_gpu_fail_ptx_compilation_on_register_spilling + // to make the flag more general + if (debug_options.xla_gpu_fail_ptx_compilation_on_register_spilling() && is_autotuning_compilation) { + VLOG(0) << "Discard module " << module->getModuleIdentifier() + << " due register spilling or stack usage"; return xla::Cancelled( - "Compilation result discarded due to register spilling"); + "Compilation result discarded due to register spilling or stack " + "usage"); } } else { - VLOG(2) << "No register spilling detected"; + VLOG(2) << "No stack usage detected in " << module->getModuleIdentifier(); } // Clean up temp files diff --git a/xla/service/gpu/llvm_gpu_backend/amdgpu_register_spilling_test.cc b/xla/service/gpu/llvm_gpu_backend/amdgpu_register_spilling_test.cc new file mode 100644 index 0000000000000..74b1c94feffa4 --- /dev/null +++ b/xla/service/gpu/llvm_gpu_backend/amdgpu_register_spilling_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "tsl/platform/path.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +namespace se = ::stream_executor; + +static std::string RemoveLLExtension(const std::string& filename) { + return filename.substr(0, filename.find(".ll")); +} + +// Test parameter structure +struct SpillingTestParam { + std::string ir_filename; // IR file to compile + bool fail_on_spilling; // Flag value + absl::StatusCode expected_code; // Expected status code + std::string expected_substring; // Expected substring in error (if any) +}; + +class AMDGPURegisterSpillingTest + : public ::testing::TestWithParam { + protected: + // Helper to load IR module from test data + std::unique_ptr LoadTestModule(llvm::LLVMContext* context, + const std::string& filename) { + return LoadIRModule( + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "llvm_gpu_backend", "tests_data", filename), + context); + } + + // Helper to compile with given debug options + absl::StatusOr> CompileModule( + llvm::Module* module, const std::string& module_id, + bool fail_on_spilling) { + DebugOptions debug_options; + debug_options.set_xla_gpu_fail_ptx_compilation_on_register_spilling( + fail_on_spilling); + + module->setModuleIdentifier(module_id); + + return amdgpu::CompileToHsaco( + module, se::GpuComputeCapability{se::RocmComputeCapability{"gfx1100"}}, + debug_options, module_id); + } +}; + +TEST_P(AMDGPURegisterSpillingTest, CompileTest) { + const SpillingTestParam& param = GetParam(); + llvm::LLVMContext context; + + auto module = LoadTestModule(&context, param.ir_filename); + ASSERT_NE(module, nullptr); + + // Generate module ID from filename and flag state + std::string module_id = + RemoveLLExtension(param.ir_filename) + + (param.fail_on_spilling ? "_fail_on_spilling" : "_allow_spilling"); + + auto result = CompileModule(module.get(), module_id, param.fail_on_spilling); + + EXPECT_EQ(result.status().code(), param.expected_code) + << "IR: " << param.ir_filename + << ", Flag: " << (param.fail_on_spilling ? "enabled" : "disabled") + << ", Status: " << result.status().message(); + + if (!param.expected_substring.empty()) { + EXPECT_THAT(result.status().message(), + ::testing::HasSubstr(param.expected_substring)) + << "IR: " << param.ir_filename; + } +} + +INSTANTIATE_TEST_SUITE_P( + RegisterSpillingTests, AMDGPURegisterSpillingTest, + ::testing::Values( + SpillingTestParam{"amdgpu_no_spills.ll", + /*fail_on_spilling=*/true, absl::StatusCode::kOk, ""}, + SpillingTestParam{"amdgpu_vgpr_spills.ll", + /*fail_on_spilling=*/false, absl::StatusCode::kOk, + ""}, + SpillingTestParam{"amdgpu_vgpr_spills.ll", + /*fail_on_spilling=*/true, + absl::StatusCode::kCancelled, "register spilling"}, + SpillingTestParam{"amdgpu_sgpr_spills.ll", + /*fail_on_spilling=*/false, absl::StatusCode::kOk, + ""}, + SpillingTestParam{"amdgpu_sgpr_spills.ll", + /*fail_on_spilling=*/true, absl::StatusCode::kOk, ""}, + SpillingTestParam{"amdgpu_dynamic_stack.ll", + /*fail_on_spilling=*/true, + absl::StatusCode::kCancelled, "stack usage"}), + [](const ::testing::TestParamInfo& info) { + return RemoveLLExtension(info.param.ir_filename) + + (info.param.fail_on_spilling ? "_fail_on_spilling" + : "_allow_spilling"); + }); + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_dynamic_stack.ll b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_dynamic_stack.ll new file mode 100644 index 0000000000000..5a8b76446c3e5 --- /dev/null +++ b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_dynamic_stack.ll @@ -0,0 +1,26 @@ +; AMDGPU kernel with dynamic stack usage (indirect function call) +; Based on real HIP code that uses function pointers +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9" +target triple = "amdgcn-amd-amdhsa" + +@__hip_cuid_40fa47637d275275 = addrspace(1) global i8 0 + +@llvm.compiler.used = appending addrspace(1) global [1 x ptr] [ptr addrspacecast (ptr addrspace(1) @__hip_cuid_40fa47637d275275 to ptr)], section "llvm.metadata" + +; Kernel that uses indirect function call requiring dynamic stack +define protected amdgpu_kernel void @_Z4TestPDF16bS_S_(ptr addrspace(1) noundef %dst.coerce, ptr addrspace(1) noundef %ptr1.coerce, ptr addrspace(1) noundef %ptr2.coerce) local_unnamed_addr { +entry: + %0 = ptrtoint ptr addrspace(1) %dst.coerce to i64 + %1 = inttoptr i64 %0 to ptr + %2 = ptrtoint ptr addrspace(1) %ptr1.coerce to i64 + %3 = inttoptr i64 %2 to ptr + %4 = ptrtoint ptr addrspace(1) %ptr2.coerce to i64 + %5 = inttoptr i64 %4 to ptr + %6 = tail call ptr asm "", "=s"() #1 + tail call void %6(ptr noundef %1, ptr noundef %3, ptr noundef %5) #2 + ret void +} + +attributes #1 = { nounwind } +attributes #2 = { nounwind } + diff --git a/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_no_spills.ll b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_no_spills.ll new file mode 100644 index 0000000000000..4ab9829a36f90 --- /dev/null +++ b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_no_spills.ll @@ -0,0 +1,29 @@ +; Simple AMDGPU kernel for testing register spilling detection +; This module has no external dependencies and minimal module flags +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9" +target triple = "amdgcn-amd-amdhsa" + +; Simple kernel that adds two arrays +define amdgpu_kernel void @simple_add(ptr addrspace(1) %a, ptr addrspace(1) %b, ptr addrspace(1) %c) { +entry: + %tid = call i32 @llvm.amdgcn.workitem.id.x() + %tidx = zext i32 %tid to i64 + + %a_ptr = getelementptr float, ptr addrspace(1) %a, i64 %tidx + %b_ptr = getelementptr float, ptr addrspace(1) %b, i64 %tidx + %c_ptr = getelementptr float, ptr addrspace(1) %c, i64 %tidx + + %a_val = load float, ptr addrspace(1) %a_ptr, align 4 + %b_val = load float, ptr addrspace(1) %b_ptr, align 4 + + %sum = fadd float %a_val, %b_val + + store float %sum, ptr addrspace(1) %c_ptr, align 4 + ret void +} + +; Intrinsic declaration +declare i32 @llvm.amdgcn.workitem.id.x() #0 + +attributes #0 = { nounwind readnone speculatable } + diff --git a/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_sgpr_spills.ll b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_sgpr_spills.ll new file mode 100644 index 0000000000000..51dbc634d680c --- /dev/null +++ b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_sgpr_spills.ll @@ -0,0 +1,166 @@ +; AMDGPU kernel with high SGPR pressure to force scalar register spilling +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9" +target triple = "amdgcn-amd-amdhsa" + +; Kernel using many scalar operations with limited SGPRs +; We use readfirstlane to force values into SGPRs +define amdgpu_kernel void @sgpr_pressure(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 { +entry: + %tid = call i32 @llvm.amdgcn.workitem.id.x() + %tidx = zext i32 %tid to i64 + + ; Load many scalar values from memory + ; Using readfirstlane forces values into SGPRs (uniform across wavefront) + %ptr0 = getelementptr i32, ptr addrspace(1) %in, i64 0 + %v0_vec = load i32, ptr addrspace(1) %ptr0, align 4 + %v0 = call i32 @llvm.amdgcn.readfirstlane(i32 %v0_vec) + + %ptr1 = getelementptr i32, ptr addrspace(1) %in, i64 1 + %v1_vec = load i32, ptr addrspace(1) %ptr1, align 4 + %v1 = call i32 @llvm.amdgcn.readfirstlane(i32 %v1_vec) + + %ptr2 = getelementptr i32, ptr addrspace(1) %in, i64 2 + %v2_vec = load i32, ptr addrspace(1) %ptr2, align 4 + %v2 = call i32 @llvm.amdgcn.readfirstlane(i32 %v2_vec) + + %ptr3 = getelementptr i32, ptr addrspace(1) %in, i64 3 + %v3_vec = load i32, ptr addrspace(1) %ptr3, align 4 + %v3 = call i32 @llvm.amdgcn.readfirstlane(i32 %v3_vec) + + %ptr4 = getelementptr i32, ptr addrspace(1) %in, i64 4 + %v4_vec = load i32, ptr addrspace(1) %ptr4, align 4 + %v4 = call i32 @llvm.amdgcn.readfirstlane(i32 %v4_vec) + + %ptr5 = getelementptr i32, ptr addrspace(1) %in, i64 5 + %v5_vec = load i32, ptr addrspace(1) %ptr5, align 4 + %v5 = call i32 @llvm.amdgcn.readfirstlane(i32 %v5_vec) + + %ptr6 = getelementptr i32, ptr addrspace(1) %in, i64 6 + %v6_vec = load i32, ptr addrspace(1) %ptr6, align 4 + %v6 = call i32 @llvm.amdgcn.readfirstlane(i32 %v6_vec) + + %ptr7 = getelementptr i32, ptr addrspace(1) %in, i64 7 + %v7_vec = load i32, ptr addrspace(1) %ptr7, align 4 + %v7 = call i32 @llvm.amdgcn.readfirstlane(i32 %v7_vec) + + %ptr8 = getelementptr i32, ptr addrspace(1) %in, i64 8 + %v8_vec = load i32, ptr addrspace(1) %ptr8, align 4 + %v8 = call i32 @llvm.amdgcn.readfirstlane(i32 %v8_vec) + + %ptr9 = getelementptr i32, ptr addrspace(1) %in, i64 9 + %v9_vec = load i32, ptr addrspace(1) %ptr9, align 4 + %v9 = call i32 @llvm.amdgcn.readfirstlane(i32 %v9_vec) + + %ptr10 = getelementptr i32, ptr addrspace(1) %in, i64 10 + %v10_vec = load i32, ptr addrspace(1) %ptr10, align 4 + %v10 = call i32 @llvm.amdgcn.readfirstlane(i32 %v10_vec) + + %ptr11 = getelementptr i32, ptr addrspace(1) %in, i64 11 + %v11_vec = load i32, ptr addrspace(1) %ptr11, align 4 + %v11 = call i32 @llvm.amdgcn.readfirstlane(i32 %v11_vec) + + %ptr12 = getelementptr i32, ptr addrspace(1) %in, i64 12 + %v12_vec = load i32, ptr addrspace(1) %ptr12, align 4 + %v12 = call i32 @llvm.amdgcn.readfirstlane(i32 %v12_vec) + + %ptr13 = getelementptr i32, ptr addrspace(1) %in, i64 13 + %v13_vec = load i32, ptr addrspace(1) %ptr13, align 4 + %v13 = call i32 @llvm.amdgcn.readfirstlane(i32 %v13_vec) + + %ptr14 = getelementptr i32, ptr addrspace(1) %in, i64 14 + %v14_vec = load i32, ptr addrspace(1) %ptr14, align 4 + %v14 = call i32 @llvm.amdgcn.readfirstlane(i32 %v14_vec) + + %ptr15 = getelementptr i32, ptr addrspace(1) %in, i64 15 + %v15_vec = load i32, ptr addrspace(1) %ptr15, align 4 + %v15 = call i32 @llvm.amdgcn.readfirstlane(i32 %v15_vec) + + ; Create many scalar computations - chain A + %a0 = add i32 %v0, %v1 + %a1 = mul i32 %a0, %v2 + %a2 = add i32 %a1, %v3 + %a3 = mul i32 %a2, %v4 + %a4 = add i32 %a3, %v5 + %a5 = mul i32 %a4, %v6 + %a6 = add i32 %a5, %v7 + %a7 = mul i32 %a6, %v8 + %a8 = add i32 %a7, %v9 + %a9 = mul i32 %a8, %v10 + %a10 = add i32 %a9, %v11 + %a11 = mul i32 %a10, %v12 + %a12 = add i32 %a11, %v13 + %a13 = mul i32 %a12, %v14 + %a14 = add i32 %a13, %v15 + + ; Chain B - reverse + %b0 = mul i32 %v15, %v14 + %b1 = add i32 %b0, %v13 + %b2 = mul i32 %b1, %v12 + %b3 = add i32 %b2, %v11 + %b4 = mul i32 %b3, %v10 + %b5 = add i32 %b4, %v9 + %b6 = mul i32 %b5, %v8 + %b7 = add i32 %b6, %v7 + %b8 = mul i32 %b7, %v6 + %b9 = add i32 %b8, %v5 + %b10 = mul i32 %b9, %v4 + %b11 = add i32 %b10, %v3 + %b12 = mul i32 %b11, %v2 + %b13 = add i32 %b12, %v1 + %b14 = mul i32 %b13, %v0 + + ; Chain C - subtraction + %c0 = sub i32 %v0, %v1 + %c1 = mul i32 %c0, %v2 + %c2 = sub i32 %c1, %v3 + %c3 = mul i32 %c2, %v4 + %c4 = sub i32 %c3, %v5 + %c5 = mul i32 %c4, %v6 + %c6 = sub i32 %c5, %v7 + %c7 = mul i32 %c6, %v8 + %c8 = sub i32 %c7, %v9 + %c9 = mul i32 %c8, %v10 + %c10 = sub i32 %c9, %v11 + %c11 = mul i32 %c10, %v12 + %c12 = sub i32 %c11, %v13 + %c13 = mul i32 %c12, %v14 + %c14 = sub i32 %c13, %v15 + + ; Chain D - cross dependencies + %d0 = add i32 %a0, %b0 + %d1 = mul i32 %d0, %c0 + %d2 = add i32 %a1, %b1 + %d3 = mul i32 %d2, %c1 + %d4 = add i32 %a2, %b2 + %d5 = mul i32 %d4, %c2 + %d6 = add i32 %a3, %b3 + %d7 = mul i32 %d6, %c3 + %d8 = add i32 %a4, %b4 + %d9 = mul i32 %d8, %c4 + %d10 = add i32 %a5, %b5 + %d11 = mul i32 %d10, %c5 + %d12 = add i32 %a6, %b6 + %d13 = mul i32 %d12, %c6 + + ; Combine all chains + %r0 = add i32 %a14, %b14 + %r1 = add i32 %r0, %c14 + %r2 = add i32 %r1, %d1 + %r3 = add i32 %r2, %d3 + %r4 = add i32 %r3, %d5 + %r5 = add i32 %r4, %d7 + %r6 = add i32 %r5, %d9 + %r7 = add i32 %r6, %d11 + %result = add i32 %r7, %d13 + + %out_ptr = getelementptr i32, ptr addrspace(1) %out, i64 %tidx + store i32 %result, ptr addrspace(1) %out_ptr, align 4 + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x() #1 +declare i32 @llvm.amdgcn.readfirstlane(i32) #1 + +; Limit SGPRs to 32, this should force SGPR spilling +attributes #0 = { "amdgpu-num-sgpr"="32" "amdgpu-flat-work-group-size"="1,256" } +attributes #1 = { nounwind readnone speculatable } diff --git a/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_vgpr_spills.ll b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_vgpr_spills.ll new file mode 100644 index 0000000000000..5634790c8e6eb --- /dev/null +++ b/xla/service/gpu/llvm_gpu_backend/tests_data/amdgpu_vgpr_spills.ll @@ -0,0 +1,145 @@ +; AMDGPU kernel with high register pressure to force spilling +; This uses many vector operations to exhaust available VGPRs +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9" +target triple = "amdgcn-amd-amdhsa" + +; Kernel with many live values to force register spilling +define amdgpu_kernel void @high_register_pressure(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 { +entry: + %tid = call i32 @llvm.amdgcn.workitem.id.x() + %tidx = zext i32 %tid to i64 + + ; Load many vectors from memory - using volatile to prevent optimization + %ptr0 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 %tidx + %v0 = load volatile <4 x float>, ptr addrspace(1) %ptr0, align 16 + + %ptr1 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 1 + %v1 = load volatile <4 x float>, ptr addrspace(1) %ptr1, align 16 + + %ptr2 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 2 + %v2 = load volatile <4 x float>, ptr addrspace(1) %ptr2, align 16 + + %ptr3 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 3 + %v3 = load volatile <4 x float>, ptr addrspace(1) %ptr3, align 16 + + %ptr4 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 4 + %v4 = load volatile <4 x float>, ptr addrspace(1) %ptr4, align 16 + + %ptr5 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 5 + %v5 = load volatile <4 x float>, ptr addrspace(1) %ptr5, align 16 + + %ptr6 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 6 + %v6 = load volatile <4 x float>, ptr addrspace(1) %ptr6, align 16 + + %ptr7 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 7 + %v7 = load volatile <4 x float>, ptr addrspace(1) %ptr7, align 16 + + %ptr8 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 8 + %v8 = load volatile <4 x float>, ptr addrspace(1) %ptr8, align 16 + + %ptr9 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 9 + %v9 = load volatile <4 x float>, ptr addrspace(1) %ptr9, align 16 + + %ptr10 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 10 + %v10 = load volatile <4 x float>, ptr addrspace(1) %ptr10, align 16 + + %ptr11 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 11 + %v11 = load volatile <4 x float>, ptr addrspace(1) %ptr11, align 16 + + %ptr12 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 12 + %v12 = load volatile <4 x float>, ptr addrspace(1) %ptr12, align 16 + + %ptr13 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 13 + %v13 = load volatile <4 x float>, ptr addrspace(1) %ptr13, align 16 + + %ptr14 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 14 + %v14 = load volatile <4 x float>, ptr addrspace(1) %ptr14, align 16 + + %ptr15 = getelementptr <4 x float>, ptr addrspace(1) %in, i64 15 + %v15 = load volatile <4 x float>, ptr addrspace(1) %ptr15, align 16 + + ; Create many dependent calculations - chain A + %a0 = fadd <4 x float> %v0, %v1 + %a1 = fmul <4 x float> %a0, %v2 + %a2 = fadd <4 x float> %a1, %v3 + %a3 = fmul <4 x float> %a2, %v4 + %a4 = fadd <4 x float> %a3, %v5 + %a5 = fmul <4 x float> %a4, %v6 + %a6 = fadd <4 x float> %a5, %v7 + %a7 = fmul <4 x float> %a6, %v8 + %a8 = fadd <4 x float> %a7, %v9 + %a9 = fmul <4 x float> %a8, %v10 + %a10 = fadd <4 x float> %a9, %v11 + %a11 = fmul <4 x float> %a10, %v12 + %a12 = fadd <4 x float> %a11, %v13 + %a13 = fmul <4 x float> %a12, %v14 + %a14 = fadd <4 x float> %a13, %v15 + + ; Chain B - reverse direction + %b0 = fmul <4 x float> %v15, %v14 + %b1 = fadd <4 x float> %b0, %v13 + %b2 = fmul <4 x float> %b1, %v12 + %b3 = fadd <4 x float> %b2, %v11 + %b4 = fmul <4 x float> %b3, %v10 + %b5 = fadd <4 x float> %b4, %v9 + %b6 = fmul <4 x float> %b5, %v8 + %b7 = fadd <4 x float> %b6, %v7 + %b8 = fmul <4 x float> %b7, %v6 + %b9 = fadd <4 x float> %b8, %v5 + %b10 = fmul <4 x float> %b9, %v4 + %b11 = fadd <4 x float> %b10, %v3 + %b12 = fmul <4 x float> %b11, %v2 + %b13 = fadd <4 x float> %b12, %v1 + %b14 = fmul <4 x float> %b13, %v0 + + ; Chain C - subtraction chain + %c0 = fsub <4 x float> %v0, %v1 + %c1 = fmul <4 x float> %c0, %v2 + %c2 = fsub <4 x float> %c1, %v3 + %c3 = fmul <4 x float> %c2, %v4 + %c4 = fsub <4 x float> %c3, %v5 + %c5 = fmul <4 x float> %c4, %v6 + %c6 = fsub <4 x float> %c5, %v7 + %c7 = fmul <4 x float> %c6, %v8 + %c8 = fsub <4 x float> %c7, %v9 + %c9 = fmul <4 x float> %c8, %v10 + %c10 = fsub <4 x float> %c9, %v11 + %c11 = fmul <4 x float> %c10, %v12 + %c12 = fsub <4 x float> %c11, %v13 + %c13 = fmul <4 x float> %c12, %v14 + %c14 = fsub <4 x float> %c13, %v15 + + ; Chain D - cross dependencies + %d0 = fadd <4 x float> %a0, %b0 + %d1 = fmul <4 x float> %d0, %c0 + %d2 = fadd <4 x float> %a1, %b1 + %d3 = fmul <4 x float> %d2, %c1 + %d4 = fadd <4 x float> %a2, %b2 + %d5 = fmul <4 x float> %d4, %c2 + %d6 = fadd <4 x float> %a3, %b3 + %d7 = fmul <4 x float> %d6, %c3 + %d8 = fadd <4 x float> %a4, %b4 + %d9 = fmul <4 x float> %d8, %c4 + %d10 = fadd <4 x float> %a5, %b5 + %d11 = fmul <4 x float> %d10, %c5 + + ; Final combination to keep all values live + %result0 = fadd <4 x float> %a14, %b14 + %result1 = fadd <4 x float> %result0, %c14 + %result2 = fadd <4 x float> %result1, %d1 + %result3 = fadd <4 x float> %result2, %d3 + %result4 = fadd <4 x float> %result3, %d5 + %result5 = fadd <4 x float> %result4, %d7 + %result6 = fadd <4 x float> %result5, %d9 + %result = fadd <4 x float> %result6, %d11 + + %out_ptr = getelementptr <4 x float>, ptr addrspace(1) %out, i64 %tidx + store <4 x float> %result, ptr addrspace(1) %out_ptr, align 16 + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x() #1 + +; Limit VGPRs to 64 to force spilling +attributes #0 = { "amdgpu-num-vgpr"="64" "amdgpu-flat-work-group-size"="1,256" } +attributes #1 = { nounwind readnone speculatable } diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 8fd7b0c567d78..7b8406413a83a 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -48,7 +48,6 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_driver_wrapper", ":rocm_status", "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:context", @@ -65,22 +64,6 @@ cc_library( ], ) -cc_library( - name = "rocm_driver_wrapper", - hdrs = ["rocm_driver_wrapper.h"], - defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"}, - tags = [ - "gpu", - "rocm-only", - ], - deps = [ - "//xla/tsl/platform:env", - "@local_config_rocm//rocm:hip", # buildcleaner: keep - "@local_config_rocm//rocm:rocm_headers", - "@tsl//tsl/platform:dso_loader", - ], -) - cc_library( name = "rocm_event", srcs = ["rocm_event.cc"], @@ -90,7 +73,6 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_driver_wrapper", ":rocm_status", "//xla/stream_executor:activate_context", "//xla/stream_executor:event", @@ -103,6 +85,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ], ) @@ -137,7 +120,6 @@ cc_library( deps = [ ":rocm_command_buffer", ":rocm_context", - ":rocm_driver_wrapper", ":rocm_event", ":rocm_kernel", ":rocm_platform_id", @@ -190,6 +172,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:fingerprint", @@ -236,7 +219,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":rocm_driver_wrapper", ":rocm_status", "//xla/stream_executor:activate_context", "//xla/stream_executor:kernel", @@ -251,6 +233,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ], ) @@ -289,7 +272,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":rocm_driver_wrapper", ":rocm_executor", ":rocm_platform_id", ":rocm_status", @@ -306,6 +288,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ], alwayslink = True, # Registers itself with the PlatformManager. @@ -329,24 +312,6 @@ cc_library( ]), ) -cc_library( - name = "rocblas_wrapper", - hdrs = ["rocblas_wrapper.h"], - tags = [ - "gpu", - "rocm-only", - ], - deps = [ - ":rocm_executor", - "//xla/tsl/platform:env", - "//xla/tsl/util:determinism_for_kernels", - "@local_config_rocm//rocm:rocm_headers", - "@tsl//tsl/platform", - "@tsl//tsl/platform:dso_loader", - ], - alwayslink = True, -) - cc_library( name = "rocblas_plugin", srcs = ["rocm_blas.cc"], @@ -359,7 +324,6 @@ cc_library( deps = [ ":hipblas_lt_header", ":rocblas_if_static", - ":rocblas_wrapper", ":rocm_complex_converters", ":rocm_executor", ":rocm_helpers", @@ -410,24 +374,17 @@ cc_library( name = "rocm_solver_context", srcs = ["rocm_solver_context.cc"], hdrs = ["rocm_solver_context.h"], - local_defines = [ - "TENSORFLOW_USE_ROCM=1", - ], tags = [ "gpu", "manual", "rocm-only", ], deps = [ - ":hipsolver_wrapper", - ":rocblas_wrapper", ":rocm_platform_id", - ":rocsolver_wrapper", "//xla:comparison_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory", "//xla/stream_executor:gpu_solver_context", "//xla/stream_executor:stream", "//xla/stream_executor/platform:platform_object_registry", @@ -436,6 +393,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", ], alwayslink = 1, @@ -539,25 +497,22 @@ cc_library( ) cc_library( - name = "hiprand_if_static", + name = "hipsolver_wrapper", + hdrs = ["hipsolver_wrapper.h"], tags = [ "gpu", "rocm-only", ], deps = if_static([ - "@local_config_rocm//rocm:hiprand", - ]), -) - -cc_library( - name = "hipsparse_if_static", - tags = [ - "gpu", - "rocm-only", + "@local_config_rocm//rocm:hipsolver", + ]) + [ + ":rocm_executor", + ":rocm_platform_id", + "//xla/tsl/platform:env", + "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", ], - deps = if_static([ - "@local_config_rocm//rocm:hipsparse", - ]), + alwayslink = True, ) cc_library( @@ -568,8 +523,9 @@ cc_library( "gpu", "rocm-only", ], - deps = [ - ":hipsparse_if_static", + deps = if_static([ + "@local_config_rocm//rocm:hipsparse", + ]) + [ ":rocm_executor", ":rocm_platform_id", "//xla/tsl/platform:env", @@ -581,55 +537,35 @@ cc_library( ) cc_library( - name = "rocsolver_if_static", - tags = [ - "gpu", - "rocm-only", - ], - deps = if_static([ - "@local_config_rocm//rocm:rocsolver", - ]), -) - -cc_library( - name = "rocsolver_wrapper", - srcs = ["rocsolver_wrapper.h"], - hdrs = ["rocsolver_wrapper.h"], + name = "rocblas_wrapper", + hdrs = ["rocblas_wrapper.h"], tags = [ "gpu", "rocm-only", ], deps = [ + ":rocblas_if_static", ":rocm_executor", - ":rocm_platform_id", - ":rocsolver_if_static", "//xla/tsl/platform:env", + "//xla/tsl/util:determinism_for_kernels", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform", "@tsl//tsl/platform:dso_loader", ], alwayslink = True, ) cc_library( - name = "hipsolver_if_static", + name = "rocsolver_wrapper", + srcs = ["rocsolver_wrapper.h"], + hdrs = ["rocsolver_wrapper.h"], tags = [ "gpu", "rocm-only", ], deps = if_static([ - "@local_config_rocm//rocm:hipsolver", - ]), -) - -cc_library( - name = "hipsolver_wrapper", - hdrs = ["hipsolver_wrapper.h"], - tags = [ - "gpu", - "rocm-only", - ], - deps = [ - ":hipsolver_if_static", + "@local_config_rocm//rocm:rocsolver", + ]) + [ ":rocm_executor", ":rocm_platform_id", "//xla/tsl/platform:env", @@ -656,7 +592,6 @@ cc_library( hdrs = [ "hip_blas_lt.h", "hip_blas_utils.h", - "hipblaslt_wrapper.h", ], defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"}, tags = [ @@ -696,9 +631,8 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:ml_dtypes", - ] + if_static([ - ":hipblaslt_if_static", - ]), + "@local_config_rocm//rocm:hipblaslt", + ], alwayslink = True, ) @@ -707,7 +641,6 @@ cc_library( hdrs = [ "hip_blas_lt.h", "hip_blas_utils.h", - "hipblaslt_wrapper.h", ], tags = [ "gpu", @@ -747,17 +680,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], -) - -cc_library( - name = "roctracer_if_static", - tags = [ - "gpu", - "rocm-only", - ], - deps = if_static([ - "@local_config_rocm//rocm:roctracer", - ]), + alwayslink = 1, ) cc_library( @@ -770,11 +693,7 @@ cc_library( ], deps = [ ":rocm_executor", - ":roctracer_if_static", # buildcleaner: keep - "//xla/tsl/platform:env", "@local_config_rocm//rocm:rocm_headers", - "@tsl//tsl/platform", - "@tsl//tsl/platform:dso_loader", ], alwayslink = True, ) @@ -889,7 +808,6 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_driver_wrapper", ":rocm_event", ":rocm_kernel", ":rocm_status", @@ -913,6 +831,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ], ) @@ -952,7 +871,6 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_driver_wrapper", ":rocm_event", ":rocm_status", "//xla/stream_executor:activate_context", @@ -965,6 +883,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ], ) @@ -1041,7 +960,6 @@ cc_library( "rocm-only", ], deps = [ - ":rocm_driver_wrapper", ":rocm_kernel", ":rocm_status", "//xla/stream_executor:bit_pattern", @@ -1064,6 +982,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:casts", ], diff --git a/xla/stream_executor/rocm/hip_blas_lt.cc b/xla/stream_executor/rocm/hip_blas_lt.cc index b746a81ede59c..cde7dfab4ea65 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/xla/stream_executor/rocm/hip_blas_lt.cc @@ -47,7 +47,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/rocm/hip_blas_utils.h" -#include "xla/stream_executor/rocm/hipblaslt_wrapper.h" #include "xla/stream_executor/rocm/rocm_blas.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" @@ -83,32 +82,31 @@ namespace { template absl::Status SetAttr(hipblasLtMatrixLayout_t handle, hipblasLtMatrixLayoutAttribute_t attr, T value) { - return SET_ATTR(wrap::hipblasLtMatrixLayoutSetAttribute, handle, attr, value); + return SET_ATTR(hipblasLtMatrixLayoutSetAttribute, handle, attr, value); } template absl::StatusOr GetAttr(hipblasLtMatrixLayout_t handle, hipblasLtMatrixLayoutAttribute_t attr) { - return GET_ATTR(wrap::hipblasLtMatrixLayoutGetAttribute, handle, attr, T); + return GET_ATTR(hipblasLtMatrixLayoutGetAttribute, handle, attr, T); } template absl::Status SetAttr(hipblasLtMatmulDesc_t handle, hipblasLtMatmulDescAttributes_t attr, T value) { - return SET_ATTR(wrap::hipblasLtMatmulDescSetAttribute, handle, attr, value); + return SET_ATTR(hipblasLtMatmulDescSetAttribute, handle, attr, value); } template absl::StatusOr GetAttr(hipblasLtMatmulDesc_t handle, hipblasLtMatmulDescAttributes_t attr) { - return GET_ATTR(wrap::hipblasLtMatmulDescGetAttribute, handle, attr, T); + return GET_ATTR(hipblasLtMatmulDescGetAttribute, handle, attr, T); } template absl::Status SetAttr(hipblasLtMatmulPreference_t handle, hipblasLtMatmulPreferenceAttributes_t attr, T value) { - return SET_ATTR(wrap::hipblasLtMatmulPreferenceSetAttribute, handle, attr, - value); + return SET_ATTR(hipblasLtMatmulPreferenceSetAttribute, handle, attr, value); } static absl::StatusOr AsHipblasLtEpilogue( @@ -148,7 +146,7 @@ static absl::StatusOr AsHipblasLtEpilogue( absl::Status BlasLt::Init() { hipblasLtHandle_t blas_lt; - SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtCreate(&blas_lt)); + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtCreate(&blas_lt)); absl::MutexLock lock(mu_); blas_lt_.reset(blas_lt); return absl::OkStatus(); @@ -160,9 +158,9 @@ absl::Status BlasLt::Init() { auto hipblas_data_type_ = AsHipblasDataType(type); hipblasLtMatrixLayout_t hip_layout; - SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatrixLayoutCreate( - &hip_layout, hipblas_data_type_, m.num_rows, m.num_cols, - m.leading_dim_stride)); + SE_HIPBLAS_RETURN_IF_ERROR( + hipblasLtMatrixLayoutCreate(&hip_layout, hipblas_data_type_, m.num_rows, + m.num_cols, m.leading_dim_stride)); // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatrixLayout layout(hip_layout, hipblas_data_type_); if (m.order != gpu::MatrixLayout::Order::kColumnMajor) @@ -194,8 +192,8 @@ absl::Status BlasLt::Init() { << int(pointer_mode) << "mx_mode: " << mx_mode; auto hip_scale_type = AsHipblasDataType(scale_type); auto hip_compute_type = AsHipblasComputeType(compute_type); - SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmulDescCreate( - &hip_desc, hip_compute_type, hip_scale_type)); + SE_HIPBLAS_RETURN_IF_ERROR( + hipblasLtMatmulDescCreate(&hip_desc, hip_compute_type, hip_scale_type)); int32_t bias_flag = static_cast(epilogue) & static_cast(Epilogue::kBias); @@ -228,11 +226,11 @@ auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, hipblasLtMatmulPreference_t hip_preference; SE_HIPBLAS_RETURN_IF_ERROR( - wrap::hipblasLtMatmulPreferenceCreate(&hip_preference)); + hipblasLtMatmulPreferenceCreate(&hip_preference)); // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. Owned preference( - hip_preference, wrap::hipblasLtMatmulPreferenceDestroy); + hip_preference, hipblasLtMatmulPreferenceDestroy); TF_RETURN_IF_ERROR(SetAttr( hip_preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, @@ -279,7 +277,7 @@ auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, #endif // TF_ROCM_VERSION >= 70000 int found_algorithm_count = 0; - auto error = wrap::hipblasLtMatmulAlgoGetHeuristic( + auto error = hipblasLtMatmulAlgoGetHeuristic( blas_lt->blas_lt_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), c_desc_.get(), d_desc_.get(), preference.get(), max_algorithm_count, results.data(), &found_algorithm_count); @@ -504,13 +502,13 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( std::unique_ptr activation = blas_lt->parent_->Activate(); if (palgo != nullptr) { - SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmul( - blas_lt->blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), - a_desc_.get(), b.opaque(), b_desc_.get(), beta, args.c.opaque(), - c_desc_.get(), args.d.opaque(), d_desc_.get(), palgo, workspace_addr, - workspace_size, - absl::bit_cast( - stream->platform_specific_handle().stream))); + SE_HIPBLAS_RETURN_IF_ERROR( + hipblasLtMatmul(blas_lt->blas_lt_.get(), op_desc_.get(), alpha, + a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), + beta, args.c.opaque(), c_desc_.get(), args.d.opaque(), + d_desc_.get(), palgo, workspace_addr, workspace_size, + absl::bit_cast( + stream->platform_specific_handle().stream))); } else { return absl::InternalError("hipblaslt: Invalid algorithm type"); } diff --git a/xla/stream_executor/rocm/hip_blas_lt.h b/xla/stream_executor/rocm/hip_blas_lt.h index 7760b5d587e6d..64066c7fbfaa6 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.h +++ b/xla/stream_executor/rocm/hip_blas_lt.h @@ -47,8 +47,7 @@ class BlasLt : public gpu::BlasLt { private: MatrixLayout(hipblasLtMatrixLayout_t handle, hipDataType datatype) - : handle_(handle, wrap::hipblasLtMatrixLayoutDestroy), - datatype_(datatype) {} + : handle_(handle, hipblasLtMatrixLayoutDestroy), datatype_(datatype) {} Owned handle_; hipDataType datatype_; @@ -75,7 +74,7 @@ class BlasLt : public gpu::BlasLt { private: MatmulDesc(hipblasLtMatmulDesc_t handle, hipblasComputeType_t compute_type, hipDataType datatype, bool bias_epilogue, bool mx_mode) - : handle_(handle, wrap::hipblasLtMatmulDescDestroy), + : handle_(handle, hipblasLtMatmulDescDestroy), compute_type_(compute_type), datatype_(datatype), has_bias_epilogue_(bias_epilogue), @@ -136,7 +135,7 @@ class BlasLt : public gpu::BlasLt { }; // class MatmulPlan explicit BlasLt(StreamExecutor* parent) - : parent_(parent), blas_lt_(nullptr, wrap::hipblasLtDestroy) {} + : parent_(parent), blas_lt_(nullptr, hipblasLtDestroy) {} absl::Status Init() override; diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index 487ee50fe29d4..1e70e8d9887fa 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "rocm/rocm_config.h" #include "xla/stream_executor/blas.h" #if TF_HIPBLASLT diff --git a/xla/stream_executor/rocm/hip_blas_utils.h b/xla/stream_executor/rocm/hip_blas_utils.h index 5d849ac1fa3e4..9be28a6dd121f 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.h +++ b/xla/stream_executor/rocm/hip_blas_utils.h @@ -19,8 +19,9 @@ limitations under the License. #include #include "absl/status/status.h" +#include "rocm/include/hipblas/hipblas.h" +#include "rocm/include/hipblaslt/hipblaslt.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/rocm/hipblaslt_wrapper.h" #include "xla/tsl/platform/errors.h" #if TF_HIPBLASLT diff --git a/xla/stream_executor/rocm/hipblaslt_wrapper.h b/xla/stream_executor/rocm/hipblaslt_wrapper.h deleted file mode 100644 index 0ff04d4d41f6d..0000000000000 --- a/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file wraps rocsolver API calls with dso loader so that we don't need to -// have explicit linking to librocsolver. All TF hipsarse API usage should route -// through this wrapper. - -#ifndef XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ -#define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ - -#include "rocm/rocm_config.h" - -#if TF_HIPBLASLT -#include "rocm/include/hipblas/hipblas.h" -#if TF_ROCM_VERSION >= 50500 -#include "rocm/include/hipblaslt/hipblaslt.h" -#else -#include "rocm/include/hipblaslt.h" -#endif -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" - -namespace stream_executor { -namespace wrap { - -#ifdef PLATFORM_GOOGLE - -#define HIPBLASLT_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - return ::api_name(args...); \ - } - -#else - -#define TO_STR_(x) #x -#define TO_STR(x) TO_STR_(x) - -#define HIPBLASLT_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - tsl::internal::CachedDsoLoader::GetHipblasltDsoHandle().value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipblaslt lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ - } - -#endif - -// clang-format off -#define FOREACH_HIPBLASLT_API(__macro) \ - __macro(hipblasLtCreate) \ - __macro(hipblasLtDestroy) \ - __macro(hipblasLtMatmulPreferenceCreate) \ - __macro(hipblasLtMatmulPreferenceSetAttribute) \ - __macro(hipblasLtMatmulPreferenceDestroy) \ - __macro(hipblasLtMatmulDescSetAttribute) \ - __macro(hipblasLtMatmulDescGetAttribute) \ - __macro(hipblasLtMatmulAlgoGetHeuristic) \ - __macro(hipblasLtMatrixLayoutCreate) \ - __macro(hipblasLtMatrixLayoutDestroy) \ - __macro(hipblasLtMatrixLayoutSetAttribute) \ - __macro(hipblasLtMatrixLayoutGetAttribute) \ - __macro(hipblasLtMatmulDescCreate) \ - __macro(hipblasLtMatmulDescDestroy) \ - __macro(hipblasLtMatmul) \ - __macro(hipblasStatusToString) -// clang-format on - -FOREACH_HIPBLASLT_API(HIPBLASLT_API_WRAPPER) - -#undef TO_STR_ -#undef TO_STR -#undef FOREACH_HIPBLASLT_API -#undef HIPBLASLT_API_WRAPPER - -} // namespace wrap -} // namespace stream_executor - -#endif // TF_HIPBLASLT - -#endif // XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ diff --git a/xla/stream_executor/rocm/hipsolver_wrapper.h b/xla/stream_executor/rocm/hipsolver_wrapper.h index f5edb57bb96a4..16f6fc27c54e9 100644 --- a/xla/stream_executor/rocm/hipsolver_wrapper.h +++ b/xla/stream_executor/rocm/hipsolver_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The OpenXLA Authors. +/* Copyright 2026 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,59 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file wraps hipsolver API calls with dso loader so that we don't need to -// have explicit linking to libhipsolver. All TF hipsolver API usage should -// route through this wrapper. - #ifndef XLA_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_HIPSOLVER_WRAPPER_H_ -#include "rocm/rocm_config.h" - -#if TF_ROCM_VERSION >= 40500 -#if TF_ROCM_VERSION >= 50600 #include "rocm/include/hipsolver/hipsolver.h" -#else - -#include "rocm/include/hipsolver.h" -#endif -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" +#include "rocm/rocm_config.h" namespace stream_executor { namespace wrap { -#ifdef PLATFORM_GOOGLE - -#define HIPSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - return ::api_name(args...); \ - } - -#else - -#define TO_STR_(x) #x -#define TO_STR(x) TO_STR_(x) - -#define HIPSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - tsl::internal::CachedDsoLoader::GetHipsolverDsoHandle().value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ - } - -#endif +#define HIPSOLVER_API_WRAPPER(api_name) using ::api_name; // clang-format off #define FOREACH_HIPSOLVER_API(__macro) \ @@ -136,8 +93,6 @@ namespace wrap { FOREACH_HIPSOLVER_API(HIPSOLVER_API_WRAPPER) -#undef TO_STR_ -#undef TO_STR #undef FOREACH_HIPSOLVER_API #undef HIPSOLVER_API_WRAPPER diff --git a/xla/stream_executor/rocm/hipsparse_wrapper.h b/xla/stream_executor/rocm/hipsparse_wrapper.h index 2a85ed02041ac..eaa06947220a6 100644 --- a/xla/stream_executor/rocm/hipsparse_wrapper.h +++ b/xla/stream_executor/rocm/hipsparse_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The OpenXLA Authors. +/* Copyright 2026 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,67 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file wraps hipsparse API calls with dso loader so that we don't need to -// have explicit linking to libhipsparse. All TF hipsarse API usage should route -// through this wrapper. - #ifndef XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ -#include "rocm/rocm_config.h" - -#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/hipsparse/hipsparse.h" -#else -#include "rocm/include/hipsparse.h" -#endif -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" -#include "tsl/platform/platform.h" +#include "rocm/rocm_config.h" namespace stream_executor { namespace wrap { -#ifdef PLATFORM_GOOGLE - -#define HIPSPARSE_API_WRAPPER(__name) \ - struct WrapperShim__##__name { \ - template \ - hipsparseStatus_t operator()(Args... args) { \ - hipsparseStatus_t retval = ::__name(args...); \ - return retval; \ - } \ - } __name; - -#else - -#define HIPSPARSE_API_WRAPPER(__name) \ - static struct DynLoadShim__##__name { \ - constexpr static const char* kName = #__name; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = tsl::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - hipsparseStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; - -#endif +#define HIPSPARSE_API_WRAPPER(__name) using ::__name; // clang-format off #define FOREACH_HIPSPARSE_API(__macro) \ @@ -115,10 +64,7 @@ namespace wrap { __macro(hipsparseZcsrgemm) \ __macro(hipsparseZcsrmm) \ __macro(hipsparseZcsrmm2) \ - __macro(hipsparseZcsrmv) - -#if TF_ROCM_VERSION >= 40200 -#define FOREACH_HIPSPARSE_ROCM42_API(__macro) \ + __macro(hipsparseZcsrmv) \ __macro(hipsparseCcsru2csr_bufferSizeExt) \ __macro(hipsparseCcsru2csr) \ __macro(hipsparseCreateCsr) \ @@ -133,13 +79,6 @@ namespace wrap { __macro(hipsparseSpMM) \ __macro(hipsparseZcsru2csr_bufferSizeExt) \ __macro(hipsparseZcsru2csr) - - -FOREACH_HIPSPARSE_ROCM42_API(HIPSPARSE_API_WRAPPER) - -#undef FOREACH_HIPSPARSE_ROCM42_API -#endif - // clang-format on FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER) diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 9e31d8814186c..9e5fc21482900 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The OpenXLA Authors. +/* Copyright 2026 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file wraps rocblas API calls with dso loader so that we don't need to -// have explicit linking to librocblas. All TF hipsarse API usage should route -// through this wrapper. - #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ @@ -25,53 +21,11 @@ limitations under the License. #include "rocm/include/rocblas/rocblas.h" #include "rocm/rocm_config.h" -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" -#include "tsl/platform/platform.h" namespace stream_executor { namespace wrap { -#ifdef PLATFORM_GOOGLE -#define ROCBLAS_API_WRAPPER(__name) \ - struct WrapperShim__##__name { \ - constexpr static const char* kName = #__name; \ - template \ - rocblas_status operator()(Args... args) { \ - return (::__name)(args...); \ - } \ - } __name; - -#else -using tsl::internal::CachedDsoLoader::GetRocblasDsoHandle; - -#define ROCBLAS_API_WRAPPER(__name) \ - static struct DynLoadShim__##__name { \ - constexpr static const char* kName = #__name; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = GetRocblasDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocblas DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - auto operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; - -#endif +#define ROCBLAS_API_WRAPPER(__name) using ::__name; // clang-format off #define FOREACH_ROCBLAS_API(__macro) \ @@ -280,6 +234,9 @@ using tsl::internal::CachedDsoLoader::GetRocblasDsoHandle; FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER) +#undef FOREACH_ROCBLAS_API +#undef ROCBLAS_API_WRAPPER + } // namespace wrap } // namespace stream_executor diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index 732454d86d218..eb0e420df6c3a 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -48,7 +48,6 @@ limitations under the License. #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/rocm/rocblas_wrapper.h" #include "xla/stream_executor/rocm/rocm_complex_converters.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" @@ -62,6 +61,59 @@ limitations under the License. using tsl::OpDeterminismRequired; namespace stream_executor { + +namespace wrap { + +namespace { + +#define ROCBLAS_API_WRAPPER(__name) \ + struct WrapperShim__##__name { \ + constexpr static const char *kName = #__name; \ + template \ + rocblas_status operator()(Args... args) { \ + return (::__name)(args...); \ + } \ + } __name; + +// clang-format off +#define FOREACH_ROCBLAS_API(__macro) \ + __macro(rocblas_sscal) \ + __macro(rocblas_dscal) \ + __macro(rocblas_cscal) \ + __macro(rocblas_csscal) \ + __macro(rocblas_zscal) \ + __macro(rocblas_zdscal) \ + __macro(rocblas_strsm) \ + __macro(rocblas_dtrsm) \ + __macro(rocblas_ctrsm) \ + __macro(rocblas_ztrsm) \ + __macro(rocblas_sgemv) \ + __macro(rocblas_dgemv) \ + __macro(rocblas_cgemv) \ + __macro(rocblas_zgemv) \ + __macro(rocblas_sgemm) \ + __macro(rocblas_dgemm) \ + __macro(rocblas_hgemm) \ + __macro(rocblas_cgemm) \ + __macro(rocblas_zgemm) \ + __macro(rocblas_hgemm_strided_batched) \ + __macro(rocblas_sgemm_strided_batched) \ + __macro(rocblas_dgemm_strided_batched) \ + __macro(rocblas_cgemm_strided_batched) \ + __macro(rocblas_zgemm_strided_batched) \ + __macro(rocblas_gemm_ex) \ + __macro(rocblas_gemm_strided_batched_ex) \ + __macro(rocblas_strsm_batched) \ + __macro(rocblas_dtrsm_batched) \ + __macro(rocblas_ctrsm_batched) \ + __macro(rocblas_ztrsm_batched) + +// clang-format on + +FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER) +} // namespace +} // namespace wrap + namespace gpu { using rocm::ROCMComplex; @@ -125,7 +177,7 @@ static std::string ToString(rocblas_status status) { bool ROCMBlas::Init() { std::unique_ptr activation = parent_->Activate(); - rocblas_status ret = wrap::rocblas_create_handle(&blas_); + rocblas_status ret = rocblas_create_handle(&blas_); if (ret != rocblas_status_success) { LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); return false; @@ -164,7 +216,7 @@ ROCMBlas::ROCMBlas(StreamExecutor *parent) ROCMBlas::~ROCMBlas() { if (blas_ != nullptr) { std::unique_ptr activation = parent_->Activate(); - wrap::rocblas_destroy_handle(blas_); + rocblas_destroy_handle(blas_); } } @@ -174,7 +226,7 @@ bool ROCMBlas::SetStream(Stream *stream) { (stream != nullptr) ? static_cast(stream->platform_specific_handle().stream) : nullptr; - if (auto ret = wrap::rocblas_set_stream(blas_, handle); + if (auto ret = rocblas_set_stream(blas_, handle); ret != rocblas_status_success) { LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; @@ -186,7 +238,7 @@ absl::StatusOr ROCMBlas::IsMainStreamSet() const { absl::MutexLock lock{mu_}; CHECK(blas_ != nullptr); hipStream_t handle{}; - if (auto ret = wrap::rocblas_get_stream(blas_, &handle); + if (auto ret = rocblas_get_stream(blas_, &handle); ret != rocblas_status_success) { return absl::InternalError("failed to get the current stream value"); } @@ -377,7 +429,7 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, // set the atomics mode, leaving default to library bool allow_atomics = !OpDeterminismRequired(); if (!allow_atomics) { - ret = wrap::rocblas_set_atomics_mode(blas_, rocblas_atomics_not_allowed); + ret = rocblas_set_atomics_mode(blas_, rocblas_atomics_not_allowed); if (err_on_failure && ret != rocblas_status_success) { LOG(ERROR) << "failed to set atomics mode before " << FuncT::kName << ": " << ToString(ret); @@ -391,7 +443,7 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, auto *workspace = GetWorkspace(); auto *wptr = workspace != nullptr ? workspace->opaque() : nullptr; size_t wsize = workspace != nullptr ? workspace->size() : 0; - ret = wrap::rocblas_set_workspace(blas_, wptr, wsize); + ret = rocblas_set_workspace(blas_, wptr, wsize); if (err_on_failure && ret != rocblas_status_success) { LOG(ERROR) << "failed to set workspace before " << FuncT::kName << ": " << ToString(ret); @@ -745,18 +797,17 @@ bool ROCMBlas::GetBlasGemmAlgorithms( if (c->batch_size == 1) { return DoBlasInternalFailureOK( - NameWrap{blas_lambda}, stream, true, - wrap::rocblas_gemm_ex_get_solutions, ROCMBlasTranspose(a.transpose), - ROCMBlasTranspose(b.transpose), c->m, c->n, c->k, alpha, - a.data.opaque(), roc_type_a, a.leading_dim_stride, b.data.opaque(), - roc_type_a, b.leading_dim_stride, beta, c->data.opaque(), roc_type_c, - c->leading_dim_stride, c->data.opaque(), roc_type_c, - c->leading_dim_stride, roc_comp_type, rocblas_gemm_algo_solution_index, - 0); + NameWrap{blas_lambda}, stream, true, rocblas_gemm_ex_get_solutions, + ROCMBlasTranspose(a.transpose), ROCMBlasTranspose(b.transpose), c->m, + c->n, c->k, alpha, a.data.opaque(), roc_type_a, a.leading_dim_stride, + b.data.opaque(), roc_type_a, b.leading_dim_stride, beta, + c->data.opaque(), roc_type_c, c->leading_dim_stride, c->data.opaque(), + roc_type_c, c->leading_dim_stride, roc_comp_type, + rocblas_gemm_algo_solution_index, 0); } return DoBlasInternalFailureOK( NameWrap{blas_lambda}, stream, true, - wrap::rocblas_gemm_strided_batched_ex_get_solutions, + rocblas_gemm_strided_batched_ex_get_solutions, ROCMBlasTranspose(a.transpose), ROCMBlasTranspose(b.transpose), c->m, c->n, c->k, alpha, a.data.opaque(), roc_type_a, a.leading_dim_stride, a.batch_stride, b.data.opaque(), roc_type_a, b.leading_dim_stride, @@ -1262,13 +1313,13 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) absl::Status ROCMBlas::GetVersion(std::string *version) { absl::MutexLock lock{mu_}; size_t len = 0; - if (auto res = wrap::rocblas_get_version_string_size(&len); + if (auto res = rocblas_get_version_string_size(&len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res))); } std::vector buf(len + 1); - if (auto res = wrap::rocblas_get_version_string(buf.data(), len); + if (auto res = rocblas_get_version_string(buf.data(), len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res))); diff --git a/xla/stream_executor/rocm/rocm_blas.h b/xla/stream_executor/rocm/rocm_blas.h index 48a3576293b59..a1d9e4873ff78 100644 --- a/xla/stream_executor/rocm/rocm_blas.h +++ b/xla/stream_executor/rocm/rocm_blas.h @@ -26,11 +26,7 @@ limitations under the License. #include "rocm/rocm_config.h" #define ROCBLAS_BETA_FEATURES_API -#if TF_ROCM_VERSION >= 50600 #include "rocm/include/rocblas/rocblas.h" -#else -#include "rocm/include/rocblas.h" -#endif #include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/plugin_registry.h" diff --git a/xla/stream_executor/rocm/rocm_command_buffer.cc b/xla/stream_executor/rocm/rocm_command_buffer.cc index 92d6067151002..bd67f11b02529 100644 --- a/xla/stream_executor/rocm/rocm_command_buffer.cc +++ b/xla/stream_executor/rocm/rocm_command_buffer.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_kernel.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/stream_executor/stream_executor.h" @@ -55,7 +54,7 @@ namespace { absl::StatusOr CreateGraph() { VLOG(2) << "Create new HIP graph"; hipGraph_t graph; - TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(&graph, /*flags=*/0), + TF_RETURN_IF_ERROR(ToStatus(hipGraphCreate(&graph, /*flags=*/0), "Failed to create HIP graph")); VLOG(2) << "Created HIP graph " << graph; return graph; @@ -156,8 +155,8 @@ absl::StatusOr RocmCommandBuffer::CreateMemsetNode( hipGraphNode_t node_handle = nullptr; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipGraphAddMemsetNode(&node_handle, graph_, deps.data(), - deps.size(), ¶ms), + ToStatus(hipGraphAddMemsetNode(&node_handle, graph_, deps.data(), + deps.size(), ¶ms), "Failed to add memset node to a HIP graph")); return FromHipGraphHandle(node_handle); } @@ -179,7 +178,7 @@ absl::Status RocmCommandBuffer::UpdateMemsetNode(GraphNodeHandle node_handle, params.value = bit_pattern.GetPatternBroadcastedToUint32(); params.width = num_elements; - return ToStatus(wrap::hipGraphExecMemsetNodeSetParams( + return ToStatus(hipGraphExecMemsetNodeSetParams( exec_, ToHipGraphHandle(node_handle), ¶ms), "Failed to set memset node params"); } @@ -195,10 +194,9 @@ absl::StatusOr RocmCommandBuffer::CreateMemcpyD2DNode( hipGraphNode_t node_handle = nullptr; TF_RETURN_IF_ERROR(ToStatus( - wrap::hipGraphAddMemcpyNode1D(&node_handle, graph_, deps.data(), - deps.size(), AsDevicePtr(destination), - AsDevicePtr(source), size, - hipMemcpyDeviceToDevice), + hipGraphAddMemcpyNode1D(&node_handle, graph_, deps.data(), deps.size(), + AsDevicePtr(destination), AsDevicePtr(source), + size, hipMemcpyDeviceToDevice), "Failed to add memcpy d2d node to a HIP graph")); return FromHipGraphHandle(node_handle); } @@ -212,7 +210,7 @@ absl::Status RocmCommandBuffer::UpdateMemcpyD2DNode( << "; size: " << size; return ToStatus( - wrap::hipGraphExecMemcpyNodeSetParams1D( + hipGraphExecMemcpyNodeSetParams1D( exec_, ToHipGraphHandle(node_handle), AsDevicePtr(destination), AsDevicePtr(source), size, hipMemcpyDeviceToDevice), "Failed to set memcpy d2d node params"); @@ -236,8 +234,8 @@ absl::StatusOr RocmCommandBuffer::CreateChildNode( hipGraphNode_t node_handle = nullptr; TF_RETURN_IF_ERROR(ToStatus( - wrap::hipGraphAddChildGraphNode(&node_handle, graph_, deps.data(), - deps.size(), child_graph), + hipGraphAddChildGraphNode(&node_handle, graph_, deps.data(), deps.size(), + child_graph), "Failed to create a child graph node and add it to a HIP graph")); return FromHipGraphHandle(node_handle); } @@ -253,7 +251,7 @@ absl::Status RocmCommandBuffer::UpdateChildNode(ChildCommandType type, VLOG(2) << "Set child node params " << node_handle << " in graph executable " << exec_ << "to params contained in " << child_graph; - return ToStatus(wrap::hipGraphExecChildGraphNodeSetParams( + return ToStatus(hipGraphExecChildGraphNodeSetParams( exec_, ToHipGraphHandle(node_handle), child_graph), "Failed to set HIP graph child node params"); } @@ -287,19 +285,19 @@ absl::StatusOr RocmCommandBuffer::CreateKernelNode( params.extra = nullptr; if (shared_mem_bytes != 0) { - TF_RETURN_IF_ERROR(ToStatus( - wrap::hipFuncSetAttribute(function, - hipFuncAttributeMaxDynamicSharedMemorySize, - shared_mem_bytes), - "Failed to set shared memory size")); + TF_RETURN_IF_ERROR( + ToStatus(hipFuncSetAttribute(function, + hipFuncAttributeMaxDynamicSharedMemorySize, + shared_mem_bytes), + "Failed to set shared memory size")); } std::vector deps = ToHipGraphHandles(dependencies); hipGraphNode_t node_handle = nullptr; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipGraphAddKernelNode(&node_handle, graph_, deps.data(), - deps.size(), ¶ms), + ToStatus(hipGraphAddKernelNode(&node_handle, graph_, deps.data(), + deps.size(), ¶ms), "Failed to add kernel node to a HIP graph")); return FromHipGraphHandle(node_handle); @@ -333,14 +331,14 @@ absl::Status RocmCommandBuffer::UpdateKernelNode( params.extra = nullptr; if (shared_mem_bytes != 0) { - TF_RETURN_IF_ERROR(ToStatus( - wrap::hipFuncSetAttribute(function, - hipFuncAttributeMaxDynamicSharedMemorySize, - shared_mem_bytes), - "Failed to set shared memory size")); + TF_RETURN_IF_ERROR( + ToStatus(hipFuncSetAttribute(function, + hipFuncAttributeMaxDynamicSharedMemorySize, + shared_mem_bytes), + "Failed to set shared memory size")); } - return ToStatus(wrap::hipGraphExecKernelNodeSetParams( + return ToStatus(hipGraphExecKernelNodeSetParams( exec_, ToHipGraphHandle(node_handle), ¶ms), "Failed to set HIP graph kernel node params"); } @@ -366,20 +364,19 @@ absl::Status RocmCommandBuffer::Trace( // Switch stream into the capture mode. uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - TF_RETURN_IF_ERROR( - ToStatus(wrap::hipStreamBeginCapture(stream_handle, - hipStreamCaptureModeThreadLocal), - "Failed to begin stream capture")); + TF_RETURN_IF_ERROR(ToStatus( + hipStreamBeginCapture(stream_handle, hipStreamCaptureModeThreadLocal), + "Failed to begin stream capture")); auto traced = function(); // Always stop capturing the stream before checking `traced` result. VLOG(5) << "End stream " << stream << " capture"; hipGraph_t captured_graph; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipStreamEndCapture(stream_handle, &captured_graph), + ToStatus(hipStreamEndCapture(stream_handle, &captured_graph), "Failed to end stream capture")); TF_RETURN_IF_ERROR( - ToStatus(wrap::hipGraphDestroy(std::exchange(graph_, captured_graph)), + ToStatus(hipGraphDestroy(std::exchange(graph_, captured_graph)), "Failed to destroy HIP graph")); uint64_t end_nanos = tsl::Env::Default()->NowNanos(); @@ -396,15 +393,15 @@ absl::Status RocmCommandBuffer::Trace( absl::Status RocmCommandBuffer::LaunchGraph(Stream* stream) { VLOG(3) << "Launch command buffer executable graph " << exec_ << " on a stream: " << stream; - return ToStatus(wrap::hipGraphLaunch( - exec_, static_cast( - stream->platform_specific_handle().stream)), - "Failed to launch HIP graph"); + return ToStatus( + hipGraphLaunch(exec_, static_cast( + stream->platform_specific_handle().stream)), + "Failed to launch HIP graph"); } absl::StatusOr RocmCommandBuffer::GetNodeCount() const { size_t numNodes; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipGraphGetNodes(graph_, /*nodes=*/nullptr, &numNodes), + ToStatus(hipGraphGetNodes(graph_, /*nodes=*/nullptr, &numNodes), "Failed to get HIP graph node count")); return numNodes; @@ -425,15 +422,14 @@ absl::Status RocmCommandBuffer::WriteGraphToDotFile(absl::string_view path) { int flags = hipGraphDebugDotFlagsVerbose; return ToStatus( - wrap::hipGraphDebugDotPrint(graph_, std::string{path}.c_str(), flags), + hipGraphDebugDotPrint(graph_, std::string{path}.c_str(), flags), "Failed to print gpu graph debug file"); } absl::Status RocmCommandBuffer::InstantiateGraph() { VLOG(2) << "Instantiate HIP executable graph from graph " << graph_; - return ToStatus( - wrap::hipGraphInstantiate(&exec_, graph_, nullptr, nullptr, 0), - "Failed to instantiate HIP graph"); + return ToStatus(hipGraphInstantiate(&exec_, graph_, nullptr, nullptr, 0), + "Failed to instantiate HIP graph"); } RocmCommandBuffer::~RocmCommandBuffer() { diff --git a/xla/stream_executor/rocm/rocm_context.cc b/xla/stream_executor/rocm/rocm_context.cc index ab61f6902aa2d..c7632a9b9a9f3 100644 --- a/xla/stream_executor/rocm/rocm_context.cc +++ b/xla/stream_executor/rocm/rocm_context.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/context_map.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status.h" @@ -82,7 +81,7 @@ ContextMap* RocmContext::GetContextMap() { bool RocmContext::GetDeviceTotalMemory(hipDevice_t device, uint64_t* result) { size_t value = -1; - hipError_t res = wrap::hipDeviceTotalMem(&value, device); + hipError_t res = hipDeviceTotalMem(&value, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query total available memory: " << ToString(res); return false; @@ -95,7 +94,7 @@ bool RocmContext::GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out) { ScopedActivateContext activation(this); size_t free = 0; size_t total = 0; - hipError_t res = wrap::hipMemGetInfo(&free, &total); + hipError_t res = hipMemGetInfo(&free, &total); if (res != hipSuccess) { LOG(ERROR) << "failed to query device memory info: " << ToString(res); return false; @@ -121,10 +120,10 @@ RocmContext::~RocmContext() { // about calling a virtual method in the destructor. RocmContext::SetActive(); hipDevice_t device; - CHECK_EQ(hipSuccess, wrap::hipCtxGetDevice(&device)); - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); + CHECK_EQ(hipSuccess, hipCtxGetDevice(&device)); + CHECK_EQ(hipSuccess, hipCtxSetCurrent(former_context)); - auto res = wrap::hipDevicePrimaryCtxRelease(device); + auto res = hipDevicePrimaryCtxRelease(device); if (res != hipSuccess) { LOG(ERROR) << "failed to release HIP context; leaking: " << ToString(res); @@ -134,16 +133,15 @@ RocmContext::~RocmContext() { } void RocmContext::SetActive() { - TF_CHECK_OK( - ToStatus(wrap::hipCtxSetCurrent(context_), "Failed setting context")); + TF_CHECK_OK(ToStatus(hipCtxSetCurrent(context_), "Failed setting context")); } bool RocmContext::IsActive() const { return CurrentContext() == context_; } absl::Status RocmContext::Synchronize() { ScopedActivateContext activation(this); - TF_RETURN_IF_ERROR(ToStatus(wrap::hipDeviceSynchronize(), - "could not synchronize on ROCM device")); + TF_RETURN_IF_ERROR( + ToStatus(hipDeviceSynchronize(), "could not synchronize on ROCM device")); return absl::OkStatus(); } @@ -159,9 +157,9 @@ absl::StatusOr RocmContext::Create(int device_ordinal, unsigned int former_primary_context_flags; int former_primary_context_is_active; - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxGetState( - device, &former_primary_context_flags, - &former_primary_context_is_active)); + CHECK_EQ(hipSuccess, + hipDevicePrimaryCtxGetState(device, &former_primary_context_flags, + &former_primary_context_is_active)); if (former_primary_context_flags != flags) { if (former_primary_context_is_active) { LOG(ERROR) @@ -169,15 +167,15 @@ absl::StatusOr RocmContext::Create(int device_ordinal, << former_primary_context_flags << ") than the desired flag set (" << flags << ")."; } else { - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxSetFlags(device, flags)); + CHECK_EQ(hipSuccess, hipDevicePrimaryCtxSetFlags(device, flags)); } } former_context = CurrentContextOrDie(); - res = wrap::hipDevicePrimaryCtxRetain(&new_context, device); + res = hipDevicePrimaryCtxRetain(&new_context, device); if (former_context != nullptr) { hipDevice_t former_device; - if (wrap::hipCtxGetDevice(&former_device) == hipSuccess) { + if (hipCtxGetDevice(&former_device) == hipSuccess) { if (former_device == device) { if (former_context == new_context) { VLOG(2) << "The primary context " << former_context << " for device " @@ -196,7 +194,7 @@ absl::StatusOr RocmContext::Create(int device_ordinal, << former_context; } } - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); + CHECK_EQ(hipSuccess, hipCtxSetCurrent(former_context)); if (res == hipSuccess) { context = GetContextMap()->Add(new_context, device_ordinal); diff --git a/xla/stream_executor/rocm/rocm_dnn.cc b/xla/stream_executor/rocm/rocm_dnn.cc index d6747d7811d89..57f2e0ac46706 100644 --- a/xla/stream_executor/rocm/rocm_dnn.cc +++ b/xla/stream_executor/rocm/rocm_dnn.cc @@ -255,311 +255,6 @@ class MIOpenHandle { miopenHandle_t handle_; // Not owned. }; -namespace wrap { - -#ifdef PLATFORM_GOOGLE -#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ - struct WrapperShim__##__name { \ - template \ - miopenStatus_t operator()(Args... args) { \ - miopenStatus_t retval = ::__name(args...); \ - return retval; \ - } \ - } __name; - -#else - -#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = tsl::internal::CachedDsoLoader::GetMiopenDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - miopenStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; \ - const char* DynLoadShim__##__name::kName = #__name; - -#endif - -#if (TF_ROCM_VERSION >= 50000) -// clang-format off -#define MIOPEN_DNN_ROUTINE_EACH(__macro) \ - __macro(miopenBatchNormalizationBackward) \ - __macro(miopenBatchNormalizationForwardInference) \ - __macro(miopenBatchNormalizationForwardTraining) \ - __macro(miopenGetConvolutionForwardOutputDim) \ - __macro(miopenGetConvolutionNdForwardOutputDim) \ - __macro(miopenFindConvolutionForwardAlgorithm) \ - __macro(miopenCreateTensorDescriptor) \ - __macro(miopenDestroyTensorDescriptor) \ - __macro(miopenSetNdPoolingDescriptor) \ - __macro(miopenSetPoolingIndexType) \ - __macro(miopenSetLRNDescriptor) \ - __macro(miopenLRNGetWorkSpaceSize) \ - __macro(miopenCreateConvolutionDescriptor) \ - __macro(miopenCreatePoolingDescriptor) \ - __macro(miopenDestroyPoolingDescriptor) \ - __macro(miopenCreateLRNDescriptor) \ - __macro(miopenDestroyLRNDescriptor) \ - __macro(miopenDestroyConvolutionDescriptor) \ - __macro(miopenCreateWithStream) \ - __macro(miopenDestroy) \ - __macro(miopenSetStream) \ - __macro(miopenSetAllocator) \ - __macro(miopenActivationForward) \ - __macro(miopenConvolutionForward) \ - __macro(miopenConvolutionBackwardBias) \ - __macro(miopenConvolutionForwardGetWorkSpaceSize) \ - __macro(miopenInitConvolutionDescriptor) \ - __macro(miopenInitConvolutionNdDescriptor) \ - __macro(miopenGetConvolutionDescriptor) \ - __macro(miopenGetConvolutionNdDescriptor) \ - __macro(miopenSetConvolutionGroupCount) \ - __macro(miopenSet4dTensorDescriptor) \ - __macro(miopenGetTensorDescriptor) \ - __macro(miopenSetTensorDescriptor) \ - __macro(miopenGetTensorDescriptorSize) \ - __macro(miopenPoolingForward) \ - __macro(miopenPoolingGetWorkSpaceSizeV2) \ - __macro(miopenPoolingBackward) \ - __macro(miopenLRNForward) \ - __macro(miopenLRNBackward) \ - __macro(miopenOpTensor) \ - __macro(miopenConvolutionBackwardData) \ - __macro(miopenConvolutionBackwardWeights) \ - __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize) \ - __macro(miopenFindConvolutionBackwardDataAlgorithm) \ - __macro(miopenFindConvolutionBackwardWeightsAlgorithm) \ - __macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \ - __macro(miopenCreateRNNDescriptor) \ - __macro(miopenSetRNNDescriptor) \ - __macro(miopenSetRNNDescriptor_V2) \ - __macro(miopenDestroyRNNDescriptor) \ - __macro(miopenGetRNNParamsSize) \ - __macro(miopenGetRNNLayerParam) \ - __macro(miopenGetRNNLayerBias) \ - __macro(miopenGetRNNWorkspaceSize) \ - __macro(miopenGetRNNTrainingReserveSize) \ - __macro(miopenRNNForwardInference) \ - __macro(miopenRNNForwardTraining) \ - __macro(miopenRNNBackwardData) \ - __macro(miopenRNNBackwardWeights) \ - __macro(miopenGetRNNLayerParamOffset) \ - __macro(miopenGetRNNLayerParamSize) \ - __macro(miopenGetRNNLayerBiasOffset) \ - __macro(miopenGetRNNLayerBiasSize) \ - __macro(miopenGetRNNParamsDescriptor) \ - __macro(miopenCreateDropoutDescriptor) \ - __macro(miopenSetDropoutDescriptor) \ - __macro(miopenGetDropoutDescriptor) \ - __macro(miopenDestroyDropoutDescriptor) \ - __macro(miopenRestoreDropoutDescriptor) \ - __macro(miopenDropoutGetReserveSpaceSize) \ - __macro(miopenDropoutGetStatesSize) \ - __macro(miopenDropoutForward) \ - __macro(miopenDropoutBackward) \ - __macro(miopenCreateActivationDescriptor) \ - __macro(miopenSetActivationDescriptor) \ - __macro(miopenGetActivationDescriptor) \ - __macro(miopenDestroyActivationDescriptor) \ - __macro(miopenCreateFusionPlan) \ - __macro(miopenCreateOpConvForward) \ - __macro(miopenCreateOpBiasForward) \ - __macro(miopenCreateOpActivationForward) \ - __macro(miopenCreateOpActivationBackward) \ - __macro(miopenCreateOpBatchNormInference) \ - __macro(miopenCreateOpBatchNormForward) \ - __macro(miopenCreateOpBatchNormBackward) \ - __macro(miopenCompileFusionPlan) \ - __macro(miopenFusionPlanGetOp) \ - __macro(miopenCreateOperatorArgs) \ - __macro(miopenSetOpArgsConvForward) \ - __macro(miopenSetOpArgsBiasForward) \ - __macro(miopenSetOpArgsActivForward) \ - __macro(miopenSetOpArgsActivBackward) \ - __macro(miopenSetOpArgsBatchNormInference) \ - __macro(miopenSetOpArgsBatchNormForward) \ - __macro(miopenSetOpArgsBatchNormBackward) \ - __macro(miopenExecuteFusionPlan) \ - __macro(miopenDestroyOperatorArgs) \ - __macro(miopenDestroyFusionPlan) \ - __macro(miopenConvolutionForwardGetSolutionCount) \ - __macro(miopenConvolutionForwardGetSolution) \ - __macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionForwardCompileSolution) \ - __macro(miopenConvolutionForwardImmediate) \ - __macro(miopenConvolutionForwardBias) \ - __macro(miopenConvolutionBiasActivationForward) \ - __macro(miopenConvolutionBackwardDataGetSolutionCount) \ - __macro(miopenConvolutionBackwardDataGetSolution) \ - __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionBackwardDataCompileSolution) \ - __macro(miopenConvolutionBackwardDataImmediate) \ - __macro(miopenConvolutionBackwardWeightsGetSolutionCount) \ - __macro(miopenConvolutionBackwardWeightsGetSolution) \ - __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionBackwardWeightsCompileSolution) \ - __macro(miopenConvolutionBackwardWeightsImmediate) \ - __macro(miopenCreateCTCLossDescriptor) \ - __macro(miopenSetCTCLossDescriptor) \ - __macro(miopenGetCTCLossWorkspaceSize) \ - __macro(miopenCTCLoss) \ - __macro(miopenDestroyCTCLossDescriptor) \ - __macro(miopenSetConvolutionAttribute) // clang-format on -#else -// clang-format off -#define MIOPEN_DNN_ROUTINE_EACH(__macro) \ - __macro(miopenBatchNormalizationBackward) \ - __macro(miopenBatchNormalizationForwardInference) \ - __macro(miopenBatchNormalizationForwardTraining) \ - __macro(miopenGetConvolutionForwardOutputDim) \ - __macro(miopenGetConvolutionNdForwardOutputDim) \ - __macro(miopenFindConvolutionForwardAlgorithm) \ - __macro(miopenCreateTensorDescriptor) \ - __macro(miopenDestroyTensorDescriptor) \ - __macro(miopenSetNdPoolingDescriptor) \ - __macro(miopenSetPoolingIndexType) \ - __macro(miopenSetLRNDescriptor) \ - __macro(miopenLRNGetWorkSpaceSize) \ - __macro(miopenCreateConvolutionDescriptor) \ - __macro(miopenCreatePoolingDescriptor) \ - __macro(miopenDestroyPoolingDescriptor) \ - __macro(miopenCreateLRNDescriptor) \ - __macro(miopenDestroyLRNDescriptor) \ - __macro(miopenDestroyConvolutionDescriptor) \ - __macro(miopenCreateWithStream) \ - __macro(miopenDestroy) \ - __macro(miopenSetStream) \ - __macro(miopenSetAllocator) \ - __macro(miopenActivationForward) \ - __macro(miopenConvolutionForward) \ - __macro(miopenConvolutionBackwardBias) \ - __macro(miopenConvolutionForwardGetWorkSpaceSize) \ - __macro(miopenInitConvolutionDescriptor) \ - __macro(miopenInitConvolutionNdDescriptor) \ - __macro(miopenGetConvolutionDescriptor) \ - __macro(miopenGetConvolutionNdDescriptor) \ - __macro(miopenSetConvolutionGroupCount) \ - __macro(miopenSet4dTensorDescriptor) \ - __macro(miopenGetTensorDescriptor) \ - __macro(miopenSetTensorDescriptor) \ - __macro(miopenGetTensorDescriptorSize) \ - __macro(miopenPoolingForward) \ - __macro(miopenPoolingGetWorkSpaceSizeV2) \ - __macro(miopenPoolingBackward) \ - __macro(miopenLRNForward) \ - __macro(miopenLRNBackward) \ - __macro(miopenOpTensor) \ - __macro(miopenConvolutionBackwardData) \ - __macro(miopenConvolutionBackwardWeights) \ - __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize) \ - __macro(miopenFindConvolutionBackwardDataAlgorithm) \ - __macro(miopenFindConvolutionBackwardWeightsAlgorithm) \ - __macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \ - __macro(miopenCreateRNNDescriptor) \ - __macro(miopenSetRNNDescriptor) \ - __macro(miopenSetRNNDescriptor_V2) \ - __macro(miopenDestroyRNNDescriptor) \ - __macro(miopenGetRNNParamsSize) \ - __macro(miopenGetRNNLayerParam) \ - __macro(miopenGetRNNLayerBias) \ - __macro(miopenGetRNNWorkspaceSize) \ - __macro(miopenGetRNNTrainingReserveSize) \ - __macro(miopenRNNForwardInference) \ - __macro(miopenRNNForwardTraining) \ - __macro(miopenRNNBackwardData) \ - __macro(miopenRNNBackwardWeights) \ - __macro(miopenGetRNNLayerParamOffset) \ - __macro(miopenGetRNNLayerParamSize) \ - __macro(miopenGetRNNLayerBiasOffset) \ - __macro(miopenGetRNNLayerBiasSize) \ - __macro(miopenGetRNNParamsDescriptor) \ - __macro(miopenCreateDropoutDescriptor) \ - __macro(miopenSetDropoutDescriptor) \ - __macro(miopenGetDropoutDescriptor) \ - __macro(miopenDestroyDropoutDescriptor) \ - __macro(miopenRestoreDropoutDescriptor) \ - __macro(miopenDropoutGetReserveSpaceSize) \ - __macro(miopenDropoutGetStatesSize) \ - __macro(miopenDropoutForward) \ - __macro(miopenDropoutBackward) \ - __macro(miopenCreateActivationDescriptor) \ - __macro(miopenSetActivationDescriptor) \ - __macro(miopenGetActivationDescriptor) \ - __macro(miopenDestroyActivationDescriptor) \ - __macro(miopenCreateFusionPlan) \ - __macro(miopenCreateOpConvForward) \ - __macro(miopenCreateOpBiasForward) \ - __macro(miopenCreateOpActivationForward) \ - __macro(miopenCreateOpActivationBackward) \ - __macro(miopenCreateOpBatchNormInference) \ - __macro(miopenCreateOpBatchNormForward) \ - __macro(miopenCreateOpBatchNormBackward) \ - __macro(miopenCompileFusionPlan) \ - __macro(miopenFusionPlanGetOp) \ - __macro(miopenCreateOperatorArgs) \ - __macro(miopenSetOpArgsConvForward) \ - __macro(miopenSetOpArgsBiasForward) \ - __macro(miopenSetOpArgsActivForward) \ - __macro(miopenSetOpArgsActivBackward) \ - __macro(miopenSetOpArgsBatchNormInference) \ - __macro(miopenSetOpArgsBatchNormForward) \ - __macro(miopenSetOpArgsBatchNormBackward) \ - __macro(miopenExecuteFusionPlan) \ - __macro(miopenDestroyOperatorArgs) \ - __macro(miopenDestroyFusionPlan) \ - __macro(miopenConvolutionBiasActivationForward) \ - __macro(miopenConvolutionForwardGetSolutionCount) \ - __macro(miopenConvolutionForwardGetSolution) \ - __macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionForwardCompileSolution) \ - __macro(miopenConvolutionForwardImmediate) \ - __macro(miopenConvolutionForwardBias) \ - __macro(miopenConvolutionBackwardDataGetSolutionCount) \ - __macro(miopenConvolutionBackwardDataGetSolution) \ - __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionBackwardDataCompileSolution) \ - __macro(miopenConvolutionBackwardDataImmediate) \ - __macro(miopenConvolutionBackwardWeightsGetSolutionCount) \ - __macro(miopenConvolutionBackwardWeightsGetSolution) \ - __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \ - __macro(miopenConvolutionBackwardWeightsCompileSolution) \ - __macro(miopenConvolutionBackwardWeightsImmediate) \ - __macro(miopenCreateCTCLossDescriptor) \ - __macro(miopenSetCTCLossDescriptor) \ - __macro(miopenGetCTCLossWorkspaceSize) \ - __macro(miopenCTCLoss) \ - __macro(miopenDestroyCTCLossDescriptor) -// clang-format on -#endif - -#if (MIOPEN_BETA_API && TF_ROCM_VERSION >= 60300) -STREAM_EXECUTOR_MIOPEN_WRAP(miopenSetTensorDescriptorV2) -#endif - -MIOPEN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_MIOPEN_WRAP) - -#undef MIOPEN_DNN_ROUTINE_EACH - -} // namespace wrap - namespace { // These routines should ideally be provided as an MIOpen API. @@ -575,7 +270,7 @@ uint64_t GetHashValue(miopenTensorDescriptor_t tensor_desc) { miopenDataType_t datatype = miopenFloat; int dims[kMaxMIOpenTensorSize] = {0}; int strides[kMaxMIOpenTensorSize] = {0}; - wrap::miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides); + miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides); uint64_t hash_value = tsl::hash()(datatype); for (int dim : dims) @@ -589,15 +284,15 @@ uint64_t GetHashValue(miopenTensorDescriptor_t tensor_desc) { uint64_t GetHashValue(miopenConvolutionDescriptor_t conv_desc) { miopenConvolutionMode_t c_mode = miopenConvolution; int nd = 0; - wrap::miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr, - nullptr, &c_mode); + miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr, nullptr, + &c_mode); std::vector stride(nd); std::vector pad(nd); std::vector dilation(nd); - wrap::miopenGetConvolutionNdDescriptor( - conv_desc, nd, &nd, pad.data(), stride.data(), dilation.data(), &c_mode); + miopenGetConvolutionNdDescriptor(conv_desc, nd, &nd, pad.data(), + stride.data(), dilation.data(), &c_mode); uint64_t hash_value = tsl::hash()(c_mode); auto hash64Combine = [&hash_value](int element) { @@ -638,8 +333,8 @@ class CachedFusionPlans { found_cached_plan = true; } else { VLOG(2) << "Creating a new plan for " << hash; - auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction, - input_descriptor); + auto status = miopenCreateFusionPlan(fusion_plan, fusion_direction, + input_descriptor); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCreateFusionPlan failed: " << ToString(status); @@ -656,7 +351,7 @@ class CachedFusionPlans { absl::MutexLock lock{cached_plans_mutex}; for (auto it : cached_plans) { - auto status = wrap::miopenDestroyFusionPlan(it.second); + auto status = miopenDestroyFusionPlan(it.second); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyFusionPlan failed: " << ToString(status); @@ -746,7 +441,7 @@ class MIOpenAccess { ~MIOpenAccess() { absl::MutexLock lock(mutex_); - wrap::miopenDestroy(handle_); + miopenDestroy(handle_); } // Creates a MIOpenHandle instance for stream. @@ -770,7 +465,7 @@ class MIOpenAccess { stream ? static_cast( stream->platform_specific_handle().stream) : nullptr; - auto status = wrap::miopenSetStream(handle_, hip_stream); + auto status = miopenSetStream(handle_, hip_stream); CHECK_EQ(status, miopenStatusSuccess) << "Failed to set MIOpen stream."; return MIOpenHandle(executor, std::move(lock), handle_); } @@ -808,7 +503,7 @@ MIOpenSupport::MIOpenSupport(StreamExecutor* parent) : parent_(parent) { absl::Status MIOpenSupport::Init() { std::unique_ptr context = parent_->Activate(); miopenHandle_t miopen_handle = nullptr; - auto status = wrap::miopenCreateWithStream( + auto status = miopenCreateWithStream( reinterpret_cast(&miopen_handle), (hipStream_t) nullptr); if (status == miopenStatusSuccess) { miopen_ = std::make_unique(miopen_handle); @@ -834,22 +529,22 @@ miopenStatus_t miDestroyObject(T obj) { template <> miopenStatus_t miDestroyObject(miopenTensorDescriptor_t obj) { - return wrap::miopenDestroyTensorDescriptor(obj); + return miopenDestroyTensorDescriptor(obj); } template <> miopenStatus_t miDestroyObject(miopenConvolutionDescriptor_t obj) { - return wrap::miopenDestroyConvolutionDescriptor(obj); + return miopenDestroyConvolutionDescriptor(obj); } template <> miopenStatus_t miDestroyObject(miopenPoolingDescriptor_t obj) { - return wrap::miopenDestroyPoolingDescriptor(obj); + return miopenDestroyPoolingDescriptor(obj); } template <> miopenStatus_t miDestroyObject(miopenLRNDescriptor_t obj) { - return wrap::miopenDestroyLRNDescriptor(obj); + return miopenDestroyLRNDescriptor(obj); } template @@ -864,8 +559,8 @@ struct ScopedDescriptor { ~ScopedDescriptor() { if (handle_ == nullptr) return; - auto status = miDestroyObject( - handle_); // wrap::miopenDestroyTensorDescriptor(handle_); + auto status = + miDestroyObject(handle_); // miopenDestroyTensorDescriptor(handle_); if (status != miopenStatusSuccess) { LOG(ERROR) << "could not destroy miopen tensor descriptor: " << ToString(status); @@ -890,7 +585,7 @@ using ScopedNormalizeDescriptor = ScopedDescriptor; absl::StatusOr scope( const BatchDescriptor& batch_descriptor, miopenDataType_t data_type) { ScopedTensorDescriptor obj; - auto status = wrap::miopenCreateTensorDescriptor(&obj.handle_); + auto status = miopenCreateTensorDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { return absl::InternalError("could not create miopen tensor descriptor: " + ToString(status)); @@ -908,9 +603,9 @@ absl::StatusOr scope( batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX); #if (MIOPEN_BETA_API && TF_ROCM_VERSION >= 60300) - status = wrap::miopenSetTensorDescriptorV2( - obj.handle_, data_type, nd, (const size_t*)dims64.data(), - (const size_t*)strides64.data()); + status = miopenSetTensorDescriptorV2(obj.handle_, data_type, nd, + (const size_t*)dims64.data(), + (const size_t*)strides64.data()); #else // MIOpen requires arrays of ints. std::vector strides(nd); @@ -919,8 +614,8 @@ absl::StatusOr scope( &CheckedNarrowing); std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), &CheckedNarrowing); - status = wrap::miopenSetTensorDescriptor(obj.handle_, data_type, nd, - dims.data(), strides.data()); + status = miopenSetTensorDescriptor(obj.handle_, data_type, nd, + dims.data(), strides.data()); #endif if (status != miopenStatusSuccess) { return absl::InternalError( @@ -939,7 +634,7 @@ absl::StatusOr scope( absl::StatusOr scope( const FilterDescriptor& filter_descriptor, miopenDataType_t data_type) { ScopedFilterDescriptor obj; - auto status = wrap::miopenCreateTensorDescriptor(&obj.handle_); + auto status = miopenCreateTensorDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen filter descriptor: " << ToString(status); @@ -990,9 +685,9 @@ absl::StatusOr scope( filter_descriptor.full_dims(dnn::FilterLayout::kOutputInputYX); #if (MIOPEN_BETA_API && TF_ROCM_VERSION >= 60300) - status = wrap::miopenSetTensorDescriptorV2( - obj.handle_, data_type, nd, (const size_t*)dims64.data(), - (const size_t*)strides64.data()); + status = miopenSetTensorDescriptorV2(obj.handle_, data_type, nd, + (const size_t*)dims64.data(), + (const size_t*)strides64.data()); #else // MIOpen requires arrays of ints. std::vector strides; @@ -1001,8 +696,8 @@ absl::StatusOr scope( &CheckedNarrowing); absl::c_transform(dims64, std::back_inserter(dims), &CheckedNarrowing); - status = wrap::miopenSetTensorDescriptor(obj.handle_, data_type, nd, - dims.data(), strides.data()); + status = miopenSetTensorDescriptor(obj.handle_, data_type, nd, + dims.data(), strides.data()); #endif if (status != miopenStatusSuccess) { LOG(FATAL) << "could not convert FilterDescriptor " @@ -1021,7 +716,7 @@ absl::StatusOr scope( absl::StatusOr scope( const ConvolutionDescriptor& convolution_descriptor) { ScopedConvolutionDescriptor obj; - auto status = wrap::miopenCreateConvolutionDescriptor(&obj.handle_); + auto status = miopenCreateConvolutionDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen convolution descriptor: " << ToString(status); @@ -1042,7 +737,7 @@ absl::StatusOr scope( std::transform(dilations64.cbegin(), dilations64.cend(), upscale.begin(), &CheckedNarrowing); - status = wrap::miopenInitConvolutionNdDescriptor( + status = miopenInitConvolutionNdDescriptor( obj.handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), upscale.data(), miopenConvolution); if (status != miopenStatusSuccess) { @@ -1052,8 +747,8 @@ absl::StatusOr scope( VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); - status = wrap::miopenSetConvolutionGroupCount( - obj.handle_, convolution_descriptor.group_count()); + status = miopenSetConvolutionGroupCount(obj.handle_, + convolution_descriptor.group_count()); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen convolution group count: " << ToString(status); @@ -1061,7 +756,7 @@ absl::StatusOr scope( #if (TF_ROCM_VERSION >= 50300) if (RequireMIOpenDeterminism()) { - status = wrap::miopenSetConvolutionAttribute( + status = miopenSetConvolutionAttribute( obj.handle_, MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, 1); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen convolution attribute: " @@ -1075,7 +770,7 @@ absl::StatusOr scope( absl::StatusOr scope( const PoolingDescriptor& pooling_descriptor) { ScopedPoolingDescriptor obj; - auto status = wrap::miopenCreatePoolingDescriptor(&obj.handle_); + auto status = miopenCreatePoolingDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen pooling descriptor: " << ToString(status); @@ -1096,7 +791,7 @@ absl::StatusOr scope( std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); - status = wrap::miopenSetNdPoolingDescriptor( + status = miopenSetNdPoolingDescriptor( obj.handle_, (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? miopenPoolingMax @@ -1107,7 +802,7 @@ absl::StatusOr scope( // API assumes all input indexes to be the same type. Since a tensor // descriptor can only use int32 type, the index type here need to be // aligned with the tensor index type of the (input) tensor descritptor - status = wrap::miopenSetPoolingIndexType(obj.handle_, miopenIndexUint32); + status = miopenSetPoolingIndexType(obj.handle_, miopenIndexUint32); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen pooling descriptor: " @@ -1119,7 +814,7 @@ absl::StatusOr scope( absl::StatusOr scope( const NormalizeDescriptor& normalize_descriptor) { ScopedNormalizeDescriptor obj; - auto status = wrap::miopenCreateLRNDescriptor(&obj.handle_); + auto status = miopenCreateLRNDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen LRN descriptor: " << ToString(status); @@ -1145,8 +840,8 @@ absl::StatusOr scope( double lrn_beta = normalize_descriptor.beta(); double lrn_k = normalize_descriptor.bias(); - status = wrap::miopenSetLRNDescriptor(obj.handle_, miopenLRNCrossChannel, - lrn_N, lrn_alpha, lrn_beta, lrn_k); + status = miopenSetLRNDescriptor(obj.handle_, miopenLRNCrossChannel, lrn_N, + lrn_alpha, lrn_beta, lrn_k); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status); } @@ -1161,7 +856,7 @@ struct ScopedActivationDescriptor dnn::ActivationMode activation_mode, double alpha = 0.0) { ScopedActivationDescriptor obj; obj.alpha_ = alpha; - auto status = wrap::miopenCreateActivationDescriptor(&obj.handle_); + auto status = miopenCreateActivationDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { return absl::InternalError( "call to miopenCreateActivationDescriptor failed: " + @@ -1206,9 +901,9 @@ struct ScopedActivationDescriptor return absl::InternalError("Activation not implemented"); } - status = wrap::miopenSetActivationDescriptor( - obj.handle_, obj.miopen_activation_mode_, obj.alpha_, obj.beta_, - obj.gamma_); + status = miopenSetActivationDescriptor(obj.handle_, + obj.miopen_activation_mode_, + obj.alpha_, obj.beta_, obj.gamma_); if (status != miopenStatusSuccess) { return absl::InternalError( "call to miopenSetActivationDescriptor failed: " + @@ -1261,7 +956,7 @@ class ScopedFusionPlanBase { fusion_plan_(nullptr), fusion_args_(nullptr), fusion_plan_compiled_(false) { - auto status = wrap::miopenCreateOperatorArgs(&fusion_args_); + auto status = miopenCreateOperatorArgs(&fusion_args_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCreateOperatorArgs failed: " << ToString(status); @@ -1270,7 +965,7 @@ class ScopedFusionPlanBase { virtual ~ScopedFusionPlanBase() { if (fusion_args_ == nullptr) return; - auto status = wrap::miopenDestroyOperatorArgs(fusion_args_); + auto status = miopenDestroyOperatorArgs(fusion_args_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: " << ToString(status); @@ -1281,7 +976,7 @@ class ScopedFusionPlanBase { const void* input_data, miopenTensorDescriptor_t output_descriptor, void* output_data) { - auto status = wrap::miopenExecuteFusionPlan( + auto status = miopenExecuteFusionPlan( miopen_handle_, fusion_plan_, input_descriptor, input_data, output_descriptor, output_data, fusion_args_); if (status != miopenStatusSuccess) { @@ -1297,14 +992,14 @@ class ScopedFusionPlanBase { miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha, const float* beta, const void* data) { miopenFusionOpDescriptor_t conv_op; - auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha, - beta, data); + status = + miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha, beta, data); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetOpArgsConvForward failed: " << ToString(status); @@ -1315,14 +1010,14 @@ class ScopedFusionPlanBase { miopenStatus_t SetBiasArgs(const int op_idx, const float* alpha, const float* beta, const void* data) { miopenFusionOpDescriptor_t bias_op; - auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha, - beta, data); + status = + miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha, beta, data); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetOpArgsBiasForward failed: " << ToString(status); @@ -1336,16 +1031,15 @@ class ScopedFusionPlanBase { const void* variance, double epsilon) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = - wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op, - alpha, beta, scale, offset, - mean, variance, epsilon); + status = miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op, + alpha, beta, scale, offset, mean, + variance, epsilon); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetOpArgsBatchNormInference failed: " << ToString(status); @@ -1359,14 +1053,13 @@ class ScopedFusionPlanBase { void* running_variance, void* saved_mean, void* saved_inv_variance, double exponential_average_factor, double epsilon) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = - wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsBatchNormForward( + status = miopenSetOpArgsBatchNormForward( fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean, saved_inv_variance, running_mean, running_variance, exponential_average_factor, epsilon); @@ -1384,14 +1077,13 @@ class ScopedFusionPlanBase { const void* saved_mean, const void* saved_inv_variance) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = - wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsBatchNormBackward( + status = miopenSetOpArgsBatchNormBackward( fusion_args_, batchnorm_op, alpha, beta, x, scale, offset, scale_grad, offset_grad, saved_mean, saved_inv_variance); if (status != miopenStatusSuccess) { @@ -1406,15 +1098,14 @@ class ScopedFusionPlanBase { double activ_beta, double activ_gamma) { miopenFusionOpDescriptor_t actv_op; - auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = - wrap::miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta, - activ_alpha, activ_beta, activ_gamma); + status = miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta, + activ_alpha, activ_beta, activ_gamma); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetOpArgsActivForward failed: " << ToString(status); @@ -1428,15 +1119,15 @@ class ScopedFusionPlanBase { double activ_beta, double activ_gamma) { miopenFusionOpDescriptor_t actv_op; - auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op); + auto status = miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenFusionPlanGetOp failed: " << ToString(status); } - status = wrap::miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha, - beta, y, nullptr, activ_alpha, - activ_beta, activ_gamma); + status = miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha, beta, y, + nullptr, activ_alpha, activ_beta, + activ_gamma); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetOpArgsActivBackward failed: " << ToString(status); @@ -1497,20 +1188,20 @@ class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase { hash, &obj.fusion_plan_, miopenVerticalFusion, input_descriptor); if (is_compiled) VLOG(2) << "Cache hit"; if (!is_compiled) { - auto status = wrap::miopenCreateOpConvForward( + auto status = miopenCreateOpConvForward( obj.fusion_plan_, &obj.conv_op, conv_descriptor, filter_descriptor); if (status != miopenStatusSuccess) return absl::InternalError("miopenCreateOpConvForward failed: " + ToString(status)); - status = wrap::miopenCreateOpBiasForward(obj.fusion_plan_, &obj.bias_op, - bias_descriptor); + status = miopenCreateOpBiasForward(obj.fusion_plan_, &obj.bias_op, + bias_descriptor); if (status != miopenStatusSuccess) return absl::InternalError("miopenCreateOpBiasForward failed: " + ToString(status)); if (act_descriptor.miopen_activation_mode_ != miopenActivationPASTHRU) { - status = wrap::miopenCreateOpActivationForward( + status = miopenCreateOpActivationForward( obj.fusion_plan_, &obj.actv_op, act_descriptor.miopen_activation_mode_); if (status != miopenStatusSuccess) @@ -1518,7 +1209,7 @@ class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase { "miopenCreateOpActivationForward failed: " + ToString(status)); } - status = wrap::miopenCompileFusionPlan(miopen_handle, obj.fusion_plan_); + status = miopenCompileFusionPlan(miopen_handle, obj.fusion_plan_); if (status != miopenStatusSuccess) { VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: " << ToString(status); @@ -1615,7 +1306,7 @@ class ScopedFusionPlanBatchNormActivationInference if (!is_compiled) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = wrap::miopenCreateOpBatchNormInference( + auto status = miopenCreateOpBatchNormInference( fusion_plan_, &batchnorm_op, miopenBNSpatial, scale_offset_mean_variance_descriptor); @@ -1625,7 +1316,7 @@ class ScopedFusionPlanBatchNormActivationInference } miopenFusionOpDescriptor_t actv_op; - status = wrap::miopenCreateOpActivationForward( + status = miopenCreateOpActivationForward( fusion_plan_, &actv_op, activation_descriptor.miopen_activation_mode_); if (status != miopenStatusSuccess) { @@ -1633,7 +1324,7 @@ class ScopedFusionPlanBatchNormActivationInference << ToString(status); } - status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_); + status = miopenCompileFusionPlan(miopen_handle_, fusion_plan_); if (status != miopenStatusSuccess) { VLOG(2) << "call to miopenCompileFusionPlan (BnA inference) failed: " << ToString(status); @@ -1716,7 +1407,7 @@ class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase { if (!is_compiled) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = wrap::miopenCreateOpBatchNormForward( + auto status = miopenCreateOpBatchNormForward( fusion_plan_, &batchnorm_op, miopenBNSpatial, true /* runningMeanVariance */); @@ -1726,7 +1417,7 @@ class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase { } miopenFusionOpDescriptor_t actv_op; - status = wrap::miopenCreateOpActivationForward( + status = miopenCreateOpActivationForward( fusion_plan_, &actv_op, activation_descriptor.miopen_activation_mode_); if (status != miopenStatusSuccess) { @@ -1734,7 +1425,7 @@ class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase { << ToString(status); } - status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_); + status = miopenCompileFusionPlan(miopen_handle_, fusion_plan_); if (status != miopenStatusSuccess) { VLOG(2) << "call to miopenCompileFusionPlan (BnA forward) failed: " << ToString(status); @@ -1818,8 +1509,8 @@ class ScopedFusionPlanBatchNormActivationBackward if (!is_compiled) { miopenFusionOpDescriptor_t batchnorm_op; - auto status = wrap::miopenCreateOpBatchNormBackward( - fusion_plan_, &batchnorm_op, miopenBNSpatial); + auto status = miopenCreateOpBatchNormBackward(fusion_plan_, &batchnorm_op, + miopenBNSpatial); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCreateOpBatchNormBackward failed: " @@ -1827,7 +1518,7 @@ class ScopedFusionPlanBatchNormActivationBackward } miopenFusionOpDescriptor_t actv_op; - status = wrap::miopenCreateOpActivationBackward( + status = miopenCreateOpActivationBackward( fusion_plan_, &actv_op, activation_descriptor.miopen_activation_mode_); if (status != miopenStatusSuccess) { @@ -1835,7 +1526,7 @@ class ScopedFusionPlanBatchNormActivationBackward << ToString(status); } - status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_); + status = miopenCompileFusionPlan(miopen_handle_, fusion_plan_); if (status != miopenStatusSuccess) { VLOG(2) << "call to miopenCompileFusionPlan (BnA backward) failed: " << ToString(status); @@ -2017,7 +1708,7 @@ class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon { MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc); ~MIOpenRnnParamsDescriptor() { - auto status = wrap::miopenDestroyTensorDescriptor(handle_); + auto status = miopenDestroyTensorDescriptor(handle_); RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy RNN tensor descriptor"); } miopenTensorDescriptor_t handle() const { @@ -2051,7 +1742,7 @@ class MIOpenDropoutDescriptor { MIOpenDropoutDescriptor(miopenHandle_t miopen_handle, float dropout, uint64_t seed, ScratchAllocator* state_allocator) : dropout_desc_(nullptr) { - auto status = wrap::miopenCreateDropoutDescriptor(&dropout_desc_); + auto status = miopenCreateDropoutDescriptor(&dropout_desc_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCreateDropoutDescriptor failed: " << ToString(status); @@ -2061,8 +1752,8 @@ class MIOpenDropoutDescriptor { DeviceMemory state_memory; if (state_allocator) { size_t state_sizes_in_bytes = 0; - status = wrap::miopenDropoutGetStatesSize(miopen_handle, - &state_sizes_in_bytes); + status = + miopenDropoutGetStatesSize(miopen_handle, &state_sizes_in_bytes); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDropoutGetStatesSize failed: " << ToString(status); @@ -2078,7 +1769,7 @@ class MIOpenDropoutDescriptor { bool state_evo = false; // input placeholder, currently not enabled bool use_mask = true; - status = wrap::miopenSetDropoutDescriptor( + status = miopenSetDropoutDescriptor( dropout_desc_ /*dropoutDesc*/, miopen_handle /*handle*/, dropout /*dropout*/, state_memory.opaque() /*states*/, state_memory.size() /*stateSizeInBytes*/, seed /*seed*/, @@ -2092,7 +1783,7 @@ class MIOpenDropoutDescriptor { } ~MIOpenDropoutDescriptor() { - auto status = wrap::miopenDestroyDropoutDescriptor(dropout_desc_); + auto status = miopenDestroyDropoutDescriptor(dropout_desc_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyDropoutDescriptor failed: " << ToString(status); @@ -2131,9 +1822,9 @@ class MIOpenRnnDescriptor : public MIOpenDescriptorCommon { miopen_dropout_desc_ = std::make_unique( miopen_handle, dropout, seed, state_allocator); // Create the RNN handle - auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_); + auto status = miopenCreateRNNDescriptor(&rnn_desc_); RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor"); - status = wrap::miopenSetRNNDescriptor_V2( + status = miopenSetRNNDescriptor_V2( rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/, num_layers /*numLayers*/, miopen_dropout_desc_->handle() /*dropoutDesc*/, @@ -2151,7 +1842,7 @@ class MIOpenRnnDescriptor : public MIOpenDescriptorCommon { } ~MIOpenRnnDescriptor() override { if (rnn_desc_) { - auto status = wrap::miopenDestroyRNNDescriptor(rnn_desc_); + auto status = miopenDestroyRNNDescriptor(rnn_desc_); RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN descriptor"); } } @@ -2236,10 +1927,10 @@ class MIOpenRnnSequenceTensorDescriptor SetFailure(absl::UnknownError(error_msg)); return; } - auto status = wrap::miopenCreateTensorDescriptor(&handle); + auto status = miopenCreateTensorDescriptor(&handle); RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor"); std::array dims = {{batch_size, data_size}}; - status = wrap::miopenSetTensorDescriptor( + status = miopenSetTensorDescriptor( handle /*tensorDesc*/, data_type /*dataType*/, 2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/); RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor"); @@ -2249,7 +1940,7 @@ class MIOpenRnnSequenceTensorDescriptor ~MIOpenRnnSequenceTensorDescriptor() override { // Only the first one needs to be destroyed. All others are the same. - auto status = wrap::miopenDestroyTensorDescriptor(handles_[0]); + auto status = miopenDestroyTensorDescriptor(handles_[0]); RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy sequence tensor descriptor"); } @@ -2286,10 +1977,10 @@ class MIOpenRnnStateTensorDescriptor batch_size_(batch_size), data_size_(data_size), data_type_(data_type) { - auto status = wrap::miopenCreateTensorDescriptor(&handle_); + auto status = miopenCreateTensorDescriptor(&handle_); RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor"); std::array dims = {{num_layers, batch_size, data_size}}; - status = wrap::miopenSetTensorDescriptor( + status = miopenSetTensorDescriptor( handle_ /*tensorDesc*/, data_type /*dataType*/, 3 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/); RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor"); @@ -2297,7 +1988,7 @@ class MIOpenRnnStateTensorDescriptor ~MIOpenRnnStateTensorDescriptor() override { if (!handle_) { - auto status = wrap::miopenDestroyTensorDescriptor(handle_); + auto status = miopenDestroyTensorDescriptor(handle_); RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN state tensor"); } } @@ -2398,7 +2089,7 @@ bool CheckRNNParameterSize( miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc, const MIOpenRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; - auto status = wrap::miopenGetRNNParamsSize( + auto status = miopenGetRNNParamsSize( miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, input_desc.handles()[0] /*xDesc*/, ¶ms_size_in_bytes /*sizeInBytes*/, rnn_desc.data_type() /*dataType*/); @@ -2417,7 +2108,7 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, DeviceMemory* workspace) { // Query the workspace size. size_t workspace_size_in_bytes = 0; - auto status = wrap::miopenGetRNNWorkspaceSize( + auto status = miopenGetRNNWorkspaceSize( miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/, &workspace_size_in_bytes /*sizeInBytes*/); @@ -2497,7 +2188,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( DeviceMemory reserve_space; if (is_training) { size_t reserve_space_size_in_bytes = 0; - auto status = wrap::miopenGetRNNTrainingReserveSize( + auto status = miopenGetRNNTrainingReserveSize( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, &reserve_space_size_in_bytes /*sizeInBytes*/); @@ -2530,7 +2221,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( // make the forward call if (!is_training) { - auto status = wrap::miopenRNNForwardInference( + auto status = miopenRNNForwardInference( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, @@ -2548,7 +2239,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( return absl::InternalError("miopenRNNForwardInference failed"); } } else { - auto status = wrap::miopenRNNForwardTraining( + auto status = miopenRNNForwardTraining( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, @@ -2661,7 +2352,7 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( } // make the backward data call - auto status = wrap::miopenRNNBackwardData( + auto status = miopenRNNBackwardData( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/, @@ -2690,7 +2381,7 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( TF_RETURN_IF_ERROR( stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call - status = wrap::miopenRNNBackwardWeights( + status = miopenRNNBackwardWeights( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, @@ -2722,16 +2413,16 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor( miopenTensorDescriptor_t input_desc = nullptr; { // Query the params size. - auto status = wrap::miopenCreateTensorDescriptor(&input_desc); + auto status = miopenCreateTensorDescriptor(&input_desc); RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create tensor descriptor"); std::array dims = {{1, rnn_desc.input_size()}}; - status = wrap::miopenSetTensorDescriptor( + status = miopenSetTensorDescriptor( input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/, 2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/); RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to set tensor descriptor"); size_t params_size = 0; - status = wrap::miopenGetRNNParamsSize( + status = miopenGetRNNParamsSize( miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, input_desc /*xDesc*/, ¶ms_size /*sizeInBytes*/, rnn_desc.data_type() /*dataType*/); @@ -2741,18 +2432,18 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor( { // Create the params descriptor. - auto status = wrap::miopenCreateTensorDescriptor(&handle_); + auto status = miopenCreateTensorDescriptor(&handle_); RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create RNN params descriptor"); - status = wrap::miopenGetRNNParamsDescriptor(miopen_handle, - rnn_desc.handle(), input_desc, - handle_, rnn_desc.data_type()); + status = + miopenGetRNNParamsDescriptor(miopen_handle, rnn_desc.handle(), + input_desc, handle_, rnn_desc.data_type()); RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to update RNN filter descriptor"); } { // Release the dummy input tensor descriptor. - auto status = wrap::miopenDestroyTensorDescriptor(input_desc); + auto status = miopenDestroyTensorDescriptor(input_desc); RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to destroy tensor descriptor"); } } @@ -2760,15 +2451,15 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor( class MIOpenCTCLossDescriptor { public: explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) { - auto status = wrap::miopenCreateCTCLossDescriptor(&handle_); + auto status = miopenCreateCTCLossDescriptor(&handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: " << ToString(status); } bool apply_softmax_layer = true; - status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0, - apply_softmax_layer); + status = + miopenSetCTCLossDescriptor(handle_, data_type, 0, apply_softmax_layer); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: " << ToString(status); @@ -2776,7 +2467,7 @@ class MIOpenCTCLossDescriptor { } ~MIOpenCTCLossDescriptor() { - auto status = wrap::miopenDestroyCTCLossDescriptor(handle_); + auto status = miopenDestroyCTCLossDescriptor(handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: " << ToString(status); @@ -2814,7 +2505,7 @@ absl::Status MIOpenSupport::DoPrepareForCtcLoss( const MIOpenRnnStateTensorDescriptor& miopen_grads_desc = static_cast(grads_desc); - auto status = wrap::miopenGetCTCLossWorkspaceSize( + auto status = miopenGetCTCLossWorkspaceSize( miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(), labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(), MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(), @@ -2871,7 +2562,7 @@ absl::Status MIOpenSupport::DoCtcLossImpl( int total_size = kNumLabels * kNumTimestamps * kBatchSize; (void)total_size; - auto status = wrap::miopenCTCLoss( + auto status = miopenCTCLoss( miopen.handle(), probs_desc.handle(), probs_data.opaque(), labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(), costs_data.opaque(), grads_desc.handle(), grads_data.opaque(), @@ -3323,8 +3014,8 @@ class RocmConvRunner : public dnn::ConvRunner { (kind == dnn::ConvolutionKind::BACKWARD_FILTER)); // #if TF_ROCM_VERSION >= 50000 if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { - wrap::miopenSetConvolutionAttribute( - conv_desc_.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + miopenSetConvolutionAttribute(conv_desc_.handle(), + MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); } // #endif } @@ -3363,14 +3054,14 @@ class RocmConvRunner : public dnn::ConvRunner { switch (kind_) { case dnn::ConvolutionKind::FORWARD: { if (use_immediate_mode_) { - status = wrap::miopenConvolutionForwardImmediate( + status = miopenConvolutionForwardImmediate( miopen.handle(), filter_desc_.handle(), filter_data.opaque(), input_desc_.handle(), input_data.opaque(), conv_desc_.handle(), output_desc_.handle(), output_data.opaque(), scratch_memory.opaque(), scratch_memory.size(), static_cast(algo_id_)); } else { - status = wrap::miopenConvolutionForward( + status = miopenConvolutionForward( miopen.handle(), &alpha, input_desc_.handle(), input_data.opaque(), filter_desc_.handle(), filter_data.opaque(), conv_desc_.handle(), @@ -3383,14 +3074,14 @@ class RocmConvRunner : public dnn::ConvRunner { } case dnn::ConvolutionKind::BACKWARD_DATA: { if (use_immediate_mode_) { - status = wrap::miopenConvolutionBackwardDataImmediate( + status = miopenConvolutionBackwardDataImmediate( miopen.handle(), output_desc_.handle(), output_data.opaque(), filter_desc_.handle(), filter_data.opaque(), conv_desc_.handle(), input_desc_.handle(), input_data.opaque(), scratch_memory.opaque(), scratch_memory.size(), static_cast(algo_id_)); } else { - status = wrap::miopenConvolutionBackwardData( + status = miopenConvolutionBackwardData( miopen.handle(), &alpha, output_desc_.handle(), output_data.opaque(), filter_desc_.handle(), filter_data.opaque(), conv_desc_.handle(), @@ -3402,14 +3093,14 @@ class RocmConvRunner : public dnn::ConvRunner { } case dnn::ConvolutionKind::BACKWARD_FILTER: { if (use_immediate_mode_) { - status = wrap::miopenConvolutionBackwardWeightsImmediate( + status = miopenConvolutionBackwardWeightsImmediate( miopen.handle(), output_desc_.handle(), output_data.opaque(), input_desc_.handle(), input_data.opaque(), conv_desc_.handle(), filter_desc_.handle(), filter_data.opaque(), scratch_memory.opaque(), scratch_memory.size(), static_cast(algo_id_)); } else { - status = wrap::miopenConvolutionBackwardWeights( + status = miopenConvolutionBackwardWeights( miopen.handle(), &alpha, output_desc_.handle(), output_data.opaque(), input_desc_.handle(), input_data.opaque(), conv_desc_.handle(), @@ -3590,8 +3281,8 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( #if TF_ROCM_VERSION >= 50000 if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { - wrap::miopenSetConvolutionAttribute( - conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + miopenSetConvolutionAttribute(conv.handle(), + MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); } #endif // First determine the number of algorithms available @@ -3599,7 +3290,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( switch (kind) { case dnn::ConvolutionKind::FORWARD: { - auto status = wrap::miopenConvolutionForwardGetSolutionCount( + auto status = miopenConvolutionForwardGetSolutionCount( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { @@ -3610,7 +3301,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount( + auto status = miopenConvolutionBackwardDataGetSolutionCount( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { @@ -3622,7 +3313,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount( + auto status = miopenConvolutionBackwardWeightsGetSolutionCount( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { @@ -3654,7 +3345,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( switch (kind) { case dnn::ConvolutionKind::FORWARD: { - auto status = wrap::miopenConvolutionForwardGetSolution( + auto status = miopenConvolutionForwardGetSolution( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), maxSolutionCount, &solutionCount, solutions.get()); @@ -3676,7 +3367,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( << ", " << solution.workspace_size << ", " << solution.solution_id << ", " << ToString(solution.algorithm); - status = wrap::miopenConvolutionForwardCompileSolution( + status = miopenConvolutionForwardCompileSolution( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), solution.solution_id); @@ -3693,7 +3384,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( } case dnn::ConvolutionKind::BACKWARD_DATA: { - auto status = wrap::miopenConvolutionBackwardDataGetSolution( + auto status = miopenConvolutionBackwardDataGetSolution( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get()); if (status != miopenStatusSuccess) { @@ -3713,7 +3404,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( << ", " << solution.workspace_size << ", " << solution.solution_id << ", " << ToString(solution.algorithm); - status = wrap::miopenConvolutionBackwardDataCompileSolution( + status = miopenConvolutionBackwardDataCompileSolution( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), solution.solution_id); @@ -3730,7 +3421,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - auto status = wrap::miopenConvolutionBackwardWeightsGetSolution( + auto status = miopenConvolutionBackwardWeightsGetSolution( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), maxSolutionCount, &solutionCount, solutions.get()); if (status != miopenStatusSuccess) { @@ -3750,7 +3441,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( << ", " << solution.workspace_size << ", " << solution.solution_id << ", " << ToString(solution.algorithm); - status = wrap::miopenConvolutionBackwardWeightsCompileSolution( + status = miopenConvolutionBackwardWeightsCompileSolution( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), solution.solution_id); @@ -3802,8 +3493,8 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( #if TF_ROCM_VERSION >= 50000 if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { - wrap::miopenSetConvolutionAttribute( - conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); + miopenSetConvolutionAttribute(conv.handle(), + MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); } #endif @@ -3811,7 +3502,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( size_t scratch_memory_size = 0; switch (kind) { case dnn::ConvolutionKind::FORWARD: { - auto status = wrap::miopenConvolutionForwardGetWorkSpaceSize( + auto status = miopenConvolutionForwardGetWorkSpaceSize( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { @@ -3822,7 +3513,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - auto status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize( + auto status = miopenConvolutionBackwardDataGetWorkSpaceSize( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { @@ -3833,7 +3524,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - auto status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize( + auto status = miopenConvolutionBackwardWeightsGetWorkSpaceSize( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { @@ -3887,7 +3578,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( switch (kind) { case dnn::ConvolutionKind::FORWARD: { - auto status = wrap::miopenFindConvolutionForwardAlgorithm( + auto status = miopenFindConvolutionForwardAlgorithm( miopen.handle(), input_nd.handle(), input_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(), output_nd.handle(), output_data.opaque(), requestedAlgorithmCount, @@ -3901,7 +3592,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm( + auto status = miopenFindConvolutionBackwardDataAlgorithm( miopen.handle(), output_nd.handle(), output_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(), input_nd.handle(), input_data.opaque(), requestedAlgorithmCount, @@ -3915,7 +3606,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm( + auto status = miopenFindConvolutionBackwardWeightsAlgorithm( miopen.handle(), output_nd.handle(), output_data.opaque(), input_nd.handle(), input_data.opaque(), conv.handle(), filter.handle(), filter_data.opaque(), requestedAlgorithmCount, @@ -4052,7 +3743,7 @@ absl::Status MIOpenSupport::DoBatchNormalizationForwardImpl( auto status = miopenStatusInvalidValue; if (is_training) { - status = wrap::miopenBatchNormalizationForwardTraining( + status = miopenBatchNormalizationForwardTraining( miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), const_cast(scale.opaque()), const_cast(offset.opaque()), @@ -4060,7 +3751,7 @@ absl::Status MIOpenSupport::DoBatchNormalizationForwardImpl( epsilon, saved_mean->opaque(), saved_inv_var->opaque()); } else { const void* maybe_inv_var = estimated_variance.opaque(); - status = wrap::miopenBatchNormalizationForwardInference( + status = miopenBatchNormalizationForwardInference( miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), const_cast(scale.opaque()), const_cast(offset.opaque()), @@ -4153,7 +3844,7 @@ absl::Status MIOpenSupport::DoBatchNormalizationBackwardImpl( float one = 1.0; float zero = 0.0; - auto status = wrap::miopenBatchNormalizationBackward( + auto status = miopenBatchNormalizationBackward( miopen.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(), @@ -4423,7 +4114,7 @@ absl::Status MIOpenSupport::DoPoolForward( size_t workspace_size = 0; if (m_pooling_cache_enabled && element_type == dnn::DataType::kFloat) { do_backward = true; - auto status = wrap::miopenPoolingGetWorkSpaceSizeV2( + auto status = miopenPoolingGetWorkSpaceSizeV2( pooling_desc.handle(), dest_desc.handle(), &workspace_size); if (status != miopenStatusSuccess) { return absl::InternalError(absl::StrCat( @@ -4449,7 +4140,7 @@ absl::Status MIOpenSupport::DoPoolForward( } } - auto status = wrap::miopenPoolingForward( + auto status = miopenPoolingForward( miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), output_data.opaque(), do_backward, workspace, workspace_size); @@ -4578,7 +4269,7 @@ absl::Status MIOpenSupport::DoPoolBackward( PoolingWorkspaceDescriptor* pdesc = nullptr; size_t workspace_size_in_bytes = 0; - auto status = wrap::miopenPoolingGetWorkSpaceSizeV2( + auto status = miopenPoolingGetWorkSpaceSizeV2( pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes); if (status != miopenStatusSuccess) { return absl::InternalError(absl::StrCat( @@ -4632,7 +4323,7 @@ absl::Status MIOpenSupport::DoPoolBackward( "backward pooling"; } - status = wrap::miopenPoolingForward( + status = miopenPoolingForward( miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true, workspace.opaque(), workspace_size_in_bytes); @@ -4646,7 +4337,7 @@ absl::Status MIOpenSupport::DoPoolBackward( } } - status = wrap::miopenPoolingBackward( + status = miopenPoolingBackward( miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), @@ -4695,10 +4386,10 @@ bool MIOpenSupport::DoNormalizeWithDimensions( // Beta is the scaling factor for output. float beta = 0.0f; - auto status = wrap::miopenLRNForward( - miopen.handle(), normalize.handle(), &alpha, dims.handle(), - input_data.opaque(), &beta, dims.handle(), output_data->opaque(), false, - nullptr); + auto status = + miopenLRNForward(miopen.handle(), normalize.handle(), &alpha, + dims.handle(), input_data.opaque(), &beta, dims.handle(), + output_data->opaque(), false, nullptr); if (status != miopenStatusSuccess) { LOG(ERROR) << "failed to run miopenLRNForward"; return false; @@ -4734,7 +4425,7 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( DeviceMemory workspace; size_t workspace_size_in_bytes = 0; auto status = - wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes); + miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes); if (status != miopenStatusSuccess) { LOG(ERROR) << "failed to obtain workspace size for miopenLRNBackward"; @@ -4781,22 +4472,22 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( "backward LRN"; } - status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha, - dims.handle(), raw_data.opaque(), &beta, - dims.handle(), dest2.opaque(), true, - workspace.opaque()); + status = + miopenLRNForward(miopen.handle(), normalize.handle(), &alpha, + dims.handle(), raw_data.opaque(), &beta, dims.handle(), + dest2.opaque(), true, workspace.opaque()); if (status != miopenStatusSuccess) { LOG(ERROR) << "failed to run miopenLRNForward"; return false; } - status = wrap::miopenLRNBackward( - miopen.handle(), normalize.handle(), &alpha, dims.handle(), - normalized_data.opaque(), dims.handle(), - normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), - &beta, dims.handle(), raw_variable_gradient->opaque(), - workspace.opaque()); + status = + miopenLRNBackward(miopen.handle(), normalize.handle(), &alpha, + dims.handle(), normalized_data.opaque(), dims.handle(), + normalized_variable_gradient.opaque(), dims.handle(), + raw_data.opaque(), &beta, dims.handle(), + raw_variable_gradient->opaque(), workspace.opaque()); if (status != miopenStatusSuccess) { LOG(ERROR) << "failed to run miopenLRNBackward"; @@ -4816,7 +4507,7 @@ bool MIOpenSupport::DeriveOutputBatchDescriptor( int dn = batch_descriptor.ndims() + 2; std::vector dims(dn); // in BDYX - auto status = wrap::miopenGetConvolutionNdForwardOutputDim( + auto status = miopenGetConvolutionNdForwardOutputDim( conv.handle(), input_nd.handle(), filter.handle(), &dn, dims.data()); if (status != miopenStatusSuccess) { LOG(ERROR) << "could not get output tensor for convolution: " @@ -4873,10 +4564,10 @@ class RocmFusedConvRunner : public dnn::FusedConvRunner { } miopenStatus_t status; - status = wrap::miopenExecuteFusionPlan( - miopen.handle(), fusion_plan_.fusion_plan_, input_nd_.handle(), - input_data.opaque(), output_nd_.handle(), output_data.opaque(), - fusion_plan_.fusion_args_); + status = miopenExecuteFusionPlan(miopen.handle(), fusion_plan_.fusion_plan_, + input_nd_.handle(), input_data.opaque(), + output_nd_.handle(), output_data.opaque(), + fusion_plan_.fusion_args_); if (status != miopenStatusSuccess) { LOG(ERROR) << "Failed to enqueue fused convolution on stream: " diff --git a/xla/stream_executor/rocm/rocm_driver_wrapper.h b/xla/stream_executor/rocm/rocm_driver_wrapper.h deleted file mode 100644 index a6716a61002a0..0000000000000 --- a/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file wraps rocm driver calls with dso loader so that we don't need to -// have explicit linking to librocm. All TF rocm driver usage should route -// through this wrapper. - -#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ -#define XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ - -#include "rocm/include/hip/hip_runtime.h" -#include "rocm/rocm_config.h" -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" - -namespace stream_executor { -namespace wrap { -#ifdef PLATFORM_GOOGLE -// Use static linked library -#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ - template \ - auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ - return ::hipSymbolName(args...); \ - } - -// This macro wraps a global identifier, given by hipSymbolName, in a callable -// structure that loads the DLL symbol out of the DSO handle in a thread-safe -// manner on first use. This dynamic loading technique is used to avoid DSO -// dependencies on vendor libraries which may or may not be available in the -// deployed binary environment. -#else -#define TO_STR_(x) #x -#define TO_STR(x) TO_STR_(x) - -#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ - template \ - auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char *kName = TO_STR(hipSymbolName); \ - void *f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \ - &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in HIP DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ - } -#endif - -// clang-format off -// IMPORTANT: if you add a new HIP API to this list, please notify -// the rocm-profiler developers to track the API traces. -#define HIP_ROUTINE_EACH(__macro) \ - __macro(hipCtxGetDevice) \ - __macro(hipCtxSetCurrent) \ - __macro(hipCtxEnablePeerAccess) \ - __macro(hipDeviceCanAccessPeer) \ - __macro(hipDeviceEnablePeerAccess) \ - __macro(hipDeviceGet) \ - __macro(hipDeviceGetAttribute) \ - __macro(hipDeviceGetName) \ - __macro(hipDeviceGetPCIBusId) \ - __macro(hipDeviceGetSharedMemConfig) \ - __macro(hipDeviceGetStreamPriorityRange) \ - __macro(hipDeviceGraphMemTrim) \ - __macro(hipDevicePrimaryCtxGetState) \ - __macro(hipDevicePrimaryCtxSetFlags) \ - __macro(hipDevicePrimaryCtxRetain) \ - __macro(hipDevicePrimaryCtxRelease) \ - __macro(hipDeviceSetSharedMemConfig) \ - __macro(hipDeviceSynchronize) \ - __macro(hipDeviceTotalMem) \ - __macro(hipDriverGetVersion) \ - __macro(hipEventCreateWithFlags) \ - __macro(hipEventDestroy) \ - __macro(hipEventElapsedTime) \ - __macro(hipEventQuery) \ - __macro(hipEventRecord) \ - __macro(hipEventSynchronize) \ - __macro(hipFree) \ - __macro(hipFuncSetCacheConfig) \ - __macro(hipFuncGetAttribute) \ - __macro(hipFuncSetAttribute) \ - __macro(hipGetDevice) \ - __macro(hipGetDeviceCount) \ - __macro(hipGetDeviceProperties) \ - __macro(hipGetErrorString) \ - __macro(hipGetLastError) \ - __macro(hipGraphAddKernelNode) \ - __macro(hipGraphAddChildGraphNode) \ - __macro(hipGraphAddEmptyNode) \ - __macro(hipGraphAddMemAllocNode) \ - __macro(hipGraphAddMemcpyNode1D) \ - __macro(hipGraphAddMemsetNode) \ - __macro(hipGraphAddMemFreeNode) \ - __macro(hipGraphCreate) \ - __macro(hipGraphDebugDotPrint) \ - __macro(hipGraphDestroy) \ - __macro(hipGraphGetNodes) \ - __macro(hipGraphExecChildGraphNodeSetParams) \ - __macro(hipGraphExecDestroy) \ - __macro(hipGraphExecKernelNodeSetParams) \ - __macro(hipGraphExecMemcpyNodeSetParams1D) \ - __macro(hipGraphExecMemsetNodeSetParams) \ - __macro(hipGraphExecUpdate) \ - __macro(hipGraphInstantiate) \ - __macro(hipGraphMemAllocNodeGetParams) \ - __macro(hipGraphLaunch) \ - __macro(hipGraphNodeGetType) \ - __macro(hipGraphNodeSetEnabled) \ - __macro(hipHostFree) \ - __macro(hipHostMalloc) \ - __macro(hipHostRegister) \ - __macro(hipHostUnregister) \ - __macro(hipInit) \ - __macro(hipKernelNameRefByPtr) \ - __macro(hipLaunchHostFunc) \ - __macro(hipLaunchKernel) \ - __macro(hipMalloc) \ - __macro(hipMallocManaged) \ - __macro(hipExtMallocWithFlags) \ - __macro(hipMemGetAddressRange) \ - __macro(hipMemGetInfo) \ - __macro(hipMemcpyDtoD) \ - __macro(hipMemcpyDtoDAsync) \ - __macro(hipMemcpyDtoH) \ - __macro(hipMemcpyDtoHAsync) \ - __macro(hipMemcpyHtoD) \ - __macro(hipMemcpyHtoDAsync) \ - __macro(hipMemset) \ - __macro(hipMemsetD8) \ - __macro(hipMemsetD16) \ - __macro(hipMemsetD32) \ - __macro(hipMemsetAsync) \ - __macro(hipMemsetD8Async) \ - __macro(hipMemsetD16Async) \ - __macro(hipMemsetD32Async) \ - __macro(hipModuleGetFunction) \ - __macro(hipModuleGetGlobal) \ - __macro(hipModuleLaunchKernel) \ - __macro(hipModuleLoadData) \ - __macro(hipModuleUnload) \ - __macro(hipModuleOccupancyMaxActiveBlocksPerMultiprocessor) \ - __macro(hipModuleOccupancyMaxPotentialBlockSize) \ - __macro(hipPointerGetAttribute) \ - __macro(hipPointerGetAttributes) \ - __macro(hipRuntimeGetVersion) \ - __macro(hipSetDevice) \ - __macro(hipStreamAddCallback) \ - __macro(hipStreamBeginCapture) \ - __macro(hipStreamCreateWithFlags) \ - __macro(hipStreamCreateWithPriority) \ - __macro(hipStreamDestroy) \ - __macro(hipStreamEndCapture) \ - __macro(hipStreamIsCapturing) \ - __macro(hipStreamQuery) \ - __macro(hipStreamSynchronize) \ - __macro(hipStreamWaitEvent) // clang-format on - -HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP) - -#if TF_ROCM_VERSION >= 60200 - -// clang-format off -#define HIP_ROUTINE_EACH_62(__macro) \ - __macro(hipGetFuncBySymbol) \ - __macro(hipStreamBeginCaptureToGraph) -// clang-format on - -HIP_ROUTINE_EACH_62(STREAM_EXECUTOR_HIP_WRAP) - -#undef HIP_ROUTINE_EACH_62 -#endif // TF_ROCM_VERSION >= 60200 - -#undef HIP_ROUTINE_EACH -#undef STREAM_EXECUTOR_HIP_WRAP -#undef TO_STR -#undef TO_STR_ - -} // namespace wrap -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ diff --git a/xla/stream_executor/rocm/rocm_event.cc b/xla/stream_executor/rocm/rocm_event.cc index aa7e62e5190df..45cd48e997074 100644 --- a/xla/stream_executor/rocm/rocm_event.cc +++ b/xla/stream_executor/rocm/rocm_event.cc @@ -27,7 +27,6 @@ limitations under the License. #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -39,7 +38,7 @@ absl::Status WaitStreamOnEvent(StreamExecutor *executor, hipStream_t stream, hipEvent_t event) { std::unique_ptr activation = executor->Activate(); TF_RETURN_IF_ERROR( - ToStatus(wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */), + ToStatus(hipStreamWaitEvent(stream, event, 0 /* = flags */), "could not wait stream on event")); return absl::OkStatus(); } @@ -61,7 +60,7 @@ absl::StatusOr InitEvent(StreamExecutor *executor, std::unique_ptr activation = executor->Activate(); hipEvent_t event; - hipError_t res = wrap::hipEventCreateWithFlags(&event, hipflags); + hipError_t res = hipEventCreateWithFlags(&event, hipflags); if (res == hipSuccess) { return event; @@ -80,7 +79,7 @@ void DestroyEvent(StreamExecutor *executor, hipEvent_t event) { } std::unique_ptr activation = executor->Activate(); - hipError_t res = wrap::hipEventDestroy(event); + hipError_t res = hipEventDestroy(event); if (res != hipSuccess) { LOG(ERROR) << absl::StrFormat( @@ -93,7 +92,7 @@ void DestroyEvent(StreamExecutor *executor, hipEvent_t event) { Event::Status RocmEvent::PollForStatus() { std::unique_ptr activated = executor_->Activate(); - hipError_t res = wrap::hipEventQuery(handle_); + hipError_t res = hipEventQuery(handle_); if (res == hipSuccess) { return Event::Status::kComplete; diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 269a2afe6194a..f88853998b2ef 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -71,7 +71,6 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/rocm/rocm_command_buffer.h" #include "xla/stream_executor/rocm/rocm_context.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_kernel.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" @@ -147,7 +146,7 @@ absl::StatusOr LoadHsaco(Context* context, GetDriverExecutor()->Schedule( [context, hsaco_contents, &module, &returned_status, ¬ification]() { ScopedActivateContext activation(context); - hipError_t res = wrap::hipModuleLoadData(&module, hsaco_contents); + hipError_t res = hipModuleLoadData(&module, hsaco_contents); if (res != hipSuccess) { returned_status = absl::InternalError( @@ -174,7 +173,7 @@ absl::StatusOr GetModuleFunction(Context* context, CHECK(module != nullptr && kernel_name != nullptr); hipFunction_t function; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name), + ToStatus(hipModuleGetFunction(&function, module, kernel_name), "Failed to get kernel")); return function; } @@ -189,14 +188,14 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module, ScopedActivateContext activated(context); CHECK(module != nullptr && symbol_name != nullptr && (dptr != nullptr || bytes != nullptr)); - return ToStatus(wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name), + return ToStatus(hipModuleGetGlobal(dptr, bytes, module, symbol_name), absl::StrCat("Failed to get symbol '", symbol_name, "'")); } // Unloads module from the current context via cuModuleUnload. void UnloadRocmModule(Context* context, hipModule_t module) { ScopedActivateContext activated(context); - hipError_t res = wrap::hipModuleUnload(module); + hipError_t res = hipModuleUnload(module); if (res != hipSuccess) { LOG(ERROR) << "failed to unload module " << module << "; leaking: " << ToString(res); @@ -208,7 +207,7 @@ absl::StatusOr GetDeviceName(hipDevice_t device) { static const size_t kCharLimit = 64; absl::InlinedVector chars(kCharLimit); TF_RETURN_IF_ERROR( - ToStatus(wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device), + ToStatus(hipDeviceGetName(chars.begin(), kCharLimit - 1, device), "Failed to get device name")); chars[kCharLimit - 1] = '\0'; return chars.begin(); @@ -216,7 +215,7 @@ absl::StatusOr GetDeviceName(hipDevice_t device) { absl::StatusOr GetGpuISAVersion(hipDevice_t device) { hipDeviceProp_t props; - hipError_t result = wrap::hipGetDeviceProperties(&props, device); + hipError_t result = hipGetDeviceProperties(&props, device); if (result == hipSuccess) { std::string gcnName = props.gcnArchName; std::vector tokens = absl::StrSplit(gcnName, ':'); @@ -235,7 +234,7 @@ absl::StatusOr GetGpuISAVersion(hipDevice_t device) { // for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack- absl::StatusOr GetGpuGCNArchName(hipDevice_t device) { hipDeviceProp_t props; - hipError_t result = wrap::hipGetDeviceProperties(&props, device); + hipError_t result = hipGetDeviceProperties(&props, device); if (result == hipSuccess) { return props.gcnArchName; } @@ -249,7 +248,7 @@ template static absl::StatusOr GetSimpleAttribute(hipDevice_t device, hipDeviceAttribute_t attribute) { int value = -1; - hipError_t result = wrap::hipDeviceGetAttribute(&value, attribute, device); + hipError_t result = hipDeviceGetAttribute(&value, attribute, device); if (result != hipSuccess) { return absl::NotFoundError( absl::StrCat("could not retrieve ROCM device attribute (", attribute, @@ -292,22 +291,19 @@ absl::StatusOr GetThreadsPerWarp(hipDevice_t device) { absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) { int value; - TF_RETURN_IF_ERROR( - ToStatus(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimX, device), - "failed to query max grid dim x")); + TF_RETURN_IF_ERROR(ToStatus( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device), + "failed to query max grid dim x")); *x = value; - TF_RETURN_IF_ERROR( - ToStatus(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimY, device), - "failed to query max grid dim y")); + TF_RETURN_IF_ERROR(ToStatus( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device), + "failed to query max grid dim y")); *y = value; - TF_RETURN_IF_ERROR( - ToStatus(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimZ, device), - "failed to query max grid dim z")); + TF_RETURN_IF_ERROR(ToStatus( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device), + "failed to query max grid dim z")); *z = value; return absl::OkStatus(); } @@ -320,7 +316,7 @@ absl::StatusOr GetMaxRegistersPerMultiprocessor(hipDevice_t device) { // Returns the device associated with the given device_ordinal. absl::StatusOr GetDevice(int device_ordinal) { hipDevice_t device; - hipError_t res = wrap::hipDeviceGet(&device, device_ordinal); + hipError_t res = hipDeviceGet(&device, device_ordinal); if (res == hipSuccess) { return device; } @@ -333,7 +329,7 @@ absl::StatusOr GetDevice(int device_ordinal) { absl::StatusOr DeviceFromContext(Context* context) { ScopedActivateContext activated(context); hipDevice_t device = -1; - hipError_t result = wrap::hipCtxGetDevice(&device); + hipError_t result = hipCtxGetDevice(&device); if (result == hipSuccess) return device; return absl::InternalError( @@ -342,7 +338,7 @@ absl::StatusOr DeviceFromContext(Context* context) { bool CanEnablePeerAccess(hipDevice_t from, hipDevice_t to) { int can_access_peer = -1; - hipError_t result = wrap::hipDeviceCanAccessPeer(&can_access_peer, from, to); + hipError_t result = hipDeviceCanAccessPeer(&can_access_peer, from, to); if (result != hipSuccess) { LOG(ERROR) << "failed to detect peer access capability: " << ToString(result); @@ -378,18 +374,18 @@ absl::Status EnablePeerAccess(Context* from, Context* to) { ScopedActivateContext activated(from); hipError_t result = - wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */); + hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */); if (result == hipErrorPeerAccessAlreadyEnabled) { // hipGetLastError is used to reset per thread error state, // as hipGetLastError would get the recent error code since rocm7 even the // last call is successful. - (void)wrap::hipGetLastError(); + (void)hipGetLastError(); } else if (result != hipSuccess) { return absl::InternalError( absl::StrFormat("failed to enable peer access from %d to %d: %s", from->device_ordinal(), to->device_ordinal(), - wrap::hipGetErrorString(result))); + hipGetErrorString(result))); } return absl::OkStatus(); @@ -400,8 +396,7 @@ std::string GetPCIBusID(hipDevice_t device) { static const int kBufferSize = 64; absl::InlinedVector chars(kBufferSize); chars[kBufferSize - 1] = '\0'; - hipError_t res = - wrap::hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); + hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res); return pci_bus_id; @@ -413,15 +408,14 @@ std::string GetPCIBusID(hipDevice_t device) { absl::StatusOr IsEccEnabled(hipDevice_t device) { int value = 0; TF_RETURN_IF_ERROR(ToStatus( - wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeEccEnabled, device), + hipDeviceGetAttribute(&value, hipDeviceAttributeEccEnabled, device), "hipDeviceGetAttribute(hipDeviceAttributeEccEnabled) failed")); return value != 0; } bool GetDeviceProperties(hipDeviceProp_t* device_properties, int device_ordinal) { - hipError_t res = - wrap::hipGetDeviceProperties(device_properties, device_ordinal); + hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal); if (res != hipSuccess) { LOG(ERROR) << "failed to query device properties: " << ToString(res); return false; @@ -445,10 +439,9 @@ void* DeviceAllocate(Context* context, uint64_t bytes, // execution. This type of memory is only used in P2P communication to solve // the cache coherence issue for some archs (e.g., MI200); most of the time, // you don't have to use it. - res = wrap::hipExtMallocWithFlags(&device_mem, bytes, - hipDeviceMallocFinegrained); + res = hipExtMallocWithFlags(&device_mem, bytes, hipDeviceMallocFinegrained); } else { - res = wrap::hipMalloc(&device_mem, bytes); + res = hipMalloc(&device_mem, bytes); } if (res != hipSuccess) { // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator @@ -469,7 +462,7 @@ void* DeviceAllocate(Context* context, uint64_t bytes, void DeviceDeallocate(Context* context, void* location) { ScopedActivateContext activation(context); hipDeviceptr_t pointer = absl::bit_cast(location); - hipError_t res = wrap::hipFree(pointer); + hipError_t res = hipFree(pointer); if (res != hipSuccess) { LOG(ERROR) << "failed to free device memory at " << location << "; result: " << ToString(res); @@ -485,7 +478,7 @@ absl::StatusOr HostAllocate(Context* context, uint64_t bytes) { void* host_mem = nullptr; // "Portable" memory is visible to all ROCM contexts. Safe for our use model. TF_RETURN_IF_ERROR( - ToStatus(wrap::hipHostMalloc(&host_mem, bytes, hipHostMallocPortable), + ToStatus(hipHostMalloc(&host_mem, bytes, hipHostMallocPortable), "failed to allocate host memory")); return host_mem; } @@ -497,7 +490,7 @@ absl::StatusOr> AllocateHostMemory( << size << " bytes of host memory"; return std::make_unique( ptr, size, [rocm_context](void* location, uint64_t size) { - hipError_t res = wrap::hipHostFree(location); + hipError_t res = hipHostFree(location); if (res != hipSuccess) { LOG(ERROR) << "error deallocating host memory at " << location << ": " << ToString(res); @@ -530,7 +523,7 @@ absl::StatusOr RocmExecutor::GetMemoryRange( const DeviceMemoryBase& location) { hipDeviceptr_t device_pointer; size_t size; - hipError_t result = wrap::hipMemGetAddressRange( + hipError_t result = hipMemGetAddressRange( &device_pointer, &size, const_cast(location.opaque())); if (result == hipSuccess) { return DeviceMemoryBase(device_pointer, size); @@ -691,9 +684,9 @@ absl::StatusOr> RocmExecutor::LoadKernel( #if TF_ROCM_VERSION >= 60200 hipFunction_t func; - TF_RETURN_IF_ERROR(ToStatus( - wrap::hipGetFuncBySymbol(&func, spec.in_process_symbol()->symbol), - "Failed call to hipGetFuncBySymbol")); + TF_RETURN_IF_ERROR( + ToStatus(hipGetFuncBySymbol(&func, spec.in_process_symbol()->symbol), + "Failed call to hipGetFuncBySymbol")); rocm_kernel->set_gpu_function(func); #else rocm_kernel->set_gpu_function( @@ -800,9 +793,9 @@ RocmExecutor::CreateMemoryAllocator(MemoryType type) { std::unique_ptr activation = Activate(); hipDeviceptr_t result = nullptr; // "managed" memory is visible to both CPU and GPU. - TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMallocManaged(&result, size, hipMemAttachGlobal), - "Failed to allocate managed memory")); + TF_RETURN_IF_ERROR( + ToStatus(hipMallocManaged(&result, size, hipMemAttachGlobal), + "Failed to allocate managed memory")); void* ptr = reinterpret_cast(result); VLOG(2) << "allocated " << ptr << " for context " << rocm_context_ << " of " << size << " bytes in unified memory"; @@ -811,7 +804,7 @@ RocmExecutor::CreateMemoryAllocator(MemoryType type) { std::unique_ptr activation = Activate(); hipDeviceptr_t pointer = absl::bit_cast(location); - hipError_t res = wrap::hipFree(pointer); + hipError_t res = hipFree(pointer); if (res != hipSuccess) { LOG(ERROR) << "failed to free unified memory at " << location << "; result: " << ToString(res); @@ -826,7 +819,7 @@ RocmExecutor::CreateMemoryAllocator(MemoryType type) { [](uint64_t size) -> absl::StatusOr> { void* ptr = nullptr; - auto hipResult = wrap::hipMalloc(&ptr, size); + auto hipResult = hipMalloc(&ptr, size); if (hipResult != hipSuccess) { return absl::InternalError(absl::StrFormat( "failed to allocate %s (%llu bytes) from device collective " @@ -839,7 +832,7 @@ RocmExecutor::CreateMemoryAllocator(MemoryType type) { << " bytes of collective memory"; return std::make_unique( ptr, size, [](void* location, uint64_t size) { - auto status = wrap::hipFree(location); + auto status = hipFree(location); if (status != hipSuccess) { LOG(ERROR) << "failed to free collective memory at " << location << "; result: " << status; @@ -897,11 +890,10 @@ absl::Status RocmExecutor::SynchronousMemZero(DeviceMemoryBase* location, hipDeviceptr_t rocm_location = AsROCmDevicePtr(location); if (reinterpret_cast(location->opaque()) % sizeof(uint32_t) == 0 && size % sizeof(uint32_t) == 0) { - return ToStatus( - wrap::hipMemsetD32(rocm_location, 0x0, size / sizeof(uint32_t)), - "Failed to memset memory"); + return ToStatus(hipMemsetD32(rocm_location, 0x0, size / sizeof(uint32_t)), + "Failed to memset memory"); } - return ToStatus(wrap::hipMemsetD8(rocm_location, 0x0, size), + return ToStatus(hipMemsetD8(rocm_location, 0x0, size), "Failed to memset memory"); } @@ -910,8 +902,8 @@ absl::Status RocmExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, uint64_t size) { std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMemcpyHtoD(AsROCmDevicePtr(gpu_dst), const_cast(host_src), - size), + hipMemcpyHtoD(AsROCmDevicePtr(gpu_dst), const_cast(host_src), + size), absl::StrFormat( "failed to synchronous memcpy from host to device: Gpu dst: %p;" " host src: %p; size: %llu=0x%llx", @@ -925,7 +917,7 @@ absl::Status RocmExecutor::SynchronousMemcpy(void* host_dst, uint64_t size) { std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMemcpyDtoH(host_dst, AsROCmDevicePtr(gpu_src), size), + hipMemcpyDtoH(host_dst, AsROCmDevicePtr(gpu_src), size), absl::StrFormat("failed to synchronous memcpy from device to host: " "host dst: %p; Gpu src: %p; size: %llu=0x%llx", host_dst, AsROCmDevicePtr(gpu_src), size, size))); @@ -1180,12 +1172,12 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_compile_time_toolkit_version( SemanticVersion{HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH}); int32_t runtime_version; - TF_RETURN_IF_ERROR(ToStatus(wrap::hipRuntimeGetVersion(&runtime_version), + TF_RETURN_IF_ERROR(ToStatus(hipRuntimeGetVersion(&runtime_version), "Failed call to hipRuntimeGetVersion")); desc.set_runtime_version( ParseRocmVersion(runtime_version).value_or(SemanticVersion{0, 0, 0})); int32_t driver_version; - TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&driver_version), + TF_RETURN_IF_ERROR(ToStatus(hipDriverGetVersion(&driver_version), "Could not get driver version")); desc.set_driver_version( ParseRocmVersion(driver_version).value_or(SemanticVersion{0, 0, 0})); @@ -1210,7 +1202,7 @@ absl::StatusOr RocmExecutor::GetPointerMemorySpace( hipDeviceptr_t pointer = reinterpret_cast(const_cast(ptr)); unsigned int value; - hipError_t result = wrap::hipPointerGetAttribute( + hipError_t result = hipPointerGetAttribute( &value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer); if (result == hipSuccess) { switch (value) { diff --git a/xla/stream_executor/rocm/rocm_fft.cc b/xla/stream_executor/rocm/rocm_fft.cc index cb69633296aa2..4d0143b9c75b8 100644 --- a/xla/stream_executor/rocm/rocm_fft.cc +++ b/xla/stream_executor/rocm/rocm_fft.cc @@ -35,11 +35,6 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/logging.h" -#ifndef PLATFORM_GOOGLE -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" -#endif - namespace stream_executor { namespace gpu { @@ -47,12 +42,8 @@ using rocm::ROCMComplex; namespace wrap { -#ifdef PLATFORM_GOOGLE -// This macro wraps a global identifier, given by __name, in a callable -// structure that loads the DLL symbol out of the DSO handle in a thread-safe -// manner on first use. This dynamic loading technique is used to avoid DSO -// dependencies on vendor libraries which may or may not be available in the -// deployed binary environment. +namespace { + #define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ struct WrapperShim__##__name { \ template \ @@ -62,38 +53,6 @@ namespace wrap { } \ } __name; -#else - -#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ - struct DynLoadShim__##__name { \ - static const char *kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void *GetDsoHandle() { \ - auto s = tsl::internal::CachedDsoLoader::GetHipfftDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void *f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocfft DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - hipfftResult operator()(StreamExecutor *parent, Args... args) { \ - std::unique_ptr activation = parent->Activate(); \ - return DynLoad()(args...); \ - } \ - } __name; \ - const char *DynLoadShim__##__name::kName = #__name; - -#endif - // clang-format off #define ROCFFT_ROUTINE_EACH(__macro) \ __macro(hipfftDestroy) \ @@ -105,13 +64,9 @@ namespace wrap { __macro(hipfftCreate) \ __macro(hipfftSetAutoAllocation) \ __macro(hipfftSetWorkArea) \ - __macro(hipfftGetSize1d) \ __macro(hipfftMakePlan1d) \ - __macro(hipfftGetSize2d) \ __macro(hipfftMakePlan2d) \ - __macro(hipfftGetSize3d) \ __macro(hipfftMakePlan3d) \ - __macro(hipfftGetSizeMany) \ __macro(hipfftMakePlanMany) \ __macro(hipfftExecD2Z) \ __macro(hipfftExecZ2D) \ @@ -124,6 +79,7 @@ namespace wrap { ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP) +} // namespace } // namespace wrap namespace { diff --git a/xla/stream_executor/rocm/rocm_kernel.cc b/xla/stream_executor/rocm/rocm_kernel.cc index cc10638d09c78..1c517f59cdee2 100644 --- a/xla/stream_executor/rocm/rocm_kernel.cc +++ b/xla/stream_executor/rocm/rocm_kernel.cc @@ -25,10 +25,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" @@ -42,7 +42,7 @@ namespace { absl::Status FuncGetAttribute(hipFunction_attribute attribute, hipFunction_t func, int* attribute_value) { return ToStatus( - wrap::hipFuncGetAttribute(attribute_value, attribute, func), + hipFuncGetAttribute(attribute_value, attribute, func), absl::StrCat("Failed to query kernel attribute: ", attribute)); } @@ -58,7 +58,7 @@ absl::StatusOr RocmKernel::GetMaxOccupiedBlocksPerCore( int max_blocks = 0; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( + ToStatus(hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( &max_blocks, rocm_function_, threads_per_block, dynamic_shared_memory_bytes), "Failed to calculate maximal active blocks per SM")); diff --git a/xla/stream_executor/rocm/rocm_platform.cc b/xla/stream_executor/rocm/rocm_platform.cc index cf2def95771ed..fe06fa845e573 100644 --- a/xla/stream_executor/rocm/rocm_platform.cc +++ b/xla/stream_executor/rocm/rocm_platform.cc @@ -27,7 +27,6 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_executor.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/rocm/rocm_status.h" @@ -41,7 +40,7 @@ namespace { // Actually performs the work of ROCM initialization. Wrapped up in one-time // execution guard. static absl::Status InternalInitialize() { - hipError_t res = wrap::hipInit(0 /* = flags */); + hipError_t res = hipInit(0 /* = flags */); if (res == hipSuccess) { return absl::OkStatus(); @@ -75,7 +74,7 @@ int ROCmPlatform::VisibleDeviceCount() const { } int device_count = 0; - hipError_t res = wrap::hipGetDeviceCount(&device_count); + hipError_t res = hipGetDeviceCount(&device_count); if (res != hipSuccess) { LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res); return 0; diff --git a/xla/stream_executor/rocm/rocm_solver_context.cc b/xla/stream_executor/rocm/rocm_solver_context.cc index 9df1d382205b9..03ee36c23b0eb 100644 --- a/xla/stream_executor/rocm/rocm_solver_context.cc +++ b/xla/stream_executor/rocm/rocm_solver_context.cc @@ -50,8 +50,8 @@ struct GpuComplexT { using gpuDataType_t = hipDataType; #if TF_ROCM_VERSION >= 40500 -#define GPU_SOLVER_CONTEXT_PREFIX wrap::hipsolver -#define GPU_SOLVER_PREFIX wrap::hipsolver +#define GPU_SOLVER_CONTEXT_PREFIX hipsolver +#define GPU_SOLVER_PREFIX hipsolver template <> struct GpuComplexT> { diff --git a/xla/stream_executor/rocm/rocm_solver_context.h b/xla/stream_executor/rocm/rocm_solver_context.h index c1e8acf9a041c..f9ff78bc3292a 100644 --- a/xla/stream_executor/rocm/rocm_solver_context.h +++ b/xla/stream_executor/rocm/rocm_solver_context.h @@ -29,7 +29,7 @@ limitations under the License. #include "rocm/rocm_config.h" // Macros to ease the transition from rocsolver to hipsolver. #if TENSORFLOW_USE_HIPSOLVER -#include "xla/stream_executor/rocm/hipsolver_wrapper.h" +#include "rocm/include/hipsolver/hipsolver.h" using gpusolverHandle_t = hipsolverHandle_t; #else // TENSORFLOW_USE_ROCSOLVER #include "xla/stream_executor/rocm/rocblas_wrapper.h" diff --git a/xla/stream_executor/rocm/rocm_stream.cc b/xla/stream_executor/rocm/rocm_stream.cc index 0ddfdccfe0eaf..8ab0a6672c757 100644 --- a/xla/stream_executor/rocm/rocm_stream.cc +++ b/xla/stream_executor/rocm/rocm_stream.cc @@ -40,7 +40,6 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_kernel.h" #include "xla/stream_executor/rocm/rocm_status.h" @@ -57,7 +56,7 @@ int GetGpuStreamPriority(StreamExecutor* executor, return 0; } int lowest, highest; - hipError_t res = wrap::hipDeviceGetStreamPriorityRange(&lowest, &highest); + hipError_t res = hipDeviceGetStreamPriorityRange(&lowest, &highest); if (res != hipSuccess) { LOG(ERROR) << "Could not query stream priority range. Returning default priority."; @@ -73,11 +72,11 @@ absl::StatusOr CreateStream(StreamExecutor* executor, hipStream_t stream; if (priority == 0) { TF_RETURN_IF_ERROR(ToStatus( - wrap::hipStreamCreateWithFlags(&stream, hipStreamDefault), + hipStreamCreateWithFlags(&stream, hipStreamDefault), "Failed to create stream")); // switch to hipStreamNonBlocking? } else { TF_RETURN_IF_ERROR(ToStatus( - wrap::hipStreamCreateWithPriority(&stream, hipStreamDefault, priority), + hipStreamCreateWithPriority(&stream, hipStreamDefault, priority), "Failed to create stream")); // switch to hipStreamNonBlocking? } @@ -89,7 +88,7 @@ absl::StatusOr CreateStream(StreamExecutor* executor, absl::Status RecordEvent(StreamExecutor* executor, hipEvent_t event, hipStream_t stream) { std::unique_ptr activation = executor->Activate(); - hipError_t res = wrap::hipEventRecord(event, stream); + hipError_t res = hipEventRecord(event, stream); switch (res) { case hipSuccess: return absl::OkStatus(); @@ -109,7 +108,7 @@ absl::Status WaitStreamOnEvent(StreamExecutor* executor, hipStream_t stream, hipEvent_t event) { std::unique_ptr activation = executor->Activate(); TF_RETURN_IF_ERROR( - ToStatus(wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */), + ToStatus(hipStreamWaitEvent(stream, event, 0 /* = flags */), "could not wait stream on event")); return absl::OkStatus(); } @@ -119,7 +118,7 @@ absl::Status AsynchronousMemcpyD2H(StreamExecutor* executor, void* host_dst, hipStream_t stream) { std::unique_ptr activation = executor->Activate(); TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream), + hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream), absl::StrFormat( "failed to enqueue async memcpy from device to host: host dst: %p; " "Gpu src: %p; size: %llu=0x%llx", @@ -137,8 +136,7 @@ absl::Status AsynchronousMemcpyH2D(StreamExecutor* executor, uint64_t size, hipStream_t stream) { std::unique_ptr activation = executor->Activate(); TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast(host_src), size, - stream), + hipMemcpyHtoDAsync(gpu_dst, const_cast(host_src), size, stream), absl::StrFormat( "failed to enqueue async memcpy from host to device: Gpu dst: %p; " "host src: %p; size: %llu=0x%llx", @@ -157,7 +155,7 @@ absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, hipStream_t stream) { std::unique_ptr activation = executor->Activate(); TF_RETURN_IF_ERROR(ToStatus( - wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream), + hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream), absl::StrFormat("failed to enqueue async memcpy from device to device: " "Gpu dst: %p ; Gpu src: %p ; size: %llu=0x%llx", absl::bit_cast(gpu_dst), @@ -172,7 +170,7 @@ absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, absl::Status SynchronizeStream(StreamExecutor* executor, hipStream_t stream) { std::unique_ptr activation = executor->Activate(); - TF_RETURN_IF_ERROR(ToStatus(wrap::hipStreamSynchronize(stream), + TF_RETURN_IF_ERROR(ToStatus(hipStreamSynchronize(stream), "Could not synchronize on ROCM stream")); VLOG(2) << "successfully synchronized stream " << stream << " on device " << executor->device_ordinal(); @@ -231,13 +229,13 @@ void DestroyStream(StreamExecutor* executor, hipStream_t stream) { if (stream == nullptr) { return; } - hipError_t res = wrap::hipStreamQuery(stream); + hipError_t res = hipStreamQuery(stream); if (res != hipSuccess) { LOG(ERROR) << "stream not idle on destroy: " << ToString(res); } std::unique_ptr activation = executor->Activate(); - res = wrap::hipStreamDestroy(stream); + res = hipStreamDestroy(stream); if (res != hipSuccess) { LOG(ERROR) << "failed to destroy ROCM stream for device " << executor->device_ordinal() << ": " << ToString(res); @@ -263,9 +261,9 @@ absl::Status RocmStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, if (size % sizeof(uint32_t) != 0) { return absl::InvalidArgumentError("size must be a multiple of 4 bytes."); } - return ToStatus(wrap::hipMemsetD32Async(location->opaque(), pattern, size / 4, - stream_handle_), - "Failed to memset memory"); + return ToStatus( + hipMemsetD32Async(location->opaque(), pattern, size / 4, stream_handle_), + "Failed to memset memory"); } absl::Status RocmStream::MemZero(DeviceMemoryBase* location, uint64_t size) { @@ -275,7 +273,7 @@ absl::Status RocmStream::MemZero(DeviceMemoryBase* location, uint64_t size) { } else { std::unique_ptr activation = executor_->Activate(); return ToStatus( - wrap::hipMemsetAsync(location->opaque(), 0x0, size, stream_handle_), + hipMemsetAsync(location->opaque(), 0x0, size, stream_handle_), "Failed to enqueue async memset operation"); } } @@ -320,8 +318,8 @@ absl::Status RocmStream::DoHostCallbackWithStatus( } }); return ToStatus( - wrap::hipLaunchHostFunc(stream_handle_, (hipHostFn_t)InternalHostCallback, - callback_ptr), + hipLaunchHostFunc(stream_handle_, (hipHostFn_t)InternalHostCallback, + callback_ptr), "unable to add host callback"); } @@ -343,18 +341,18 @@ absl::Status LaunchRocmKernel( #if TF_ROCM_VERSION < 60200 // for in-process kernel this function returns mangled kernel function name, // and null otherwise - auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream); + auto name = hipKernelNameRefByPtr((const void*)function, stream); if (name != nullptr) { - res = wrap::hipLaunchKernel((const void*)function, - dim3(grid_dim_x, grid_dim_y, grid_dim_z), - dim3(block_dim_x, block_dim_y, block_dim_z), - kernel_params, shared_mem_bytes, stream); + res = hipLaunchKernel((const void*)function, + dim3(grid_dim_x, grid_dim_y, grid_dim_z), + dim3(block_dim_x, block_dim_y, block_dim_z), + kernel_params, shared_mem_bytes, stream); } else // NOLINT(readability/braces) #endif // TF_ROCM_VERSION < 60200 { - res = wrap::hipModuleLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, - block_dim_z, shared_mem_bytes, stream, kernel_params, extra); + res = hipModuleLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, + block_dim_x, block_dim_y, block_dim_z, + shared_mem_bytes, stream, kernel_params, extra); } TF_RETURN_IF_ERROR( ToStatus(res, absl::StrCat("Failed to launch ROCm kernel: ", kernel_name, diff --git a/xla/stream_executor/rocm/rocm_timer.cc b/xla/stream_executor/rocm/rocm_timer.cc index 31b1c11514b15..05a9656f0c8e9 100644 --- a/xla/stream_executor/rocm/rocm_timer.cc +++ b/xla/stream_executor/rocm/rocm_timer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/time/time.h" #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/activate_context.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/stream_executor/stream.h" @@ -39,14 +38,14 @@ absl::StatusOr GetEventElapsedTime(StreamExecutor* executor, std::unique_ptr activation = executor->Activate(); // The stop event must have completed in order for hipEventElapsedTime to // work. - hipError_t res = wrap::hipEventSynchronize(stop); + hipError_t res = hipEventSynchronize(stop); if (res != hipSuccess) { LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res); return false; } float elapsed_milliseconds; TF_RETURN_IF_ERROR( - ToStatus(wrap::hipEventElapsedTime(&elapsed_milliseconds, start, stop), + ToStatus(hipEventElapsedTime(&elapsed_milliseconds, start, stop), "failed to get elapsed time between events")); return elapsed_milliseconds; diff --git a/xla/stream_executor/rocm/rocsolver_wrapper.h b/xla/stream_executor/rocm/rocsolver_wrapper.h index ef6bbfaefdbe5..6f3a7c3ee82c1 100644 --- a/xla/stream_executor/rocm/rocsolver_wrapper.h +++ b/xla/stream_executor/rocm/rocsolver_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The OpenXLA Authors. +/* Copyright 2026 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,57 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file wraps rocsolver API calls with dso loader so that we don't need to -// have explicit linking to librocsolver. All TF hipsarse API usage should route -// through this wrapper. - #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_ -#include "rocm/rocm_config.h" -#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/rocsolver/rocsolver.h" -#else -#include "rocm/include/rocsolver.h" -#endif - -#include "xla/tsl/platform/env.h" -#include "tsl/platform/dso_loader.h" +#include "rocm/rocm_config.h" namespace stream_executor { namespace wrap { -#ifdef PLATFORM_GOOGLE - -#define ROCSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args)->decltype(::api_name(args...)) { \ - return ::api_name(args...); \ - } - -#else - -#define TO_STR_(x) #x -#define TO_STR(x) TO_STR_(x) - -#define ROCSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - tsl::internal::CachedDsoLoader::GetRocsolverDsoHandle().value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ - } - -#endif +#define ROCSOLVER_API_WRAPPER(api_name) using ::api_name; // clang-format off #define FOREACH_ROCSOLVER_API(__macro) \ @@ -107,8 +66,6 @@ namespace wrap { FOREACH_ROCSOLVER_API(ROCSOLVER_API_WRAPPER) -#undef TO_STR_ -#undef TO_STR #undef FOREACH_ROCSOLVER_API #undef ROCSOLVER_API_WRAPPER diff --git a/xla/stream_executor/rocm/roctracer_wrapper.h b/xla/stream_executor/rocm/roctracer_wrapper.h index 4155290a4cc4f..c18ca1bf6bcea 100644 --- a/xla/stream_executor/rocm/roctracer_wrapper.h +++ b/xla/stream_executor/rocm/roctracer_wrapper.h @@ -39,42 +39,10 @@ limitations under the License. #include "rocm/include/roctracer/roctracer.h" #include "rocm/include/roctracer/roctracer_hip.h" -#include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" -#include "tsl/platform/platform.h" - namespace stream_executor { namespace wrap { -#ifdef PLATFORM_GOOGLE - -#define ROCTRACER_API_WRAPPER(API_NAME) \ - template \ - auto API_NAME(Args... args) -> decltype((::API_NAME)(args...)) { \ - return (::API_NAME)(args...); \ - } - -#else - -#define ROCTRACER_API_WRAPPER(API_NAME) \ - template \ - auto API_NAME(Args... args) -> decltype(::API_NAME(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = #API_NAME; \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - tsl::internal::CachedDsoLoader::GetRoctracerDsoHandle().value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in roctracer DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ - } - -#endif // PLATFORM_GOOGLE - +#define ROCTRACER_API_WRAPPER(API_NAME) using ::API_NAME; #if TF_ROCM_VERSION >= 50300 #define FOREACH_ROCTRACER_API(DO_FUNC) \ diff --git a/xla/tsl/platform/default/dlopen_checker.cc b/xla/tsl/platform/default/dlopen_checker.cc index 763df14caf62d..5cf0674dc3ce3 100644 --- a/xla/tsl/platform/default/dlopen_checker.cc +++ b/xla/tsl/platform/default/dlopen_checker.cc @@ -43,31 +43,11 @@ absl::Status TryDlopenCUDALibraries() { } } -absl::Status TryDlopenROCmLibraries() { - auto rocblas_status = GetRocblasDsoHandle(); - auto miopen_status = GetMiopenDsoHandle(); - auto rocfft_status = GetHipfftDsoHandle(); - auto rocrand_status = GetRocrandDsoHandle(); -#if TF_HIPBLASLT - auto hiplaslt_status = CachedLoader::GetHipblasLtDsoHandle(); -#endif - if (!rocblas_status.status().ok() || !miopen_status.status().ok() || - !rocfft_status.status().ok() || !rocrand_status.status().ok() -#if TF_HIPBLASLT - || !hipblaslt_status.status().ok() -#endif - ) { - return absl::InternalError("Cannot dlopen all ROCm libraries."); - } else { - return absl::OkStatus(); - } -} - absl::Status MaybeTryDlopenGPULibraries() { #if GOOGLE_CUDA return TryDlopenCUDALibraries(); #elif TENSORFLOW_USE_ROCM - return TryDlopenROCmLibraries(); + return absl::OkStatus(); #else LOG(INFO) << "Not built with GPU enabled. Skip GPU library dlopen check."; return absl::OkStatus(); diff --git a/xla/tsl/platform/default/dso_loader.cc b/xla/tsl/platform/default/dso_loader.cc index 1a05bcd1e6733..91b7fc3c8e4ad 100644 --- a/xla/tsl/platform/default/dso_loader.cc +++ b/xla/tsl/platform/default/dso_loader.cc @@ -50,77 +50,7 @@ absl::string_view GetCusparseVersion() { return TF_CUSPARSE_VERSION; } absl::string_view GetNcclVersion() { return TF_NCCL_VERSION; } absl::string_view GetTensorRTVersion() { return TF_TENSORRT_VERSION; } absl::string_view GetNvshmemVersion() { return XLA_NVSHMEM_VERSION; } -absl::string_view GetHipVersion() { -#if TENSORFLOW_USE_ROCM - return TF_HIPRUNTIME_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -absl::string_view GetRocblasVersion() { -#if TENSORFLOW_USE_ROCM - return TF_ROCBLAS_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetHipblasltVersion() { -#if TENSORFLOW_USE_ROCM - return TF_HIPBLASLT_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetMiopenVersion() { -#if TENSORFLOW_USE_ROCM - return TF_MIOPEN_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetHipfftVersion() { -#if TENSORFLOW_USE_ROCM - return TF_HIPFFT_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetRocsolverVersion() { -#if TENSORFLOW_USE_ROCM - return TF_ROCSOLVER_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetHipsparseVersion() { -#if TENSORFLOW_USE_ROCM - return TF_HIPSPARSE_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetRoctracerVersion() { -#if TENSORFLOW_USE_ROCM - return TF_ROCTRACER_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetHipsolverVersion() { -#if TENSORFLOW_USE_ROCM - return TF_HIPSOLVER_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} -std::string GetRocrandVersion() { -#if TENSORFLOW_USE_ROCM - return TF_ROCRAND_SOVERSION; -#else // TENSORFLOW_USE_ROCM - return ""; -#endif // TENSORFLOW_USE_ROCM -} absl::StatusOr GetDsoHandle(const std::string& name, absl::string_view version) { @@ -221,48 +151,6 @@ absl::StatusOr GetNvInferPluginDsoHandle() { #endif } -absl::StatusOr GetRocblasDsoHandle() { - return GetDsoHandle("rocblas", GetRocblasVersion()); -} - -absl::StatusOr GetMiopenDsoHandle() { - return GetDsoHandle("MIOpen", GetMiopenVersion()); -} - -absl::StatusOr GetHipfftDsoHandle() { - return GetDsoHandle("hipfft", GetHipfftVersion()); -} - -absl::StatusOr GetRocrandDsoHandle() { - return GetDsoHandle("rocrand", GetRocrandVersion()); -} - -absl::StatusOr GetRocsolverDsoHandle() { - return GetDsoHandle("rocsolver", GetRocsolverVersion()); -} - -#if TF_ROCM_VERSION >= 40500 -absl::StatusOr GetHipsolverDsoHandle() { - return GetDsoHandle("hipsolver", GetHipsolverVersion()); -} -#endif - -absl::StatusOr GetRoctracerDsoHandle() { - return GetDsoHandle("roctracer64", GetRoctracerVersion()); -} - -absl::StatusOr GetHipsparseDsoHandle() { - return GetDsoHandle("hipsparse", GetHipsparseVersion()); -} - -absl::StatusOr GetHipblasltDsoHandle() { - return GetDsoHandle("hipblaslt", GetHipblasltVersion()); -} - -absl::StatusOr GetHipDsoHandle() { - return GetDsoHandle("amdhip64", GetHipVersion()); -} - } // namespace DsoLoader namespace CachedDsoLoader { @@ -311,58 +199,6 @@ absl::StatusOr GetCudnnDsoHandle() { return *result; } -absl::StatusOr GetRocblasDsoHandle() { - static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); - return *result; -} - -absl::StatusOr GetMiopenDsoHandle() { - static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); - return *result; -} - -absl::StatusOr GetHipfftDsoHandle() { - static auto result = new auto(DsoLoader::GetHipfftDsoHandle()); - return *result; -} - -absl::StatusOr GetRocrandDsoHandle() { - static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); - return *result; -} - -absl::StatusOr GetRoctracerDsoHandle() { - static auto result = new auto(DsoLoader::GetRoctracerDsoHandle()); - return *result; -} - -absl::StatusOr GetRocsolverDsoHandle() { - static auto result = new auto(DsoLoader::GetRocsolverDsoHandle()); - return *result; -} - -#if TF_ROCM_VERSION >= 40500 -absl::StatusOr GetHipsolverDsoHandle() { - static auto result = new auto(DsoLoader::GetHipsolverDsoHandle()); - return *result; -} -#endif - -absl::StatusOr GetHipsparseDsoHandle() { - static auto result = new auto(DsoLoader::GetHipsparseDsoHandle()); - return *result; -} - -absl::StatusOr GetHipblasltDsoHandle() { - static auto result = new auto(DsoLoader::GetHipblasltDsoHandle()); - return *result; -} - -absl::StatusOr GetHipDsoHandle() { - static auto result = new auto(DsoLoader::GetHipDsoHandle()); - return *result; -} - } // namespace CachedDsoLoader } // namespace internal } // namespace tsl diff --git a/xla/tsl/platform/default/dso_loader.h b/xla/tsl/platform/default/dso_loader.h index 4234da6b7272a..bf30efdeb1f1f 100644 --- a/xla/tsl/platform/default/dso_loader.h +++ b/xla/tsl/platform/default/dso_loader.h @@ -42,16 +42,6 @@ absl::StatusOr GetNvshmemDsoHandle(); absl::StatusOr GetNvInferDsoHandle(); absl::StatusOr GetNvInferPluginDsoHandle(); -absl::StatusOr GetRocblasDsoHandle(); -absl::StatusOr GetMiopenDsoHandle(); -absl::StatusOr GetHipfftDsoHandle(); -absl::StatusOr GetRocrandDsoHandle(); -absl::StatusOr GetRoctracerDsoHandle(); -absl::StatusOr GetRocsolverDsoHandle(); -absl::StatusOr GetHipsolverDsoHandle(); -absl::StatusOr GetHipsparseDsoHandle(); -absl::StatusOr GetHipDsoHandle(); - // The following method tries to dlopen all necessary GPU libraries for the GPU // platform TF is built with (CUDA or ROCm) only when these libraries should be // dynamically loaded. Error status is returned when any of the libraries cannot @@ -77,17 +67,6 @@ absl::StatusOr GetCusolverDsoHandle(); absl::StatusOr GetCusparseDsoHandle(); absl::StatusOr GetCuptiDsoHandle(); absl::StatusOr GetCudnnDsoHandle(); - -absl::StatusOr GetRocblasDsoHandle(); -absl::StatusOr GetMiopenDsoHandle(); -absl::StatusOr GetHipfftDsoHandle(); -absl::StatusOr GetRocrandDsoHandle(); -absl::StatusOr GetRocsolverDsoHandle(); -absl::StatusOr GetHipsolverDsoHandle(); -absl::StatusOr GetRoctracerDsoHandle(); -absl::StatusOr GetHipsparseDsoHandle(); -absl::StatusOr GetHipblasltDsoHandle(); -absl::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader } // namespace internal