Skip to content

Commit e46d090

Browse files
committed
fix array slice cast
1 parent 3f7cb64 commit e46d090

5 files changed

Lines changed: 263 additions & 1 deletion

File tree

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2423,7 +2423,28 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
24232423
"pointercast called on non-pointer dest type: {other:?}"
24242424
)),
24252425
};
2426-
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);
2426+
2427+
let dst_pointee_ty = self.lookup_type(dest_pointee);
2428+
let dest_pointee_size = dst_pointee_ty.sizeof(self);
2429+
let src_pointee_ty = self.lookup_type(ptr_pointee);
2430+
2431+
// *[T; N] -> *RuntimeArray<T>
2432+
if let SpirvType::Array { element: elem_ty, .. } = src_pointee_ty
2433+
&& let SpirvType::RuntimeArray { element: rt_elem_ty } = dst_pointee_ty
2434+
&& elem_ty == rt_elem_ty
2435+
{
2436+
let zero = self.constant_u32(self.span(), 0).def(self);
2437+
let elem_ptr_ty = self.type_ptr_to(elem_ty);
2438+
let elem_ptr = self
2439+
.emit()
2440+
.in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero])
2441+
.unwrap();
2442+
return self
2443+
.emit()
2444+
.bitcast(dest_ty, None, elem_ptr)
2445+
.unwrap()
2446+
.with_type(dest_ty);
2447+
}
24272448

24282449
if let Some((indices, _)) = self.recover_access_chain_from_offset(
24292450
ptr_pointee,

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ pub fn link(
660660
peephole_opts::composite_construct(&types, func);
661661
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
662662
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
663+
peephole_opts::fold_array_bitcast_access_chain(&types, func);
663664
}
664665
}
665666

crates/rustc_codegen_spirv/src/linker/peephole_opts.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,94 @@ pub fn fold_load_from_constant_variable(module: &mut Module) {
680680
}
681681
}
682682
}
683+
684+
/// Eliminate the `OpBitcast` that arises from `*[T; N] → *RuntimeArray<T>` pointer casts.
685+
///
686+
/// When a local array is coerced to a slice (`&[T;N] as &[T]`), codegen emits:
687+
/// ```text
688+
/// %elem0 = OpInBoundsAccessChain %arr 0 ; pointer to arr[0]
689+
/// %rta = OpBitcast %elem0 ; *RuntimeArray<T> (invalid in Logical SPIR-V)
690+
/// ```
691+
/// After inlining the slice-taking function, indexing becomes:
692+
/// ```text
693+
/// %ei = OpInBoundsAccessChain %rta i ; data[i]
694+
///
695+
pub fn fold_array_bitcast_access_chain(
696+
types: &FxHashMap<Word, Instruction>,
697+
function: &mut Function,
698+
) {
699+
let func_defs: FxHashMap<Word, Instruction> = function
700+
.all_inst_iter()
701+
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
702+
.collect();
703+
704+
// look up an ID in either function-local defs or module-level types/globals.
705+
let lookup = |id: Word| -> Option<&Instruction> {
706+
func_defs.get(&id).or_else(|| types.get(&id))
707+
};
708+
709+
for block in &mut function.blocks {
710+
for inst in &mut block.instructions {
711+
if !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) {
712+
continue;
713+
}
714+
if inst.operands.is_empty() {
715+
continue;
716+
}
717+
let base_id = inst.operands[0].unwrap_id_ref();
718+
719+
// base must be an OpBitcast
720+
let Some(bitcast) = lookup(base_id) else { continue };
721+
if bitcast.class.opcode != Op::Bitcast {
722+
continue;
723+
}
724+
725+
// bitcast result type must be *SC RuntimeArray<T>
726+
let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else { continue };
727+
if bitcast_dst_ptr.class.opcode != Op::TypePointer {
728+
continue;
729+
}
730+
let rta_type_id = bitcast_dst_ptr.operands[1].unwrap_id_ref();
731+
let Some(rta_ty) = lookup(rta_type_id) else { continue };
732+
if rta_ty.class.opcode != Op::TypeRuntimeArray {
733+
continue;
734+
}
735+
let rta_elem_ty = rta_ty.operands[0].unwrap_id_ref();
736+
737+
// bitcast source must be OpInBoundsAccessChain(arr, 0)
738+
let bitcast_src_id = bitcast.operands[0].unwrap_id_ref();
739+
let Some(inner_ac) = lookup(bitcast_src_id) else { continue };
740+
if inner_ac.class.opcode != Op::InBoundsAccessChain {
741+
continue;
742+
}
743+
// Exactly one index operand
744+
if inner_ac.operands.len() != 2 {
745+
continue;
746+
}
747+
// That index must be the constant 0
748+
let idx0_id = inner_ac.operands[1].unwrap_id_ref();
749+
let Some(idx0) = lookup(idx0_id) else { continue };
750+
if idx0.class.opcode != Op::Constant {
751+
continue;
752+
}
753+
if !matches!(idx0.operands[0], Operand::LiteralBit32(0)) {
754+
continue;
755+
}
756+
757+
// inner AccessChain result type must be *SC T where T == rta_elem_ty
758+
let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else { continue };
759+
if inner_dst_ptr.class.opcode != Op::TypePointer {
760+
continue;
761+
}
762+
let elem_ty = inner_dst_ptr.operands[1].unwrap_id_ref();
763+
if elem_ty != rta_elem_ty {
764+
continue;
765+
}
766+
767+
// AccessChain(Bitcast(InBoundsAccessChain(arr, 0)), i)
768+
// Replace base with arr — the dead bitcast and intermediate AC are cleaned by DCE.
769+
let arr_id = inner_ac.operands[0].unwrap_id_ref();
770+
inst.operands[0] = Operand::IdRef(arr_id);
771+
}
772+
}
773+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// build-pass
2+
// compile-flags: -C llvm-args=--disassemble
3+
// normalize-stderr-test "OpSource .*\n" -> ""
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
6+
// normalize-stderr-test "^(; .*\n)*" -> ""
7+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
// ignore-spv1.0
10+
// ignore-spv1.1
11+
// ignore-spv1.2
12+
// ignore-spv1.3
13+
// ignore-spv1.4
14+
// ignore-spv1.5
15+
// ignore-spv1.6
16+
// ignore-vulkan1.0
17+
// ignore-vulkan1.1
18+
use spirv_std::spirv;
19+
20+
fn do_work(data: &[u32], slab: &mut [u32]) {
21+
slab[0] = data[0];
22+
slab[1] = data[1];
23+
slab[2] = data[2];
24+
}
25+
26+
#[spirv(compute(threads(64)))]
27+
pub fn compute_shader(
28+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &mut [u32],
29+
#[spirv(global_invocation_id)] global_id: glam::UVec3,
30+
) {
31+
let data = [global_id.x, global_id.y, global_id.z];
32+
do_work(&data, slab);
33+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint GLCompute %1 "compute_shader" %2 %3
4+
OpExecutionMode %1 LocalSize 64 1 1
5+
OpName %2 "slab"
6+
OpName %3 "global_id"
7+
OpDecorate %5 ArrayStride 4
8+
OpDecorate %6 Block
9+
OpMemberDecorate %6 0 Offset 0
10+
OpDecorate %2 Binding 0
11+
OpDecorate %2 DescriptorSet 0
12+
OpDecorate %3 BuiltIn GlobalInvocationId
13+
%7 = OpTypeInt 32 0
14+
%5 = OpTypeRuntimeArray %7
15+
%6 = OpTypeStruct %5
16+
%8 = OpTypePointer StorageBuffer %6
17+
%9 = OpTypeVector %7 3
18+
%10 = OpTypePointer Input %9
19+
%11 = OpTypeVoid
20+
%12 = OpTypeFunction %11
21+
%13 = OpConstant %7 3
22+
%14 = OpTypeArray %7 %13
23+
%15 = OpTypePointer Function %14
24+
%16 = OpTypePointer StorageBuffer %5
25+
%2 = OpVariable %8 StorageBuffer
26+
%17 = OpConstant %7 0
27+
%3 = OpVariable %10 Input
28+
%18 = OpTypePointer Function %7
29+
%19 = OpConstant %7 1
30+
%20 = OpConstant %7 2
31+
%21 = OpTypeBool
32+
%22 = OpTypePointer StorageBuffer %7
33+
%1 = OpFunction %11 None %12
34+
%23 = OpLabel
35+
%24 = OpVariable %15 Function
36+
%25 = OpInBoundsAccessChain %16 %2 %17
37+
%26 = OpArrayLength %7 %2 0
38+
%27 = OpLoad %9 %3
39+
%28 = OpCompositeExtract %7 %27 0
40+
%29 = OpCompositeExtract %7 %27 1
41+
%30 = OpCompositeExtract %7 %27 2
42+
%31 = OpInBoundsAccessChain %18 %24 %17
43+
OpStore %31 %28
44+
%32 = OpInBoundsAccessChain %18 %24 %19
45+
OpStore %32 %29
46+
%33 = OpInBoundsAccessChain %18 %24 %20
47+
OpStore %33 %30
48+
%34 = OpULessThan %21 %17 %13
49+
OpNoLine
50+
OpSelectionMerge %35 None
51+
OpBranchConditional %34 %36 %37
52+
%36 = OpLabel
53+
OpBranch %35
54+
%37 = OpLabel
55+
OpReturn
56+
%35 = OpLabel
57+
%38 = OpInBoundsAccessChain %18 %24 %17
58+
%39 = OpLoad %7 %38
59+
%40 = OpULessThan %21 %17 %26
60+
OpNoLine
61+
OpSelectionMerge %41 None
62+
OpBranchConditional %40 %42 %43
63+
%42 = OpLabel
64+
OpBranch %41
65+
%43 = OpLabel
66+
OpReturn
67+
%41 = OpLabel
68+
%44 = OpInBoundsAccessChain %22 %25 %17
69+
OpStore %44 %39
70+
%45 = OpULessThan %21 %19 %13
71+
OpNoLine
72+
OpSelectionMerge %46 None
73+
OpBranchConditional %45 %47 %48
74+
%47 = OpLabel
75+
OpBranch %46
76+
%48 = OpLabel
77+
OpReturn
78+
%46 = OpLabel
79+
%49 = OpInBoundsAccessChain %18 %24 %19
80+
%50 = OpLoad %7 %49
81+
%51 = OpULessThan %21 %19 %26
82+
OpNoLine
83+
OpSelectionMerge %52 None
84+
OpBranchConditional %51 %53 %54
85+
%53 = OpLabel
86+
OpBranch %52
87+
%54 = OpLabel
88+
OpReturn
89+
%52 = OpLabel
90+
%55 = OpInBoundsAccessChain %22 %25 %19
91+
OpStore %55 %50
92+
%56 = OpULessThan %21 %20 %13
93+
OpNoLine
94+
OpSelectionMerge %57 None
95+
OpBranchConditional %56 %58 %59
96+
%58 = OpLabel
97+
OpBranch %57
98+
%59 = OpLabel
99+
OpReturn
100+
%57 = OpLabel
101+
%60 = OpInBoundsAccessChain %18 %24 %20
102+
%61 = OpLoad %7 %60
103+
%62 = OpULessThan %21 %20 %26
104+
OpNoLine
105+
OpSelectionMerge %63 None
106+
OpBranchConditional %62 %64 %65
107+
%64 = OpLabel
108+
OpBranch %63
109+
%65 = OpLabel
110+
OpReturn
111+
%63 = OpLabel
112+
%66 = OpInBoundsAccessChain %22 %25 %20
113+
OpStore %66 %61
114+
OpNoLine
115+
OpReturn
116+
OpFunctionEnd

0 commit comments

Comments
 (0)