Skip to content

Commit 5c0c2d4

Browse files
39aliFirestar99
authored andcommitted
add slice cast support : *[T; N] -> *[T] , *[T; N] -> *T
1 parent 3f7cb64 commit 5c0c2d4

3 files changed

Lines changed: 146 additions & 1 deletion

File tree

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2423,7 +2423,46 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
24232423
"pointercast called on non-pointer dest type: {other:?}"
24242424
)),
24252425
};
2426-
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);
2426+
2427+
let dst_pointee_ty = self.lookup_type(dest_pointee);
2428+
let dest_pointee_size = dst_pointee_ty.sizeof(self);
2429+
let src_pointee_ty = self.lookup_type(ptr_pointee);
2430+
2431+
// array -> element: *[T; N] -> *T
2432+
if let SpirvType::Array {
2433+
element: elem_ty, ..
2434+
} = src_pointee_ty
2435+
&& elem_ty == dest_pointee
2436+
{
2437+
let zero = self.constant_u32(self.span(), 0).def(self);
2438+
return self
2439+
.emit()
2440+
.in_bounds_access_chain(dest_ty, None, ptr.def(self), [zero])
2441+
.unwrap()
2442+
.with_type(dest_ty);
2443+
}
2444+
2445+
// array -> RuntimeArray: *[T; N] -> *[T]
2446+
if let SpirvType::Array {
2447+
element: elem_ty, ..
2448+
} = src_pointee_ty
2449+
&& let SpirvType::RuntimeArray {
2450+
element: rt_elem_ty,
2451+
} = dst_pointee_ty
2452+
&& elem_ty == rt_elem_ty
2453+
{
2454+
let zero = self.constant_u32(self.span(), 0).def(self);
2455+
let elem_ptr_ty = self.type_ptr_to(elem_ty);
2456+
let elem_ptr = self
2457+
.emit()
2458+
.in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero])
2459+
.unwrap();
2460+
return self
2461+
.emit()
2462+
.bitcast(dest_ty, None, elem_ptr)
2463+
.unwrap()
2464+
.with_type(dest_ty);
2465+
}
24272466

24282467
if let Some((indices, _)) = self.recover_access_chain_from_offset(
24292468
ptr_pointee,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "^(; .*\n)*" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-spv1.4
14+
// ignore-spv1.5
15+
// ignore-spv1.6
16+
// ignore-vulkan1.0
17+
// ignore-vulkan1.1
18+
19+
use spirv_std::spirv;
20+
21+
const CONST_ARRAY: [u32; 3] = [1, 2, 3];
22+
23+
#[spirv(compute(threads(64)))]
24+
pub fn main(
25+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32; 3],
26+
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut u32,
27+
) {
28+
let mut out = 0;
29+
30+
// &[u32] fat pointer from a runtime storage buffer array
31+
let slice: &[u32] = input;
32+
out += slice[1];
33+
34+
// *[u32; 3] -> *u32 via AccessChain (array->element path)
35+
let array_ptr: *const [u32; 3] = input;
36+
let element_ptr: *const u32 = array_ptr as *const u32;
37+
out += unsafe { *element_ptr };
38+
39+
// &[u32] fat pointer from a runtime storage buffer array
40+
let slice: &[u32] = &CONST_ARRAY;
41+
out += slice[1];
42+
43+
// *[u32; 3] -> *u32 via AccessChain (array->element path)
44+
let array_ptr: *const [u32; 3] = &CONST_ARRAY;
45+
let element_ptr: *const u32 = array_ptr as *const u32;
46+
out += unsafe { *element_ptr };
47+
48+
*output = out;
49+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint GLCompute %1 "main" %2 %3 %4
4+
OpExecutionMode %1 LocalSize 64 1 1
5+
OpName %2 "input"
6+
OpName %3 "output"
7+
OpDecorate %6 ArrayStride 4
8+
OpDecorate %7 Block
9+
OpMemberDecorate %7 0 Offset 0
10+
OpDecorate %8 Block
11+
OpMemberDecorate %8 0 Offset 0
12+
OpDecorate %2 NonWritable
13+
OpDecorate %2 Binding 0
14+
OpDecorate %2 DescriptorSet 0
15+
OpDecorate %3 Binding 1
16+
OpDecorate %3 DescriptorSet 0
17+
%9 = OpTypeInt 32 0
18+
%10 = OpConstant %9 3
19+
%6 = OpTypeArray %9 %10
20+
%7 = OpTypeStruct %6
21+
%11 = OpTypePointer StorageBuffer %7
22+
%8 = OpTypeStruct %9
23+
%12 = OpTypePointer StorageBuffer %8
24+
%13 = OpTypeVoid
25+
%14 = OpTypeFunction %13
26+
%15 = OpTypePointer StorageBuffer %6
27+
%2 = OpVariable %11 StorageBuffer
28+
%16 = OpConstant %9 0
29+
%17 = OpTypePointer StorageBuffer %9
30+
%3 = OpVariable %12 StorageBuffer
31+
%18 = OpConstant %9 1
32+
%19 = OpTypePointer Private %9
33+
%20 = OpTypeArray %9 %10
34+
%21 = OpTypePointer Private %20
35+
%22 = OpConstant %9 2
36+
%23 = OpConstantComposite %20 %18 %22 %10
37+
%4 = OpVariable %21 Private %23
38+
%1 = OpFunction %13 None %14
39+
%24 = OpLabel
40+
%25 = OpInBoundsAccessChain %15 %2 %16
41+
%26 = OpInBoundsAccessChain %17 %3 %16
42+
%27 = OpInBoundsAccessChain %17 %25 %18
43+
%28 = OpLoad %9 %27
44+
%29 = OpIAdd %9 %16 %28
45+
%30 = OpInBoundsAccessChain %17 %25 %16
46+
%31 = OpLoad %9 %30
47+
%32 = OpIAdd %9 %29 %31
48+
%33 = OpInBoundsAccessChain %19 %4 %18
49+
%34 = OpLoad %9 %33
50+
%35 = OpIAdd %9 %32 %34
51+
%36 = OpInBoundsAccessChain %19 %4 %16
52+
%37 = OpLoad %9 %36
53+
%38 = OpIAdd %9 %35 %37
54+
OpStore %26 %38
55+
OpNoLine
56+
OpReturn
57+
OpFunctionEnd

0 commit comments

Comments
 (0)