Skip to content

Commit 94a91ce

Browse files
minansysCopilot
authored andcommitted
resolve the comments
Co-authored-by: Copilot <copilot@github.com>
1 parent 5d1ed80 commit 94a91ce

2 files changed

Lines changed: 33 additions & 10 deletions

File tree

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6323,6 +6323,34 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
63236323
switch (II->getIntrinsicID()) {
63246324
default:
63256325
goto end;
6326+
#if LLVM_VERSION_MAJOR >= 12
6327+
case Intrinsic::smax:
6328+
case Intrinsic::smin:
6329+
case Intrinsic::umax:
6330+
case Intrinsic::umin: {
6331+
Value *val0 = invertPointerM(II->getArgOperand(0), bb, nullShadow);
6332+
Value *val1 = invertPointerM(II->getArgOperand(1), bb, nullShadow);
6333+
assert(val0->getType() == val1->getType());
6334+
6335+
auto rule = [&](Value *val0, Value *val1) {
6336+
Value *args[] = {val0, val1};
6337+
auto shadow = cast<CallInst>(
6338+
bb.CreateCall(II->getCalledFunction(), args,
6339+
II->getName() + "'ipbi"));
6340+
shadow->setAttributes(II->getAttributes());
6341+
shadow->setCallingConv(II->getCallingConv());
6342+
shadow->setTailCallKind(II->getTailCallKind());
6343+
shadow->setDebugLoc(getNewFromOriginal(II->getDebugLoc()));
6344+
return shadow;
6345+
};
6346+
6347+
Value *shadow = applyChainRule(II->getType(), bb, rule, val0, val1);
6348+
6349+
invertedPointers.insert(
6350+
std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow)));
6351+
return shadow;
6352+
}
6353+
#endif
63266354
#if LLVM_VERSION_MAJOR < 20
63276355
case Intrinsic::nvvm_ldg_global_i:
63286356
case Intrinsic::nvvm_ldg_global_p:
@@ -6573,13 +6601,6 @@ end:;
65736601
return Constant::getNullValue(getShadowType(oval->getType()));
65746602
}
65756603

6576-
if (looseTypeAnalysis && oval->getType()->isIntOrIntVectorTy()) {
6577-
auto *shadow = Constant::getNullValue(getShadowType(oval->getType()));
6578-
invertedPointers.insert(
6579-
std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow)));
6580-
return shadow;
6581-
}
6582-
65836604
if (CustomErrorHandler) {
65846605
std::string str;
65856606
raw_string_ostream ss(str);

enzyme/test/Enzyme/ReverseMode/loosetypes_umax.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
; @llvm.umax.i32. Without the fix, this hits:
1414
; assert(0 && "cannot find deal with ptr that isnt arg")
1515
;
16-
; The fix returns a zero shadow for integer values under loose-types analysis.
16+
; The fix handles llvm.umax as a binop-like intrinsic during shadow
17+
; reconstruction, instead of falling through to the generic no-shadow path.
1718

1819
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
1920
target triple = "x86_64-unknown-linux-gnu"
@@ -73,9 +74,10 @@ declare void @__enzyme_autodiff(...)
7374
; CHECK-SAME: i32 %face_idx, ptr noalias
7475
; CHECK-SAME: %nbr)
7576

76-
; Verify the primal block correctly computes and stores i32 umax with zero shadow
77+
; Verify the primal block computes both the original and shadow i32 umax values.
7778
; CHECK: %new_flag = tail call i32 @llvm.umax.i32(i32 %old_flag, i32 %flag)
78-
; CHECK-NEXT: store i32 0, ptr %"flags_elem'ipg"
79+
; CHECK-NEXT: %[[NEW_FLAG_IP:.+]] = {{(tail )?}}call i32 @llvm.umax.i32(i32 %{{.+}}, i32 %{{.+}})
80+
; CHECK-NEXT: store i32 %[[NEW_FLAG_IP]], ptr %"flags_elem'ipg"
7981
; CHECK-NEXT: store i32 %new_flag, ptr %flags_elem
8082

8183
; Verify the reverse block correctly propagates float derivatives

0 commit comments

Comments
 (0)