Skip to content

Commit 2dfc1eb

Browse files
fluffysquirrelsFirestar99
authored andcommitted
builtin: Add more compute shader builtins
1 parent 9b22c65 commit 2dfc1eb

File tree

3 files changed

+191
-48
lines changed

3 files changed

+191
-48
lines changed

crates/spirv-std/src/builtin.rs

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,138 @@
1-
//! Symbols to query SPIR-V read-only global built-ins
1+
//! Query SPIR-V read-only global built-in values
2+
//!
3+
//! Reference links:
4+
//! * [WGSL specification describing these builtins](https://www.w3.org/TR/WGSL/#builtin-inputs-outputs)
5+
//! * [SPIR-V specification for builtins](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_builtin)
6+
//! * [GLSL 4.x reference](https://registry.khronos.org/OpenGL-Refpages/gl4/)
27
3-
/// compute shader built-ins
8+
#[cfg(target_arch = "spirv")]
9+
macro_rules! load_builtin {
10+
($ty:ty, $name:ident) => {
11+
unsafe {
12+
let mut result = <$ty>::default();
13+
asm! {
14+
"%builtin = OpVariable typeof{result_ref} Input",
15+
concat!("OpDecorate %builtin BuiltIn ", stringify!($name)),
16+
"%result = OpLoad typeof*{result_ref} %builtin",
17+
"OpStore {result_ref} %result",
18+
result_ref = in(reg) &mut result,
19+
}
20+
result
21+
}
22+
};
23+
}
24+
25+
/// Compute shader built-ins
426
pub mod compute {
527
#[cfg(target_arch = "spirv")]
628
use core::arch::asm;
729
use glam::UVec3;
830

9-
/// GLSL: `gl_LocalInvocationID()`
31+
// Local builtins (for this invocation's position in the workgroup).
32+
33+
/// The current invocation’s local invocation ID,
34+
/// i.e. its position in the workgroup grid.
35+
///
36+
/// GLSL: `gl_LocalInvocationID`
1037
/// WGSL: `local_invocation_id`
1138
#[doc(alias = "gl_LocalInvocationID")]
1239
#[inline]
1340
#[gpu_only]
1441
pub fn local_invocation_id() -> UVec3 {
15-
unsafe {
16-
let result = UVec3::default();
17-
asm! {
18-
"%builtin = OpVariable typeof{result} Input",
19-
"OpDecorate %builtin BuiltIn LocalInvocationId",
20-
"%result = OpLoad typeof*{result} %builtin",
21-
"OpStore {result} %result",
22-
result = in(reg) &result,
23-
}
24-
result
25-
}
42+
load_builtin!(UVec3, LocalInvocationId)
43+
}
44+
45+
/// The current invocation’s local invocation index,
46+
/// a linearized index of the invocation’s position within the workgroup grid.
47+
///
48+
/// GLSL: `gl_LocalInvocationIndex`
49+
/// WGSL: `local_invocation_index`
50+
#[doc(alias = "gl_LocalInvocationIndex")]
51+
#[inline]
52+
#[gpu_only]
53+
pub fn local_invocation_index() -> u32 {
54+
load_builtin!(u32, LocalInvocationIndex)
55+
}
56+
57+
// Global builtins, for this invocation's position in the compute grid.
58+
59+
/// The current invocation’s global invocation ID,
60+
/// i.e. its position in the compute shader grid.
61+
///
62+
/// GLSL: `gl_GlobalInvocationID`
63+
/// WGSL: `global_invocation_id`
64+
#[doc(alias = "gl_GlobalInvocationID")]
65+
#[inline]
66+
#[gpu_only]
67+
pub fn global_invocation_id() -> UVec3 {
68+
load_builtin!(UVec3, GlobalInvocationId)
69+
}
70+
71+
// Subgroup builtins
72+
73+
/// The number of subgroups in the current invocation’s workgroup.
74+
///
75+
/// WGSL: `num_subgroups`
76+
/// No equivalent in GLSL.
77+
#[inline]
78+
#[gpu_only]
79+
pub fn num_subgroups() -> u32 {
80+
load_builtin!(u32, NumSubgroups)
81+
}
82+
83+
/// The subgroup ID of current invocation’s subgroup within the workgroup.
84+
///
85+
/// WGSL: `subgroup_id`
86+
/// No equivalent in GLSL.
87+
#[inline]
88+
#[gpu_only]
89+
pub fn subgroup_id() -> u32 {
90+
load_builtin!(u32, SubgroupId)
91+
}
92+
93+
/// This invocation's ID within its subgroup.
94+
///
95+
/// WGSL: `subgroup_invocation_id`
96+
/// No equivalent in GLSL.
97+
#[doc(alias = "subgroup_invocation_id")]
98+
#[inline]
99+
#[gpu_only]
100+
pub fn subgroup_local_invocation_id() -> u32 {
101+
load_builtin!(u32, SubgroupLocalInvocationId)
102+
}
103+
104+
/// The subgroup size of current invocation’s subgroup.
105+
///
106+
/// WGSL: `subgroup_size`
107+
/// No equivalent in GLSL.
108+
#[inline]
109+
#[gpu_only]
110+
pub fn subgroup_size() -> u32 {
111+
load_builtin!(u32, SubgroupSize)
112+
}
113+
114+
// Workgroup builtins
115+
116+
/// The number of workgroups that have been dispatched in the compute shader grid.
117+
///
118+
/// GLSL: `gl_NumWorkGroups`
119+
/// WGSL: `num_workgroups`
120+
#[doc(alias = "gl_WorkGroupID")]
121+
#[inline]
122+
#[gpu_only]
123+
pub fn num_workgroups() -> UVec3 {
124+
load_builtin!(UVec3, NumWorkgroups)
125+
}
126+
127+
/// The current invocation’s workgroup ID,
128+
/// i.e. the position of the workgroup in the overall compute shader grid.
129+
///
130+
/// GLSL: `gl_WorkGroupID`
131+
/// WGSL: `workgroup_id`
132+
#[doc(alias = "gl_WorkGroupID")]
133+
#[inline]
134+
#[gpu_only]
135+
pub fn workgroup_id() -> UVec3 {
136+
load_builtin!(UVec3, WorkgroupId)
26137
}
27138
}
Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// build-pass
22
// compile-flags: -C llvm-args=--disassemble
3+
// compile-flags: -C target-feature=+GroupNonUniform
34
// normalize-stderr-test "OpLine .*\n" -> ""
45
// normalize-stderr-test "OpSource .*\n" -> ""
56
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
@@ -13,20 +14,26 @@
1314
// ignore-vulkan1.0
1415
// ignore-vulkan1.1
1516

16-
use spirv_std::glam::*;
17-
use spirv_std::spirv;
17+
use spirv_std::{builtin::compute, glam::*, spirv};
1818

1919
#[spirv(compute(threads(1)))]
20-
pub fn compute(
21-
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32,
22-
// #[spirv(global_invocation_id)] global_invocation_id: UVec3,
23-
// #[spirv(local_invocation_id)] local_invocation_id: UVec3,
24-
// #[spirv(subgroup_local_invocation_id)] subgroup_local_invocation_id: u32,
25-
// #[spirv(num_subgroups)] num_subgroups: u32,
26-
// #[spirv(num_workgroups)] num_workgroups: UVec3,
27-
// #[spirv(subgroup_id)] subgroup_id: u32,
28-
// #[spirv(workgroup_id)] workgroup_id: UVec3,
29-
) {
30-
let local_invocation_id = spirv_std::builtin::compute::local_invocation_id();
31-
*out = local_invocation_id.x;
20+
pub fn compute(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32) {
21+
// Local ID's
22+
let _local_invocation_id: UVec3 = compute::local_invocation_id();
23+
let local_invocation_index: u32 = compute::local_invocation_index();
24+
25+
// Global ID's
26+
let _global_invocation_id: UVec3 = compute::global_invocation_id();
27+
28+
// Subgroup ID's
29+
let _num_subgroups: u32 = compute::num_subgroups();
30+
let _subgroup_id: u32 = compute::subgroup_id();
31+
let _subgroup_local_invocation_index: u32 = compute::subgroup_local_invocation_id();
32+
let _subgroup_size: u32 = compute::subgroup_size();
33+
34+
// Workgroup ID's
35+
let _num_workgroups: UVec3 = compute::num_workgroups();
36+
let _workgroup_id: UVec3 = compute::workgroup_id();
37+
38+
*out = local_invocation_index;
3239
}
Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,54 @@
11
OpCapability Shader
2+
OpCapability GroupNonUniform
23
OpMemoryModel Logical Simple
3-
OpEntryPoint GLCompute %1 "compute" %2 %3
4+
OpEntryPoint GLCompute %1 "compute" %2 %3 %4 %5 %6 %7 %8 %9 %10 %11
45
OpExecutionMode %1 LocalSize 1 1 1
5-
OpDecorate %6 Block
6-
OpMemberDecorate %6 0 Offset 0
6+
OpDecorate %14 Block
7+
OpMemberDecorate %14 0 Offset 0
78
OpDecorate %2 Binding 0
89
OpDecorate %2 DescriptorSet 0
910
OpDecorate %3 BuiltIn LocalInvocationId
10-
%7 = OpTypeInt 32 0
11-
%6 = OpTypeStruct %7
12-
%8 = OpTypePointer StorageBuffer %6
13-
%9 = OpTypeVoid
14-
%10 = OpTypeFunction %9
15-
%11 = OpTypePointer StorageBuffer %7
16-
%2 = OpVariable %8 StorageBuffer
17-
%12 = OpConstant %7 0
18-
%13 = OpTypeVector %7 3
19-
%14 = OpTypePointer Input %13
20-
%3 = OpVariable %14 Input
21-
%1 = OpFunction %9 None %10
22-
%15 = OpLabel
23-
%16 = OpInBoundsAccessChain %11 %2 %12
24-
%17 = OpLoad %13 %3
25-
%18 = OpCompositeExtract %7 %17 0
26-
OpStore %16 %18
11+
OpDecorate %4 BuiltIn LocalInvocationIndex
12+
OpDecorate %5 BuiltIn GlobalInvocationId
13+
OpDecorate %6 BuiltIn NumSubgroups
14+
OpDecorate %7 BuiltIn SubgroupId
15+
OpDecorate %8 BuiltIn SubgroupLocalInvocationId
16+
OpDecorate %9 BuiltIn SubgroupSize
17+
OpDecorate %10 BuiltIn NumWorkgroups
18+
OpDecorate %11 BuiltIn WorkgroupId
19+
%15 = OpTypeInt 32 0
20+
%14 = OpTypeStruct %15
21+
%16 = OpTypePointer StorageBuffer %14
22+
%17 = OpTypeVoid
23+
%18 = OpTypeFunction %17
24+
%19 = OpTypePointer StorageBuffer %15
25+
%2 = OpVariable %16 StorageBuffer
26+
%20 = OpConstant %15 0
27+
%21 = OpTypeVector %15 3
28+
%22 = OpTypePointer Input %21
29+
%3 = OpVariable %22 Input
30+
%23 = OpTypePointer Input %15
31+
%4 = OpVariable %23 Input
32+
%5 = OpVariable %22 Input
33+
%6 = OpVariable %23 Input
34+
%7 = OpVariable %23 Input
35+
%8 = OpVariable %23 Input
36+
%9 = OpVariable %23 Input
37+
%10 = OpVariable %22 Input
38+
%11 = OpVariable %22 Input
39+
%1 = OpFunction %17 None %18
40+
%24 = OpLabel
41+
%25 = OpInBoundsAccessChain %19 %2 %20
42+
%26 = OpLoad %21 %3
43+
%27 = OpLoad %15 %4
44+
%28 = OpLoad %21 %5
45+
%29 = OpLoad %15 %6
46+
%30 = OpLoad %15 %7
47+
%31 = OpLoad %15 %8
48+
%32 = OpLoad %15 %9
49+
%33 = OpLoad %21 %10
50+
%34 = OpLoad %21 %11
51+
OpStore %25 %27
2752
OpNoLine
2853
OpReturn
2954
OpFunctionEnd

0 commit comments

Comments
 (0)