@@ -3,6 +3,9 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { DiffActivity , DiffMode } ;
44use rustc_ast:: expand:: typetree:: FncTree ;
55use rustc_codegen_ssa:: common:: TypeKind ;
6+ use rustc_codegen_ssa:: mir:: IntrinsicResult ;
7+ use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
8+ use rustc_codegen_ssa:: mir:: place:: PlaceValue ;
69use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
710use rustc_data_structures:: thin_vec:: ThinVec ;
811use rustc_hir:: attrs:: RustcAutodiff ;
@@ -11,7 +14,7 @@ use rustc_middle::{bug, ty};
1114use rustc_target:: callconv:: PassMode ;
1215use tracing:: debug;
1316
14- use crate :: builder:: { Builder , PlaceRef , UNNAMED } ;
17+ use crate :: builder:: { Builder , UNNAMED } ;
1518use crate :: context:: SimpleCx ;
1619use crate :: declare:: declare_simple_fn;
1720use crate :: llvm:: { self , TRUE , Type , Value } ;
@@ -296,9 +299,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
296299 ret_ty : & ' ll Type ,
297300 fn_args : & [ & ' ll Value ] ,
298301 attrs : & RustcAutodiff ,
299- dest : PlaceRef < ' tcx , & ' ll Value > ,
302+ dest_layout : ty:: layout:: TyAndLayout < ' tcx > ,
303+ dest_place : Option < PlaceValue < & ' ll Value > > ,
300304 fnc_tree : FncTree ,
301- ) {
305+ ) -> IntrinsicResult < ' tcx , & ' ll Value > {
302306 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
303307 let mut ad_name: String = match attrs. mode {
304308 DiffMode :: Forward => "__enzyme_fwddiff" ,
@@ -381,11 +385,18 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
381385 let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
382386
383387 let fn_ret_ty = builder. cx . val_ty ( call) ;
384- if fn_ret_ty != builder. cx . type_void ( ) && fn_ret_ty ! = builder. cx . type_struct ( & [ ] , false ) {
388+ if fn_ret_ty == builder. cx . type_void ( ) || fn_ret_ty = = builder. cx . type_struct ( & [ ] , false ) {
385389 // If we return void or an empty struct, then our caller (due to how we generated it)
386390 // does not expect a return value. As such, we have no pointer (or place) into which
387391 // we could store our value, and would store into an undef, which would cause UB.
388392 // As such, we just ignore the return value in those cases.
389- builder. store_to_place ( call, dest. val ) ;
393+ IntrinsicResult :: Operand ( OperandValue :: ZeroSized )
394+ } else if let Some ( dest_place) = dest_place {
395+ builder. store_to_place ( call, dest_place) ;
396+ IntrinsicResult :: WroteIntoPlace
397+ } else {
398+ IntrinsicResult :: Operand (
399+ OperandRef :: from_immediate_or_packed_pair ( builder, call, dest_layout) . val ,
400+ )
390401 }
391402}
0 commit comments