Skip to content

Commit 2e5d65e

Browse files
amd-songpiaommakevic-amd
authored andcommitted
PR tensorflow#34806: [ROCm] fix the calling convention for AMD GPU
Imported from GitHub PR openxla/xla#34806 Bugfix: PR tensorflow#34230 ("argument removal without building prototype") removed the call to **BuildKernelPrototypeFromUniqueName** which internally called **AnnotateFunctionAsGpuKernel** to set the correct calling convention based on the target GPU. Without this, Triton's **PTX_Kernel** calling convention was copied directly, which doesn't work on AMD GPUs and lead to "LLVM ERROR: unsupported calling convention". Fix: Added a call to **AnnotateFunctionAsGpuKernel** in **RemoveUnusedTritonAbiArguments** to properly set: PTX_Kernel (71) for NVIDIA AMDGPU_KERNEL (91) for AMD SPIR_KERNEL (76) for SPIR @xla-rotation could you review my PR, please? Copybara import of the project: -- ebd6e1fa03033bc9f6913351323fce26e1a8e4d2 by Songlin Piao <Songlin.Piao@amd.com>: replace the manual calling convention fix with AnnotateFunctionAsGpuKernel -- 4f16d9579b11c2984c8ebe58041b0d2b9ea5ba3f by Songlin Piao <Songlin.Piao@amd.com>: added a filecheck test Merging this change closes tensorflow#34806 PiperOrigin-RevId: 842146580
1 parent 692e221 commit 2e5d65e

3 files changed

Lines changed: 35 additions & 1 deletion

File tree

third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,15 @@ absl::StatusOr<llvm::Function*> RemoveUnusedTritonAbiArguments(
266266
.getCallee();
267267
llvm::Function* new_function = static_cast<llvm::Function*>(inserted);
268268

269-
new_function->setCallingConv(impl_fn->getCallingConv());
270269
new_function->copyMetadata(impl_fn, 0);
271270
new_function->setAttributes(impl_fn->getAttributes());
272271

272+
// Set the correct calling convention for the target GPU.
273+
// Triton generates PTX_Kernel CC even for AMD, so we need to use
274+
// AnnotateFunctionAsGpuKernel to set the correct CC based on target triple.
275+
llvm::IRBuilder<> builder(llvm_module->getContext());
276+
AnnotateFunctionAsGpuKernel(llvm_module, new_function, &builder);
277+
273278
new_function->splice(new_function->begin(), impl_fn);
274279

275280
for (const auto& [impl_fn_arg, kernel_arg] :

third_party/xla/xla/service/gpu/tests/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ lit_test_suite_for_gpus(
661661
"slice_to_dynamic.hlo",
662662
"sorting.hlo",
663663
"sub_byte_collectives.hlo",
664+
"triton_calling_convention.hlo",
664665
"triton_naming.hlo",
665666
"zero_clamp_abs_index.hlo",
666667
],
@@ -673,10 +674,12 @@ lit_test_suite_for_gpus(
673674
disabled_on_gpus = {
674675
"v100": [
675676
"kernel_reuse.hlo",
677+
"triton_calling_convention.hlo",
676678
"triton_naming.hlo",
677679
],
678680
"p100": [
679681
"kernel_reuse.hlo",
682+
"triton_calling_convention.hlo",
680683
"triton_naming.hlo",
681684
],
682685
"mi200": [
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s
2+
3+
// Verify that Triton kernels have the correct calling convention:
4+
// - PTX_KERNEL (71) for NVIDIA targets
5+
// - AMDGPU_KERNEL (91) for AMD targets
6+
// CHECK-PTX: define ptx_kernel void @triton_
7+
// CHECK-GCN: define amdgpu_kernel void @triton_
8+
9+
HloModule TritonCallingConvention, is_scheduled=true
10+
11+
triton_softmax {
12+
param_0 = f32[4,4]{1,0} parameter(0)
13+
ROOT exp = f32[4,4]{1,0} exponential(param_0)
14+
}
15+
16+
ENTRY main {
17+
param_0 = f32[4,4]{1,0} parameter(0)
18+
ROOT triton_softmax = f32[4,4]{1,0} fusion(param_0), kind=kCustom,
19+
calls=triton_softmax,
20+
backend_config={"fusion_backend_config":{
21+
"kind":"__triton",
22+
"block_level_fusion_config":{"output_tiles":[{"sizes":["4","4"]}],
23+
"num_warps":"1",
24+
"num_ctas":"1",
25+
"num_stages":"1"}}}
26+
}

0 commit comments

Comments
 (0)