Skip to content

Commit 058398c

Browse files
authored
[NVPTX] Constant fold blockDim when reqntid is specified (#191575)
Currently, NVPTX cannot fold the `ntid.x/y/z` intrinsic calls into const values when `reqntid` is specified, which prevents the code from further optimization. Therefore, in this change, we extend the `NVVMIntrRange` pass to: - Tighten `ntid.x/y/z` intrinsic calls to one value range, which can be const folded in later InstCombine pass - Tighten `tid.x/y/z` range attributes to use per-dimension reqntid bounds - When .reqntid exceeds hardware limits, garbage-in/garbage-out
1 parent 695e1ba commit 058398c

3 files changed

Lines changed: 123 additions & 23 deletions

File tree

llvm/lib/Target/NVPTX/NVVMIntrRange.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,47 +66,57 @@ static bool runNVVMIntrRange(Function &F) {
6666
if (!isKernelFunction(F))
6767
return false;
6868

69-
const auto OverallReqNTID = getOverallReqNTID(F);
69+
auto ReqNTID = getReqNTID(F);
7070
const auto OverallMaxNTID = getOverallMaxNTID(F);
7171
const auto OverallClusterRank = getOverallClusterRank(F);
7272

7373
// If this function lacks any range information, do nothing.
74-
if (!(OverallReqNTID || OverallMaxNTID || OverallClusterRank))
74+
if (!(!ReqNTID.empty() || OverallMaxNTID || OverallClusterRank))
7575
return false;
7676

77-
const unsigned FunctionNTID = OverallReqNTID.value_or(
78-
OverallMaxNTID.value_or(std::numeric_limits<unsigned>::max()));
77+
const unsigned MaxNTID =
78+
OverallMaxNTID.value_or(std::numeric_limits<unsigned>::max());
7979

8080
const unsigned FunctionClusterRank =
8181
OverallClusterRank.value_or(std::numeric_limits<unsigned>::max());
8282

83-
const Vector3 MaxBlockSize{std::min(1024u, FunctionNTID),
84-
std::min(1024u, FunctionNTID),
85-
std::min(64u, FunctionNTID)};
83+
// When reqntid is specified, block dimensions are exact compile-time
84+
// constants. Otherwise, use maxntid (capped at hardware limits) as upper
85+
// bounds.
86+
Vector3 MinBlockDim, MaxBlockDim;
87+
if (!ReqNTID.empty()) {
88+
ReqNTID.resize(3, 1);
89+
MinBlockDim = MaxBlockDim = {ReqNTID[0], ReqNTID[1], ReqNTID[2]};
90+
} else {
91+
MinBlockDim = {1, 1, 1};
92+
MaxBlockDim = {std::min(1024u, MaxNTID), std::min(1024u, MaxNTID),
93+
std::min(64u, MaxNTID)};
94+
}
8695

8796
// We conservatively use the maximum grid size as an upper bound for the
8897
// cluster rank.
8998
const Vector3 MaxClusterRank{std::min(0x7fffffffu, FunctionClusterRank),
9099
std::min(0xffffu, FunctionClusterRank),
91100
std::min(0xffffu, FunctionClusterRank)};
92101

93-
const auto ProccessIntrinsic = [&](IntrinsicInst *II) -> bool {
102+
const auto ProcessIntrinsic = [&](IntrinsicInst *II) -> bool {
94103
switch (II->getIntrinsicID()) {
95104
// Index within block
96105
case Intrinsic::nvvm_read_ptx_sreg_tid_x:
97-
return addRangeAttr(0, MaxBlockSize.X, II);
106+
return addRangeAttr(0, MaxBlockDim.X, II);
98107
case Intrinsic::nvvm_read_ptx_sreg_tid_y:
99-
return addRangeAttr(0, MaxBlockSize.Y, II);
108+
return addRangeAttr(0, MaxBlockDim.Y, II);
100109
case Intrinsic::nvvm_read_ptx_sreg_tid_z:
101-
return addRangeAttr(0, MaxBlockSize.Z, II);
110+
return addRangeAttr(0, MaxBlockDim.Z, II);
102111

103-
// Block size
112+
// Block size: use single-value range when reqntid is specified;
113+
// InstCombine will fold these to constants later.
104114
case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
105-
return addRangeAttr(1, MaxBlockSize.X + 1, II);
115+
return addRangeAttr(MinBlockDim.X, MaxBlockDim.X + 1, II);
106116
case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
107-
return addRangeAttr(1, MaxBlockSize.Y + 1, II);
117+
return addRangeAttr(MinBlockDim.Y, MaxBlockDim.Y + 1, II);
108118
case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
109-
return addRangeAttr(1, MaxBlockSize.Z + 1, II);
119+
return addRangeAttr(MinBlockDim.Z, MaxBlockDim.Z + 1, II);
110120

111121
// Cluster size
112122
case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_x:
@@ -140,7 +150,7 @@ static bool runNVVMIntrRange(Function &F) {
140150
bool Changed = false;
141151
for (Instruction &I : instructions(F))
142152
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I))
143-
Changed |= ProccessIntrinsic(II);
153+
Changed |= ProcessIntrinsic(II);
144154

145155
return Changed;
146156
}

llvm/test/CodeGen/NVPTX/intr-range.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ define ptx_kernel i32 @test_reqntid() "nvvm.reqntid"="20" {
3636
; CHECK-LABEL: define ptx_kernel i32 @test_reqntid(
3737
; CHECK-SAME: ) #[[ATTR1:[0-9]+]] {
3838
; CHECK-NEXT: [[TMP1:%.*]] = call range(i32 0, 20) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
39-
; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 0, 20) i32 @llvm.nvvm.read.ptx.sreg.tid.y()
40-
; CHECK-NEXT: [[TMP2:%.*]] = call range(i32 0, 20) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
41-
; CHECK-NEXT: [[TMP4:%.*]] = call range(i32 1, 21) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
42-
; CHECK-NEXT: [[TMP3:%.*]] = call range(i32 1, 21) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
43-
; CHECK-NEXT: [[TMP6:%.*]] = call range(i32 1, 21) i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
39+
; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 0, 1) i32 @llvm.nvvm.read.ptx.sreg.tid.y()
40+
; CHECK-NEXT: [[TMP2:%.*]] = call range(i32 0, 1) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
41+
; CHECK-NEXT: [[TMP4:%.*]] = call range(i32 20, 21) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
42+
; CHECK-NEXT: [[TMP12:%.*]] = call range(i32 1, 2) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
43+
; CHECK-NEXT: [[TMP6:%.*]] = call range(i32 1, 2) i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
4444
; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP1]], [[TMP5]]
4545
; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], [[TMP2]]
4646
; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP8]], [[TMP4]]
47-
; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[TMP3]]
47+
; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[TMP12]]
4848
; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP10]], [[TMP6]]
49-
; CHECK-NEXT: ret i32 [[TMP3]]
49+
; CHECK-NEXT: ret i32 [[TMP12]]
5050
;
5151
%1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
5252
%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt < %s -S -mtriple=nvptx-nvidia-cuda -mcpu=sm_20 -passes=nvvm-intr-range | FileCheck %s
3+
4+
; When .reqntid specifies 3D dimensions, ntid.x/y/z should be replaced with
5+
; constants and tid.x/y/z should get per-dimension ranges.
6+
; Product 128*4*2 = 1024 is within the hardware limit.
7+
define ptx_kernel i32 @test_reqntid_3d() "nvvm.reqntid"="128,4,2" {
8+
; CHECK-LABEL: define ptx_kernel i32 @test_reqntid_3d(
9+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
10+
; CHECK-NEXT: [[TID_X:%.*]] = call range(i32 0, 128) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
11+
; CHECK-NEXT: [[TID_Y:%.*]] = call range(i32 0, 4) i32 @llvm.nvvm.read.ptx.sreg.tid.y()
12+
; CHECK-NEXT: [[TID_Z:%.*]] = call range(i32 0, 2) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
13+
; CHECK-NEXT: [[NTID_X:%.*]] = call range(i32 128, 129) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
14+
; CHECK-NEXT: [[NTID_Y:%.*]] = call range(i32 4, 5) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
15+
; CHECK-NEXT: [[NTID_Z:%.*]] = call range(i32 2, 3) i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
16+
; CHECK-NEXT: [[A:%.*]] = add i32 [[TID_X]], [[TID_Y]]
17+
; CHECK-NEXT: [[B:%.*]] = add i32 [[A]], [[TID_Z]]
18+
; CHECK-NEXT: [[C:%.*]] = add i32 [[B]], [[NTID_X]]
19+
; CHECK-NEXT: [[D:%.*]] = add i32 [[C]], [[NTID_Y]]
20+
; CHECK-NEXT: [[E:%.*]] = add i32 [[D]], [[NTID_Z]]
21+
; CHECK-NEXT: ret i32 [[E]]
22+
;
23+
%tid.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
24+
%tid.y = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
25+
%tid.z = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
26+
%ntid.x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
27+
%ntid.y = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
28+
%ntid.z = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
29+
%a = add i32 %tid.x, %tid.y
30+
%b = add i32 %a, %tid.z
31+
%c = add i32 %b, %ntid.x
32+
%d = add i32 %c, %ntid.y
33+
%e = add i32 %d, %ntid.z
34+
ret i32 %e
35+
}
36+
37+
; When .reqntid specifies only 1D, y and z default to 1.
38+
define ptx_kernel i32 @test_reqntid_1d() "nvvm.reqntid"="128" {
39+
; CHECK-LABEL: define ptx_kernel i32 @test_reqntid_1d(
40+
; CHECK-SAME: ) #[[ATTR1:[0-9]+]] {
41+
; CHECK-NEXT: [[TID_X:%.*]] = call range(i32 0, 128) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
42+
; CHECK-NEXT: [[TID_Y:%.*]] = call range(i32 0, 1) i32 @llvm.nvvm.read.ptx.sreg.tid.y()
43+
; CHECK-NEXT: [[TID_Z:%.*]] = call range(i32 0, 1) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
44+
; CHECK-NEXT: [[NTID_X:%.*]] = call range(i32 128, 129) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
45+
; CHECK-NEXT: [[NTID_Y:%.*]] = call range(i32 1, 2) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
46+
; CHECK-NEXT: [[NTID_Z:%.*]] = call range(i32 1, 2) i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
47+
; CHECK-NEXT: [[A:%.*]] = add i32 [[TID_X]], [[TID_Y]]
48+
; CHECK-NEXT: [[B:%.*]] = add i32 [[A]], [[TID_Z]]
49+
; CHECK-NEXT: [[C:%.*]] = add i32 [[B]], [[NTID_X]]
50+
; CHECK-NEXT: [[D:%.*]] = add i32 [[C]], [[NTID_Y]]
51+
; CHECK-NEXT: [[E:%.*]] = add i32 [[D]], [[NTID_Z]]
52+
; CHECK-NEXT: ret i32 [[E]]
53+
;
54+
%tid.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
55+
%tid.y = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
56+
%tid.z = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
57+
%ntid.x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
58+
%ntid.y = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
59+
%ntid.z = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
60+
%a = add i32 %tid.x, %tid.y
61+
%b = add i32 %a, %tid.z
62+
%c = add i32 %b, %ntid.x
63+
%d = add i32 %c, %ntid.y
64+
%e = add i32 %d, %ntid.z
65+
ret i32 %e
66+
}
67+
68+
; When .reqntid exceeds hardware limits, garbage-in/garbage-out: the range
69+
; intersection with intrinsic builtin ranges may produce empty or unexpected
70+
; ranges.
71+
define ptx_kernel i32 @test_reqntid_invalid() "nvvm.reqntid"="2048" {
72+
; CHECK-LABEL: define ptx_kernel i32 @test_reqntid_invalid(
73+
; CHECK-SAME: ) #[[ATTR2:[0-9]+]] {
74+
; CHECK-NEXT: [[TID_X:%.*]] = call range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
75+
; CHECK-NEXT: [[NTID_X:%.*]] = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
76+
; CHECK-NEXT: [[A:%.*]] = add i32 [[TID_X]], [[NTID_X]]
77+
; CHECK-NEXT: ret i32 [[A]]
78+
;
79+
%tid.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
80+
%ntid.x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
81+
%a = add i32 %tid.x, %ntid.x
82+
ret i32 %a
83+
}
84+
85+
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
86+
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
87+
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
88+
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
89+
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
90+
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()

0 commit comments

Comments
 (0)