@@ -21,17 +21,20 @@ target triple = "x86_64-unknown-linux-gnu"
2121
2222declare i32 @llvm.umax.i32 (i32 , i32 )
2323
24- define void @compute_alpha (ptr noalias %op , i32 %face_idx , ptr noalias %nbr ) {
24+ define void @compute_alpha (i8* noalias %op , i32 %face_idx , i32* noalias %nbr ) {
2525entry:
26- %out_gep = getelementptr i8 , ptr %op , i64 0
27- %out_ptr = load ptr , ptr %out_gep , align 8
28- %flags_gep = getelementptr i8 , ptr %op , i64 8
29- %flags_ptr = load ptr , ptr %flags_gep , align 8
30- %clip_gep = getelementptr i8 , ptr %op , i64 16
31- %clip_val = load float , ptr %clip_gep , align 4
26+ %out_gep = getelementptr i8 , i8* %op , i64 0
27+ %out_ptr_gep = bitcast i8* %out_gep to float **
28+ %out_ptr = load float *, float ** %out_ptr_gep , align 8
29+ %flags_gep = getelementptr i8 , i8* %op , i64 8
30+ %flags_ptr_gep = bitcast i8* %flags_gep to i32**
31+ %flags_ptr = load i32* , i32** %flags_ptr_gep , align 8
32+ %clip_gep = getelementptr i8 , i8* %op , i64 16
33+ %clip_ptr = bitcast i8* %clip_gep to float *
34+ %clip_val = load float , float * %clip_ptr , align 4
3235 %idx = zext i32 %face_idx to i64
33- %face_gep = getelementptr float , ptr %out_ptr , i64 %idx
34- %face_val = load float , ptr %face_gep , align 4
36+ %face_gep = getelementptr float , float * %out_ptr , i64 %idx
37+ %face_val = load float , float * %face_gep , align 4
3538 %dot = fmul float %face_val , %clip_val
3639 %cmp = fcmp ole float %dot , 0 .0
3740 br i1 %cmp , label %left_handed , label %normal
@@ -45,40 +48,40 @@ normal:
4548merge:
4649 %result = phi float [ %clip_val , %left_handed ], [ %dot , %normal ]
4750 %flag = phi i32 [ 1 , %left_handed ], [ 0 , %normal ]
48- store float %result , ptr %face_gep , align 4
49- %n0 = load i32 , ptr %nbr , align 4
51+ store float %result , float * %face_gep , align 4
52+ %n0 = load i32 , i32* %nbr , align 4
5053 %n0_ext = zext i32 %n0 to i64
51- %flags_elem = getelementptr i32 , ptr %flags_ptr , i64 %n0_ext
52- %old_flag = load i32 , ptr %flags_elem , align 4
54+ %flags_elem = getelementptr i32 , i32* %flags_ptr , i64 %n0_ext
55+ %old_flag = load i32 , i32* %flags_elem , align 4
5356 %new_flag = tail call i32 @llvm.umax.i32 (i32 %old_flag , i32 %flag )
54- store i32 %new_flag , ptr %flags_elem , align 4
57+ store i32 %new_flag , i32* %flags_elem , align 4
5558 ret void
5659}
5760
58- define void @caller (ptr %op , ptr %d_op , i32 %face_idx , ptr %nbr ) {
61+ define void @caller (i8* %op , i8* %d_op , i32 %face_idx , i32* %nbr ) {
5962entry:
6063 call void (...) @__enzyme_autodiff (
61- ptr @compute_alpha ,
62- metadata !"enzyme_dup" , ptr %op , ptr %d_op ,
64+ i8* bitcast ( void ( i8* , i32 , i32* )* @compute_alpha to i8* ) ,
65+ metadata !"enzyme_dup" , i8* %op , i8* %d_op ,
6366 metadata !"enzyme_const" , i32 %face_idx ,
64- metadata !"enzyme_const" , ptr %nbr
67+ metadata !"enzyme_const" , i32* %nbr
6568 )
6669 ret void
6770}
6871
6972declare void @__enzyme_autodiff (...)
7073
71- ; CHECK: define internal void @diffecompute_alpha(ptr noalias
72- ; CHECK-SAME: %op, ptr
73- ; CHECK-SAME: %"op'"
74- ; CHECK-SAME: i32 %face_idx, ptr noalias
74+ ; CHECK: define internal void @diffecompute_alpha(
75+ ; CHECK-SAME: %op,
76+ ; CHECK-SAME: %"op'",
77+ ; CHECK-SAME: i32 %face_idx,
7578; CHECK-SAME: %nbr)
7679
7780; Verify the primal block no longer uses the old zero-shadow fallback.
7881; CHECK: %new_flag = {{(tail )?}}call i32 @llvm.umax.i32(i32 %old_flag, i32 %flag)
79- ; CHECK-NOT: store i32 0, ptr %"flags_elem'ipg"
80- ; CHECK: store i32 {{.+}}, ptr %"flags_elem'ipg"
81- ; CHECK: store i32 %new_flag, ptr %flags_elem
82+ ; CHECK-NOT: store i32 0, {{( ptr|i32\*)}} %"flags_elem'ipg"
83+ ; CHECK: store i32 {{.+}}, {{( ptr|i32\*)}} %"flags_elem'ipg"
84+ ; CHECK: store i32 %new_flag, {{( ptr|i32\*)}} %flags_elem
8285
8386; Verify the reverse block correctly propagates float derivatives
8487; CHECK: invertentry:
0 commit comments