Skip to content

Commit 9e95cff

Browse files
authored
[AArch64] Add vector expansion support for ISD::FPOW when using ArmPL (llvm#183526)
This patch is split off from PR llvm#183319 and teaches the backend how to lower the FPOW DAG node to the vector math library function when using ArmPL. This is similar to what we already do for llvm.sincos/FSINCOS today.
1 parent 28cbc68 commit 9e95cff

File tree

5 files changed

+125
-5
lines changed

5 files changed

+125
-5
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
13181318
// scalarizing.
13191319
break;
13201320
}
1321+
case ISD::FPOW: {
1322+
RTLIB::Libcall LC = RTLIB::getPOW(Node->getValueType(0));
1323+
if (tryExpandVecMathCall(Node, LC, Results))
1324+
return;
1325+
1326+
// TODO: Try to see if there's a narrower call available to use before
1327+
// scalarizing.
1328+
break;
1329+
}
13211330
case ISD::FMODF: {
13221331
EVT VT = Node->getValueType(0);
13231332
RTLIB::Libcall LC = RTLIB::getMODF(VT);

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,24 @@ RTLIB::Libcall RTLIB::getPOWI(EVT RetVT) {
539539
}
540540

541541
RTLIB::Libcall RTLIB::getPOW(EVT RetVT) {
542+
// TODO: Tablegen should generate this function
543+
if (RetVT.isVector()) {
544+
if (!RetVT.isSimple())
545+
return RTLIB::UNKNOWN_LIBCALL;
546+
switch (RetVT.getSimpleVT().SimpleTy) {
547+
case MVT::v4f32:
548+
return RTLIB::POW_V4F32;
549+
case MVT::v2f64:
550+
return RTLIB::POW_V2F64;
551+
case MVT::nxv4f32:
552+
return RTLIB::POW_NXV4F32;
553+
case MVT::nxv2f64:
554+
return RTLIB::POW_NXV2F64;
555+
default:
556+
return RTLIB::UNKNOWN_LIBCALL;
557+
}
558+
}
559+
542560
return getFPLibCall(RetVT, POW_F32, POW_F64, POW_F80, POW_F128, POW_PPCF128);
543561
}
544562

llvm/lib/IR/RuntimeLibcalls.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,15 @@ RuntimeLibcallsInfo::RuntimeLibcallsInfo(const Triple &TT,
6565
RTLIB::impl_armpl_svsincos_f64_x, RTLIB::impl_armpl_svsincos_f32_x,
6666
RTLIB::impl_armpl_vsincospiq_f32, RTLIB::impl_armpl_vsincospiq_f64,
6767
RTLIB::impl_armpl_svsincospi_f32_x,
68-
RTLIB::impl_armpl_svsincospi_f64_x})
68+
RTLIB::impl_armpl_svsincospi_f64_x, RTLIB::impl_armpl_svpow_f32_x,
69+
RTLIB::impl_armpl_svpow_f64_x, RTLIB::impl_armpl_vpowq_f32,
70+
RTLIB::impl_armpl_vpowq_f64})
6971
setAvailable(Impl);
7072

7173
for (RTLIB::LibcallImpl Impl :
7274
{RTLIB::impl_armpl_vfmodq_f32, RTLIB::impl_armpl_vfmodq_f64,
73-
RTLIB::impl_armpl_vsincosq_f64, RTLIB::impl_armpl_vsincosq_f32})
75+
RTLIB::impl_armpl_vsincosq_f64, RTLIB::impl_armpl_vsincosq_f32,
76+
RTLIB::impl_armpl_vpowq_f32, RTLIB::impl_armpl_vpowq_f64})
7477
setLibcallImplCallingConv(Impl, CallingConv::AArch64_VectorCall);
7578
break;
7679
default:
@@ -288,16 +291,24 @@ RuntimeLibcallsInfo::getFunctionTy(LLVMContext &Ctx, const Triple &TT,
288291
case RTLIB::impl_armpl_vfmodq_f32:
289292
case RTLIB::impl_armpl_vfmodq_f64:
290293
case RTLIB::impl_armpl_svfmod_f32_x:
291-
case RTLIB::impl_armpl_svfmod_f64_x: {
294+
case RTLIB::impl_armpl_svfmod_f64_x:
295+
case RTLIB::impl_armpl_vpowq_f32:
296+
case RTLIB::impl_armpl_vpowq_f64:
297+
case RTLIB::impl_armpl_svpow_f32_x:
298+
case RTLIB::impl_armpl_svpow_f64_x: {
292299
bool IsF32 = LibcallImpl == RTLIB::impl__ZGVnN4vv_fmodf ||
293300
LibcallImpl == RTLIB::impl__ZGVsMxvv_fmodf ||
294301
LibcallImpl == RTLIB::impl_armpl_svfmod_f32_x ||
295-
LibcallImpl == RTLIB::impl_armpl_vfmodq_f32;
302+
LibcallImpl == RTLIB::impl_armpl_vfmodq_f32 ||
303+
LibcallImpl == RTLIB::impl_armpl_vpowq_f32 ||
304+
LibcallImpl == RTLIB::impl_armpl_svpow_f32_x;
296305

297306
bool IsScalable = LibcallImpl == RTLIB::impl__ZGVsMxvv_fmod ||
298307
LibcallImpl == RTLIB::impl__ZGVsMxvv_fmodf ||
299308
LibcallImpl == RTLIB::impl_armpl_svfmod_f32_x ||
300-
LibcallImpl == RTLIB::impl_armpl_svfmod_f64_x;
309+
LibcallImpl == RTLIB::impl_armpl_svfmod_f64_x ||
310+
LibcallImpl == RTLIB::impl_armpl_svpow_f32_x ||
311+
LibcallImpl == RTLIB::impl_armpl_svpow_f64_x;
301312

302313
AttrBuilder FuncAttrBuilder(Ctx);
303314

@@ -448,6 +459,8 @@ bool RuntimeLibcallsInfo::hasVectorMaskArgument(RTLIB::LibcallImpl Impl) {
448459
case RTLIB::impl_armpl_svsincospi_f64_x:
449460
case RTLIB::impl__ZGVsMxvv_fmod:
450461
case RTLIB::impl__ZGVsMxvv_fmodf:
462+
case RTLIB::impl_armpl_svpow_f32_x:
463+
case RTLIB::impl_armpl_svpow_f64_x:
451464
return true;
452465
default:
453466
return false;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc -start-before=codegenprepare -mtriple=aarch64-gnu-linux -mattr=+neon,+sve \
3+
; RUN: -vector-library=ArmPL < %s | FileCheck %s -check-prefix=ARMPL
4+
5+
define <4 x float> @test_pow_v4f32(<4 x float> %x, <4 x float> %y) nounwind {
6+
; ARMPL-LABEL: test_pow_v4f32:
7+
; ARMPL: // %bb.0:
8+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
9+
; ARMPL-NEXT: bl armpl_vpowq_f32
10+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
11+
; ARMPL-NEXT: ret
12+
%result = call <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> %y)
13+
ret <4 x float> %result
14+
}
15+
16+
define <2 x double> @test_pow_v2f64(<2 x double> %x, <2 x double> %y) nounwind {
17+
; ARMPL-LABEL: test_pow_v2f64:
18+
; ARMPL: // %bb.0:
19+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
20+
; ARMPL-NEXT: bl armpl_vpowq_f64
21+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
22+
; ARMPL-NEXT: ret
23+
%result = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> %y)
24+
ret <2 x double> %result
25+
}
26+
27+
define <vscale x 4 x float> @test_pow_nxv4f32(<vscale x 4 x float> %x, <vscale x 4 x float> %y) nounwind {
28+
; ARMPL-LABEL: test_pow_nxv4f32:
29+
; ARMPL: // %bb.0:
30+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
31+
; ARMPL-NEXT: ptrue p0.s
32+
; ARMPL-NEXT: bl armpl_svpow_f32_x
33+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
34+
; ARMPL-NEXT: ret
35+
%result = call <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float> %x, <vscale x 4 x float> %y)
36+
ret <vscale x 4 x float> %result
37+
}
38+
39+
define <vscale x 2 x double> @test_pow_nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> %y) nounwind {
40+
; ARMPL-LABEL: test_pow_nxv2f64:
41+
; ARMPL: // %bb.0:
42+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
43+
; ARMPL-NEXT: ptrue p0.d
44+
; ARMPL-NEXT: bl armpl_svpow_f64_x
45+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
46+
; ARMPL-NEXT: ret
47+
%result = call <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> %y)
48+
ret <vscale x 2 x double> %result
49+
}
50+
51+
define <4 x float> @test_pow_v4f32_025(<4 x float> %x) nounwind {
52+
; ARMPL-LABEL: test_pow_v4f32_025:
53+
; ARMPL: // %bb.0:
54+
; ARMPL-NEXT: fsqrt v0.4s, v0.4s
55+
; ARMPL-NEXT: fsqrt v0.4s, v0.4s
56+
; ARMPL-NEXT: ret
57+
%result = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> splat (float 2.5e-01))
58+
ret <4 x float> %result
59+
}
60+
61+
define <vscale x 2 x double> @test_pow_nxv2f64_075(<vscale x 2 x double> %x) nounwind {
62+
; ARMPL-LABEL: test_pow_nxv2f64_075:
63+
; ARMPL: // %bb.0:
64+
; ARMPL-NEXT: ptrue p0.d
65+
; ARMPL-NEXT: fsqrt z0.d, p0/m, z0.d
66+
; ARMPL-NEXT: movprfx z1, z0
67+
; ARMPL-NEXT: fsqrt z1.d, p0/m, z0.d
68+
; ARMPL-NEXT: fmul z0.d, z0.d, z1.d
69+
; ARMPL-NEXT: ret
70+
%result = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> splat (double 7.5e-01))
71+
ret <vscale x 2 x double> %result
72+
}

llvm/test/Transforms/Util/DeclareRuntimeLibcalls/armpl.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
; CHECK: declare <vscale x 2 x double> @armpl_svmodf_f64_x(<vscale x 2 x double>, ptr noalias nonnull writeonly align 16, <vscale x 2 x i1>) [[ATTRS_PTR_ARG]]
1111

12+
; CHECK: declare <vscale x 4 x float> @armpl_svpow_f32_x(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x i1>) [[ATTRS]]
13+
14+
; CHECK: declare <vscale x 2 x double> @armpl_svpow_f64_x(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x i1>) [[ATTRS]]
15+
1216
; CHECK: declare void @armpl_svsincos_f32_x(<vscale x 4 x float>, ptr noalias nonnull writeonly align 16, ptr noalias nonnull writeonly align 16, <vscale x 4 x i1>) [[ATTRS_PTR_ARG]]
1317

1418
; CHECK: declare void @armpl_svsincos_f64_x(<vscale x 2 x double>, ptr noalias nonnull writeonly align 16, ptr noalias nonnull writeonly align 16, <vscale x 2 x i1>) [[ATTRS_PTR_ARG]]
@@ -25,6 +29,10 @@
2529

2630
; CHECK: declare <2 x double> @armpl_vmodfq_f64(<2 x double>, ptr noalias nonnull writeonly align 16) [[ATTRS_PTR_ARG]]
2731

32+
; CHECK: declare aarch64_vector_pcs <4 x float> @armpl_vpowq_f32(<4 x float>, <4 x float>) [[ATTRS]]
33+
34+
; CHECK: declare aarch64_vector_pcs <2 x double> @armpl_vpowq_f64(<2 x double>, <2 x double>) [[ATTRS]]
35+
2836
; CHECK: declare void @armpl_vsincospiq_f32(<4 x float>, ptr noalias nonnull writeonly align 16, ptr noalias nonnull writeonly align 16) [[ATTRS_PTR_ARG]]
2937

3038
; CHECK: declare void @armpl_vsincospiq_f64(<2 x double>, ptr noalias nonnull writeonly align 16, ptr noalias nonnull writeonly align 16) [[ATTRS_PTR_ARG]]

0 commit comments

Comments
 (0)