Skip to content

Commit f0ead83

Browse files
committed
add slice cast support : *[T; N] -> *[T] , *[T; N] -> *T
1 parent ab9ca5e commit f0ead83

2 files changed

Lines changed: 61 additions & 1 deletion

File tree

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2423,7 +2423,46 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
24232423
"pointercast called on non-pointer dest type: {other:?}"
24242424
)),
24252425
};
2426-
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);
2426+
2427+
let dst_pointee_ty = self.lookup_type(dest_pointee);
2428+
let dest_pointee_size = dst_pointee_ty.sizeof(self);
2429+
let src_pointee_ty = self.lookup_type(ptr_pointee);
2430+
2431+
// array -> element: *[T; N] -> *T
2432+
if let SpirvType::Array {
2433+
element: elem_ty, ..
2434+
} = src_pointee_ty
2435+
&& elem_ty == dest_pointee
2436+
{
2437+
let zero = self.constant_u32(self.span(), 0).def(self);
2438+
return self
2439+
.emit()
2440+
.in_bounds_access_chain(dest_ty, None, ptr.def(self), [zero])
2441+
.unwrap()
2442+
.with_type(dest_ty);
2443+
}
2444+
2445+
// array -> RuntimeArray: *[T; N] -> *[T]
2446+
if let SpirvType::Array {
2447+
element: elem_ty, ..
2448+
} = src_pointee_ty
2449+
&& let SpirvType::RuntimeArray {
2450+
element: rt_elem_ty,
2451+
} = dst_pointee_ty
2452+
&& elem_ty == rt_elem_ty
2453+
{
2454+
let zero = self.constant_u32(self.span(), 0).def(self);
2455+
let elem_ptr_ty = self.type_ptr_to(elem_ty);
2456+
let elem_ptr = self
2457+
.emit()
2458+
.in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero])
2459+
.unwrap();
2460+
return self
2461+
.emit()
2462+
.bitcast(dest_ty, None, elem_ptr)
2463+
.unwrap()
2464+
.with_type(dest_ty);
2465+
}
24272466

24282467
if let Some((indices, _)) = self.recover_access_chain_from_offset(
24292468
ptr_pointee,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
2+
// build-pass
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(64)))]
7+
pub fn main(
8+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32; 3],
9+
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut u32,
10+
) {
11+
// &[u32] fat pointer from a runtime storage buffer array
12+
let slice: &[u32] = input;
13+
let val = slice[1];
14+
15+
// *[u32; 3] -> *u32 via AccessChain (array->element path)
16+
let array_ptr: *const [u32; 3] = input;
17+
let element_ptr: *const u32 = array_ptr as *const u32;
18+
let val2 = unsafe { *element_ptr };
19+
20+
*output = val + val2;
21+
}

0 commit comments

Comments
 (0)