Skip to content

Commit 0ab50a4

Browse files
committed
builtin: add subgroup builtins
1 parent c797246 commit 0ab50a4

File tree

7 files changed

+323
-4
lines changed

7 files changed

+323
-4
lines changed

crates/spirv-std/src/subgroup.rs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,160 @@ pub enum GroupOperation {
6363
PartitionedExclusiveScanNV = 8,
6464
}
6565

66+
/// The number of subgroups within the local workgroup. The value of this variable is at least 1, and is uniform across
67+
/// the invocation group.
68+
///
69+
/// * GLSL: [`gl_NumSubgroups`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
70+
/// * WGSL: [`num_subgroups`](https://www.w3.org/TR/WGSL/#num-subgroups-builtin-value)
71+
/// * SPIR-V: [`NumSubgroups`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
72+
#[doc(alias = "gl_NumSubgroups")]
73+
#[doc(alias = "WorkgroupId")]
74+
#[inline]
75+
#[gpu_only]
76+
pub fn num_subgroups() -> u32 {
77+
crate::load_builtin!(NumSubgroups)
78+
}
79+
80+
/// The index of the subgroup within the local workgroup. The value of this variable is in the range `0` to
81+
/// [`num_subgroups`]`-1`.
82+
///
83+
/// * GLSL: [`gl_SubgroupID`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
84+
/// * WGSL: [`subgroup_id`](https://www.w3.org/TR/WGSL/#subgroup-id-builtin-value)
85+
/// * SPIR-V: [`SubgroupId`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
86+
#[doc(alias = "gl_SubgroupID")]
87+
#[doc(alias = "SubgroupId")]
88+
#[inline]
89+
#[gpu_only]
90+
pub fn subgroup_id() -> u32 {
91+
crate::load_builtin!(SubgroupId)
92+
}
93+
94+
// custom: don't mention glsl extensions
95+
/// The number of invocations within a subgroup, and its value is always a power of 2. The maximum subgroup size
96+
/// supported is 128.
97+
///
98+
/// * GLSL: [`gl_SubgroupSize`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
99+
/// * WGSL: [`subgroup_size`](https://www.w3.org/TR/WGSL/#subgroup-size-builtin-value)
100+
/// * SPIR-V: [`SubgroupSize`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
101+
#[doc(alias = "gl_SubgroupSize")]
102+
#[doc(alias = "SubgroupId")]
103+
#[inline]
104+
#[gpu_only]
105+
pub fn subgroup_size() -> u32 {
106+
crate::load_builtin!(SubgroupSize)
107+
}
108+
109+
/// The index of an invocation within a subgroup. The value of this variable is in the range `0` to
110+
/// [`subgroup_size`]`-1`.
111+
///
112+
/// There is no direct relationship between [`subgroup_invocation_id`] and [`local_invocation_id`] or
113+
/// [`local_invocation_index`]. If the pipeline or shader object was created with full subgroups applications can
114+
/// compute their own local invocation index to serve the same purpose:
115+
///
116+
/// index = SubgroupLocalInvocationId + SubgroupId × SubgroupSize
117+
///
118+
/// If full subgroups are not enabled, some subgroups may be dispatched with inactive invocations that do not correspond
119+
/// to a local workgroup invocation, making the value of index unreliable.
120+
///
121+
/// * GLSL: [`gl_SubgroupInvocationID`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
122+
/// * WGSL: [`subgroup_invocation_id`](https://www.w3.org/TR/WGSL/#subgroup-invocation-id-builtin-value)
123+
/// * SPIR-V: [`SubgroupLocalInvocationId`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
124+
///
125+
/// [`local_invocation_id`]: crate::compute::local_invocation_id
126+
/// [`local_invocation_index`]: crate::compute::local_invocation_index
127+
#[doc(alias = "gl_SubgroupInvocationID")]
128+
#[doc(alias = "SubgroupLocalInvocationId")]
129+
#[inline]
130+
#[gpu_only]
131+
pub fn subgroup_invocation_id() -> u32 {
132+
crate::load_builtin!(SubgroupLocalInvocationId)
133+
}
134+
135+
/// Provides a bitmask of all invocations, with one bit per invocation, where `bit index == gl_SubgroupInvocationID`.
136+
///
137+
/// Bit 0 of the first vector component represents the first invocation, higher-order bits within a component and higher
138+
/// component numbers both represent, in order, higher invocations, and the last invocation is the highest-order bit
139+
/// needed, in the last component needed, to contiguously represent all bits of the invocations in a subgroup.
140+
///
141+
/// * GLSL: [`gl_SubgroupEqMask`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
142+
/// * WGSL: None
143+
/// * SPIR-V: [`SubgroupEqMask`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
144+
#[doc(alias = "gl_SubgroupEqMask")]
145+
#[doc(alias = "SubgroupEqMask")]
146+
#[inline]
147+
#[gpu_only]
148+
pub fn subgroup_eq_mask() -> SubgroupMask {
149+
crate::load_builtin!(SubgroupEqMask)
150+
}
151+
152+
/// Provides a bitmask of all invocations, with one bit per invocation, where `bit index >= gl_SubgroupInvocationID`.
153+
///
154+
/// Bit 0 of the first vector component represents the first invocation, higher-order bits within a component and higher
155+
/// component numbers both represent, in order, higher invocations, and the last invocation is the highest-order bit
156+
/// needed, in the last component needed, to contiguously represent all bits of the invocations in a subgroup.
157+
///
158+
/// * GLSL: [`gl_SubgroupGeMask`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
159+
/// * WGSL: None
160+
/// * SPIR-V: [`SubgroupEqMask`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
161+
#[doc(alias = "gl_SubgroupGeMask")]
162+
#[doc(alias = "SubgroupGeMask")]
163+
#[inline]
164+
#[gpu_only]
165+
pub fn subgroup_ge_mask() -> SubgroupMask {
166+
crate::load_builtin!(SubgroupGeMask)
167+
}
168+
169+
/// Provides a bitmask of all invocations, with one bit per invocation, where `bit index > gl_SubgroupInvocationID`.
170+
///
171+
/// Bit 0 of the first vector component represents the first invocation, higher-order bits within a component and higher
172+
/// component numbers both represent, in order, higher invocations, and the last invocation is the highest-order bit
173+
/// needed, in the last component needed, to contiguously represent all bits of the invocations in a subgroup.
174+
///
175+
/// * GLSL: [`gl_SubgroupGtMask`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
176+
/// * WGSL: None
177+
/// * SPIR-V: [`SubgroupGtMask`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
178+
#[doc(alias = "gl_SubgroupGtMask")]
179+
#[doc(alias = "SubgroupGtMask")]
180+
#[inline]
181+
#[gpu_only]
182+
pub fn subgroup_gt_mask() -> SubgroupMask {
183+
crate::load_builtin!(SubgroupGtMask)
184+
}
185+
186+
/// Provides a bitmask of all invocations, with one bit per invocation, where `bit index <= gl_SubgroupInvocationID`.
187+
///
188+
/// Bit 0 of the first vector component represents the first invocation, higher-order bits within a component and higher
189+
/// component numbers both represent, in order, higher invocations, and the last invocation is the highest-order bit
190+
/// needed, in the last component needed, to contiguously represent all bits of the invocations in a subgroup.
191+
///
192+
/// * GLSL: [`gl_SubgroupLeMask`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
193+
/// * WGSL: None
194+
/// * SPIR-V: [`SubgroupLeMask`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
195+
#[doc(alias = "gl_SubgroupLeMask")]
196+
#[doc(alias = "SubgroupLeMask")]
197+
#[inline]
198+
#[gpu_only]
199+
pub fn subgroup_le_mask() -> SubgroupMask {
200+
crate::load_builtin!(SubgroupLeMask)
201+
}
202+
203+
/// Provides a bitmask of all invocations, with one bit per invocation, where `bit index < gl_SubgroupInvocationID`.
204+
///
205+
/// Bit 0 of the first vector component represents the first invocation, higher-order bits within a component and higher
206+
/// component numbers both represent, in order, higher invocations, and the last invocation is the highest-order bit
207+
/// needed, in the last component needed, to contiguously represent all bits of the invocations in a subgroup.
208+
///
209+
/// * GLSL: [`gl_SubgroupLtMask`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
210+
/// * WGSL: None
211+
/// * SPIR-V: [`SubgroupLtMask`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
212+
#[doc(alias = "gl_SubgroupLtMask")]
213+
#[doc(alias = "SubgroupLtMask")]
214+
#[inline]
215+
#[gpu_only]
216+
pub fn subgroup_lt_mask() -> SubgroupMask {
217+
crate::load_builtin!(SubgroupLtMask)
218+
}
219+
66220
/// The function `subgroupBarrier()` enforces that all active invocations within a
67221
/// subgroup must execute this function before any are allowed to continue their
68222
/// execution, and the results of any memory stores performed using coherent
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble-globals
3+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot
4+
// normalize-stderr-test "OpSource .*\n" -> ""
5+
// normalize-stderr-test "OpLine .*\n" -> ""
6+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
7+
// normalize-stderr-test "; .*\n" -> ""
8+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
9+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
10+
// ignore-vulkan1.0
11+
// ignore-spv1.0
12+
// ignore-spv1.1
13+
// ignore-spv1.2
14+
15+
use spirv_std::glam::*;
16+
use spirv_std::spirv;
17+
use spirv_std::subgroup::*;
18+
19+
#[spirv(compute(threads(1)))]
20+
pub fn compute(
21+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] buffer: &mut [u32],
22+
#[spirv(num_subgroups)] num_subgroups: u32,
23+
#[spirv(subgroup_id)] subgroup_id: u32,
24+
#[spirv(subgroup_size)] subgroup_size: u32,
25+
// spirv name differs!
26+
#[spirv(subgroup_local_invocation_id)] subgroup_invocation_id: u32,
27+
#[spirv(subgroup_eq_mask)] subgroup_eq_mask: SubgroupMask,
28+
#[spirv(subgroup_ge_mask)] subgroup_ge_mask: SubgroupMask,
29+
#[spirv(subgroup_gt_mask)] subgroup_gt_mask: SubgroupMask,
30+
#[spirv(subgroup_le_mask)] subgroup_le_mask: SubgroupMask,
31+
#[spirv(subgroup_lt_mask)] subgroup_lt_mask: SubgroupMask,
32+
) {
33+
buffer[0] = num_subgroups + subgroup_id + subgroup_size + subgroup_invocation_id;
34+
buffer[1] = subgroup_eq_mask.x
35+
+ subgroup_ge_mask.x
36+
+ subgroup_gt_mask.x
37+
+ subgroup_le_mask.x
38+
+ subgroup_lt_mask.x;
39+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
OpCapability Shader
2+
OpCapability GroupNonUniform
3+
OpCapability GroupNonUniformBallot
4+
OpMemoryModel Logical Simple
5+
OpEntryPoint GLCompute %1 "compute" %2 %3 %4 %5 %6 %7 %8 %9 %10 %11
6+
OpExecutionMode %1 LocalSize 1 1 1
7+
OpName %2 "buffer"
8+
OpName %3 "num_subgroups"
9+
OpName %4 "subgroup_id"
10+
OpName %5 "subgroup_size"
11+
OpName %6 "subgroup_invocation_id"
12+
OpName %7 "subgroup_eq_mask"
13+
OpName %8 "subgroup_ge_mask"
14+
OpName %9 "subgroup_gt_mask"
15+
OpName %10 "subgroup_le_mask"
16+
OpName %11 "subgroup_lt_mask"
17+
OpDecorate %13 ArrayStride 4
18+
OpDecorate %14 Block
19+
OpMemberDecorate %14 0 Offset 0
20+
OpDecorate %2 Binding 0
21+
OpDecorate %2 DescriptorSet 0
22+
OpDecorate %3 BuiltIn NumSubgroups
23+
OpDecorate %4 BuiltIn SubgroupId
24+
OpDecorate %5 BuiltIn SubgroupSize
25+
OpDecorate %6 BuiltIn SubgroupLocalInvocationId
26+
OpDecorate %7 BuiltIn SubgroupEqMask
27+
OpDecorate %8 BuiltIn SubgroupGeMask
28+
OpDecorate %9 BuiltIn SubgroupGtMask
29+
OpDecorate %10 BuiltIn SubgroupLeMask
30+
OpDecorate %11 BuiltIn SubgroupLtMask
31+
%15 = OpTypeInt 32 0
32+
%13 = OpTypeRuntimeArray %15
33+
%14 = OpTypeStruct %13
34+
%16 = OpTypePointer StorageBuffer %14
35+
%17 = OpTypePointer Input %15
36+
%18 = OpTypeVector %15 4
37+
%19 = OpTypePointer Input %18
38+
%20 = OpTypeVoid
39+
%21 = OpTypeFunction %20
40+
%22 = OpTypePointer StorageBuffer %13
41+
%2 = OpVariable %16 StorageBuffer
42+
%23 = OpConstant %15 0
43+
%3 = OpVariable %17 Input
44+
%4 = OpVariable %17 Input
45+
%5 = OpVariable %17 Input
46+
%6 = OpVariable %17 Input
47+
%7 = OpVariable %19 Input
48+
%8 = OpVariable %19 Input
49+
%9 = OpVariable %19 Input
50+
%10 = OpVariable %19 Input
51+
%11 = OpVariable %19 Input
52+
%24 = OpTypeBool
53+
%25 = OpTypePointer StorageBuffer %15
54+
%26 = OpConstant %15 1
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble-globals
3+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot
4+
// normalize-stderr-test "OpSource .*\n" -> ""
5+
// normalize-stderr-test "OpLine .*\n" -> ""
6+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
7+
// normalize-stderr-test "; .*\n" -> ""
8+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
9+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
10+
// ignore-vulkan1.0
11+
// ignore-spv1.0
12+
// ignore-spv1.1
13+
// ignore-spv1.2
14+
15+
use spirv_std::glam::*;
16+
use spirv_std::spirv;
17+
use spirv_std::subgroup::*;
18+
19+
#[spirv(compute(threads(1)))]
20+
pub fn compute(#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] buffer: &mut [u32]) {
21+
buffer[0] = num_subgroups() + subgroup_id() + subgroup_size() + subgroup_invocation_id();
22+
buffer[1] = subgroup_eq_mask().x
23+
+ subgroup_ge_mask().x
24+
+ subgroup_gt_mask().x
25+
+ subgroup_le_mask().x
26+
+ subgroup_lt_mask().x;
27+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
OpCapability Shader
2+
OpCapability GroupNonUniform
3+
OpCapability GroupNonUniformBallot
4+
OpMemoryModel Logical Simple
5+
OpEntryPoint GLCompute %1 "compute" %2 %3 %4 %5 %6 %7 %8 %9 %10 %11
6+
OpExecutionMode %1 LocalSize 1 1 1
7+
OpName %2 "buffer"
8+
OpDecorate %14 ArrayStride 4
9+
OpDecorate %15 Block
10+
OpMemberDecorate %15 0 Offset 0
11+
OpDecorate %2 Binding 0
12+
OpDecorate %2 DescriptorSet 0
13+
OpDecorate %3 BuiltIn NumSubgroups
14+
OpDecorate %4 BuiltIn SubgroupId
15+
OpDecorate %5 BuiltIn SubgroupSize
16+
OpDecorate %6 BuiltIn SubgroupLocalInvocationId
17+
OpDecorate %7 BuiltIn SubgroupEqMask
18+
OpDecorate %8 BuiltIn SubgroupGeMask
19+
OpDecorate %9 BuiltIn SubgroupGtMask
20+
OpDecorate %10 BuiltIn SubgroupLeMask
21+
OpDecorate %11 BuiltIn SubgroupLtMask
22+
%16 = OpTypeInt 32 0
23+
%14 = OpTypeRuntimeArray %16
24+
%15 = OpTypeStruct %14
25+
%17 = OpTypePointer StorageBuffer %15
26+
%18 = OpTypeVoid
27+
%19 = OpTypeFunction %18
28+
%20 = OpTypePointer StorageBuffer %14
29+
%2 = OpVariable %17 StorageBuffer
30+
%21 = OpConstant %16 0
31+
%22 = OpTypePointer Input %16
32+
%3 = OpVariable %22 Input
33+
%4 = OpVariable %22 Input
34+
%5 = OpVariable %22 Input
35+
%6 = OpVariable %22 Input
36+
%23 = OpTypeBool
37+
%24 = OpTypePointer StorageBuffer %16
38+
%25 = OpTypeVector %16 4
39+
%26 = OpTypePointer Input %25
40+
%7 = OpVariable %26 Input
41+
%8 = OpVariable %26 Input
42+
%9 = OpVariable %26 Input
43+
%10 = OpVariable %26 Input
44+
%11 = OpVariable %26 Input
45+
%27 = OpConstant %16 1

tests/compiletests/ui/subgroup/subgroup_cluster_size_0_fail.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error[E0080]: evaluation panicked: `ClusterSize` must be at least 1
2-
--> $SPIRV_STD_SRC/subgroup.rs:953:1
2+
--> $SPIRV_STD_SRC/subgroup.rs:1107:1
33
|
44
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
55
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.
@@ -13,7 +13,7 @@ LL | | ");
1313
= note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info)
1414

1515
note: erroneous constant encountered
16-
--> $SPIRV_STD_SRC/subgroup.rs:953:1
16+
--> $SPIRV_STD_SRC/subgroup.rs:1107:1
1717
|
1818
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
1919
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.

tests/compiletests/ui/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2
2-
--> $SPIRV_STD_SRC/subgroup.rs:953:1
2+
--> $SPIRV_STD_SRC/subgroup.rs:1107:1
33
|
44
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
55
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.
@@ -13,7 +13,7 @@ LL | | ");
1313
= note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info)
1414

1515
note: erroneous constant encountered
16-
--> $SPIRV_STD_SRC/subgroup.rs:953:1
16+
--> $SPIRV_STD_SRC/subgroup.rs:1107:1
1717
|
1818
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
1919
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.

0 commit comments

Comments
 (0)