Skip to content

Commit 53695eb

Browse files
minansysCopilot
andcommitted
update
Co-authored-by: Copilot <copilot@github.com>
1 parent dd7f36c commit 53695eb

1 file changed

Lines changed: 28 additions & 25 deletions

File tree

enzyme/test/Enzyme/ReverseMode/loosetypes_umax.ll

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,20 @@ target triple = "x86_64-unknown-linux-gnu"
2121

2222
declare i32 @llvm.umax.i32(i32, i32)
2323

24-
define void @compute_alpha(ptr noalias %op, i32 %face_idx, ptr noalias %nbr) {
24+
define void @compute_alpha(i8* noalias %op, i32 %face_idx, i32* noalias %nbr) {
2525
entry:
26-
%out_gep = getelementptr i8, ptr %op, i64 0
27-
%out_ptr = load ptr, ptr %out_gep, align 8
28-
%flags_gep = getelementptr i8, ptr %op, i64 8
29-
%flags_ptr = load ptr, ptr %flags_gep, align 8
30-
%clip_gep = getelementptr i8, ptr %op, i64 16
31-
%clip_val = load float, ptr %clip_gep, align 4
26+
%out_gep = getelementptr i8, i8* %op, i64 0
27+
%out_ptr_gep = bitcast i8* %out_gep to float**
28+
%out_ptr = load float*, float** %out_ptr_gep, align 8
29+
%flags_gep = getelementptr i8, i8* %op, i64 8
30+
%flags_ptr_gep = bitcast i8* %flags_gep to i32**
31+
%flags_ptr = load i32*, i32** %flags_ptr_gep, align 8
32+
%clip_gep = getelementptr i8, i8* %op, i64 16
33+
%clip_ptr = bitcast i8* %clip_gep to float*
34+
%clip_val = load float, float* %clip_ptr, align 4
3235
%idx = zext i32 %face_idx to i64
33-
%face_gep = getelementptr float, ptr %out_ptr, i64 %idx
34-
%face_val = load float, ptr %face_gep, align 4
36+
%face_gep = getelementptr float, float* %out_ptr, i64 %idx
37+
%face_val = load float, float* %face_gep, align 4
3538
%dot = fmul float %face_val, %clip_val
3639
%cmp = fcmp ole float %dot, 0.0
3740
br i1 %cmp, label %left_handed, label %normal
@@ -45,40 +48,40 @@ normal:
4548
merge:
4649
%result = phi float [ %clip_val, %left_handed ], [ %dot, %normal ]
4750
%flag = phi i32 [ 1, %left_handed ], [ 0, %normal ]
48-
store float %result, ptr %face_gep, align 4
49-
%n0 = load i32, ptr %nbr, align 4
51+
store float %result, float* %face_gep, align 4
52+
%n0 = load i32, i32* %nbr, align 4
5053
%n0_ext = zext i32 %n0 to i64
51-
%flags_elem = getelementptr i32, ptr %flags_ptr, i64 %n0_ext
52-
%old_flag = load i32, ptr %flags_elem, align 4
54+
%flags_elem = getelementptr i32, i32* %flags_ptr, i64 %n0_ext
55+
%old_flag = load i32, i32* %flags_elem, align 4
5356
%new_flag = tail call i32 @llvm.umax.i32(i32 %old_flag, i32 %flag)
54-
store i32 %new_flag, ptr %flags_elem, align 4
57+
store i32 %new_flag, i32* %flags_elem, align 4
5558
ret void
5659
}
5760

58-
define void @caller(ptr %op, ptr %d_op, i32 %face_idx, ptr %nbr) {
61+
define void @caller(i8* %op, i8* %d_op, i32 %face_idx, i32* %nbr) {
5962
entry:
6063
call void (...) @__enzyme_autodiff(
61-
ptr @compute_alpha,
62-
metadata !"enzyme_dup", ptr %op, ptr %d_op,
64+
i8* bitcast (void (i8*, i32, i32*)* @compute_alpha to i8*),
65+
metadata !"enzyme_dup", i8* %op, i8* %d_op,
6366
metadata !"enzyme_const", i32 %face_idx,
64-
metadata !"enzyme_const", ptr %nbr
67+
metadata !"enzyme_const", i32* %nbr
6568
)
6669
ret void
6770
}
6871

6972
declare void @__enzyme_autodiff(...)
7073

71-
; CHECK: define internal void @diffecompute_alpha(ptr noalias
72-
; CHECK-SAME: %op, ptr
73-
; CHECK-SAME: %"op'"
74-
; CHECK-SAME: i32 %face_idx, ptr noalias
74+
; CHECK: define internal void @diffecompute_alpha(
75+
; CHECK-SAME: %op,
76+
; CHECK-SAME: %"op'",
77+
; CHECK-SAME: i32 %face_idx,
7578
; CHECK-SAME: %nbr)
7679

7780
; Verify the primal block no longer uses the old zero-shadow fallback.
7881
; CHECK: %new_flag = {{(tail )?}}call i32 @llvm.umax.i32(i32 %old_flag, i32 %flag)
79-
; CHECK-NOT: store i32 0, ptr %"flags_elem'ipg"
80-
; CHECK: store i32 {{.+}}, ptr %"flags_elem'ipg"
81-
; CHECK: store i32 %new_flag, ptr %flags_elem
82+
; CHECK-NOT: store i32 0, {{(ptr|i32\*)}} %"flags_elem'ipg"
83+
; CHECK: store i32 {{.+}}, {{(ptr|i32\*)}} %"flags_elem'ipg"
84+
; CHECK: store i32 %new_flag, {{(ptr|i32\*)}} %flags_elem
8285

8386
; Verify the reverse block correctly propagates float derivatives
8487
; CHECK: invertentry:

0 commit comments

Comments
 (0)