Skip to content

Commit 8a3d0f4

Browse files
committed
refactor(autodiff): Cast primal to fn ptr, drop Instance::try_resolve for source
1 parent dd99221 commit 8a3d0f4

5 files changed

Lines changed: 93 additions & 54 deletions

File tree

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mod llvm_enzyme {
2020
};
2121
use rustc_expand::base::{Annotatable, ExtCtxt};
2222
use rustc_hir::attrs::RustcAutodiff;
23-
use rustc_span::{Ident, Span, Symbol, sym};
23+
use rustc_span::{Ident, Span, Symbol, kw, sym};
2424
use thin_vec::{ThinVec, thin_vec};
2525
use tracing::{debug, trace};
2626

@@ -197,7 +197,7 @@ mod llvm_enzyme {
197197
/// }
198198
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
199199
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
200-
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
200+
/// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))
201201
/// }
202202
/// ```
203203
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -326,6 +326,7 @@ mod llvm_enzyme {
326326
primal,
327327
first_ident(&meta_item_vec[0]),
328328
span,
329+
&sig,
329330
&d_sig,
330331
&generics,
331332
impl_of_trait,
@@ -496,18 +497,62 @@ mod llvm_enzyme {
496497

497498
// Generate `autodiff` intrinsic call
498499
// ```
499-
// std::intrinsics::autodiff(source, diff, (args))
500+
// std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))
500501
// ```
501502
fn call_autodiff(
502503
ecx: &ExtCtxt<'_>,
503504
primal: Ident,
504505
diff: Ident,
505506
span: Span,
507+
p_sig: &FnSig,
506508
d_sig: &FnSig,
507509
generics: &Generics,
508510
is_impl: bool,
509511
) -> rustc_ast::Stmt {
510512
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
513+
514+
let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));
515+
let fn_ptr_params: ThinVec<ast::Param> = p_sig
516+
.decl
517+
.inputs
518+
.iter()
519+
.map(|param| {
520+
let ty = match &param.ty.kind {
521+
TyKind::ImplicitSelf => self_ty(),
522+
TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(
523+
span,
524+
TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),
525+
),
526+
TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {
527+
ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))
528+
}
529+
_ => param.ty.clone(),
530+
};
531+
ast::Param {
532+
attrs: ast::AttrVec::new(),
533+
ty,
534+
pat: Box::new(ecx.pat_wild(span)),
535+
id: ast::DUMMY_NODE_ID,
536+
span,
537+
is_placeholder: false,
538+
}
539+
})
540+
.collect();
541+
let fn_ptr_ty = ecx.ty(
542+
span,
543+
TyKind::FnPtr(Box::new(ast::FnPtrTy {
544+
safety: p_sig.header.safety,
545+
ext: p_sig.header.ext,
546+
generic_params: ThinVec::new(),
547+
decl: Box::new(ast::FnDecl {
548+
inputs: fn_ptr_params,
549+
output: p_sig.decl.output.clone(),
550+
}),
551+
decl_span: span,
552+
})),
553+
);
554+
let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));
555+
511556
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
512557

513558
let tuple_expr = ecx.expr_tuple(
@@ -529,7 +574,7 @@ mod llvm_enzyme {
529574
let call_expr = ecx.expr_call(
530575
span,
531576
ecx.expr_path(enzyme_path),
532-
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
577+
vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),
533578
);
534579

535580
ecx.stmt_expr(call_expr)

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,29 +1323,8 @@ fn codegen_autodiff<'ll, 'tcx>(
13231323
let ret_ty = sig.output();
13241324
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
13251325

1326-
// Get source, diff, and attrs
1327-
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
1328-
ty::FnDef(def_id, source_params) => (def_id, source_params),
1329-
_ => bug!("invalid autodiff intrinsic args"),
1330-
};
1331-
1332-
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
1333-
Ok(Some(instance)) => instance,
1334-
Ok(None) => bug!(
1335-
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1336-
source_id,
1337-
source_args
1338-
),
1339-
Err(_) => {
1340-
// An error has already been emitted
1341-
return;
1342-
}
1343-
};
1344-
1345-
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
1346-
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
1347-
bug!("could not find source function")
1348-
};
1326+
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
1327+
let fn_to_diff = args[0].immediate();
13491328

13501329
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
13511330
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
@@ -1374,17 +1353,14 @@ fn codegen_autodiff<'ll, 'tcx>(
13741353
bug!("could not find autodiff attrs")
13751354
};
13761355

1377-
let fn_ptr_ty =
1378-
Ty::new_fn_ptr(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()).fn_sig(tcx));
13791356
adjust_activity_to_abi(
13801357
tcx,
1381-
fn_ptr_ty,
1358+
source_fn_ptr_ty,
13821359
TypingEnv::fully_monomorphized(),
13831360
&mut diff_attrs.input_activity,
13841361
);
13851362

1386-
let fnc_tree =
1387-
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1363+
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
13881364

13891365
// Build body
13901366
generate_enzyme_call(

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,51 @@
3535
}
3636
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
3737
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
38-
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
38+
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
39+
df1::<>, (x, bx_0, y))
3940
}
4041
#[rustc_autodiff]
4142
pub fn f2(x: &[f64], y: f64) -> f64 {
4243
::core::panicking::panic("not implemented")
4344
}
4445
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
4546
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
46-
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
47+
::core::intrinsics::autodiff(f2::<> as fn(_: &[f64], _: f64) -> f64,
48+
df2::<>, (x, bx_0, y))
4749
}
4850
#[rustc_autodiff]
4951
pub fn f3(x: &[f64], y: f64) -> f64 {
5052
::core::panicking::panic("not implemented")
5153
}
5254
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
5355
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
54-
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
56+
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
57+
df3::<>, (x, bx_0, y))
5558
}
5659
#[rustc_autodiff]
5760
pub fn f4() {}
5861
#[rustc_autodiff(Forward, 1, None)]
59-
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
62+
pub fn df4() -> () {
63+
::core::intrinsics::autodiff(f4::<> as fn(), df4::<>, ())
64+
}
6065
#[rustc_autodiff]
6166
pub fn f5(x: &[f64], y: f64) -> f64 {
6267
::core::panicking::panic("not implemented")
6368
}
6469
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
6570
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
66-
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
71+
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
72+
df5_y::<>, (x, y, by_0))
6773
}
6874
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
6975
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
70-
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
76+
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
77+
df5_x::<>, (x, bx_0, y))
7178
}
7279
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
7380
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
74-
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
81+
::core::intrinsics::autodiff(f5::<> as fn(_: &[f64], _: f64) -> f64,
82+
df5_rev::<>, (x, dx_0, y, dret))
7583
}
7684
struct DoesNotImplDefault;
7785
#[rustc_autodiff]
@@ -80,50 +88,55 @@
8088
}
8189
#[rustc_autodiff(Forward, 1, Const)]
8290
pub fn df6() -> DoesNotImplDefault {
83-
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
91+
::core::intrinsics::autodiff(f6::<> as fn() -> DoesNotImplDefault,
92+
df6::<>, ())
8493
}
8594
#[rustc_autodiff]
8695
pub fn f7(x: f32) -> () {}
8796
#[rustc_autodiff(Forward, 1, Const, None)]
8897
pub fn df7(x: f32) -> () {
89-
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
98+
::core::intrinsics::autodiff(f7::<> as fn(_: f32) -> (), df7::<>, (x,))
9099
}
91100
#[no_mangle]
92101
#[rustc_autodiff]
93102
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
94103
#[rustc_autodiff(Forward, 4, Dual, Dual)]
95104
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
96105
-> [f32; 5usize] {
97-
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
106+
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_3::<>,
98107
(x, bx_0, bx_1, bx_2, bx_3))
99108
}
100109
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
101110
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
102111
-> [f32; 4usize] {
103-
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
112+
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_2::<>,
104113
(x, bx_0, bx_1, bx_2, bx_3))
105114
}
106115
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
107116
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
108-
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
117+
::core::intrinsics::autodiff(f8::<> as fn(_: &f32) -> f32, f8_1::<>,
118+
(x, bx_0))
109119
}
110120
pub fn f9() {
111121
#[rustc_autodiff]
112122
fn inner(x: f32) -> f32 { x * x }
113123
#[rustc_autodiff(Forward, 1, Dual, Dual)]
114124
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
115-
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
125+
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
126+
d_inner_2::<>, (x, bx_0))
116127
}
117128
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
118129
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
119-
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
130+
::core::intrinsics::autodiff(inner::<> as fn(_: f32) -> f32,
131+
d_inner_1::<>, (x, bx_0))
120132
}
121133
}
122134
#[rustc_autodiff]
123135
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
124136
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
125137
pub fn d_square<T: std::ops::Mul<Output = T> +
126138
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
127-
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
139+
::core::intrinsics::autodiff(f10::<T> as fn(_: &T) -> T, d_square::<T>,
140+
(x, dx_0, dret))
128141
}
129142
fn main() {}

tests/pretty/autodiff/autodiff_reverse.pp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,37 @@
2828
}
2929
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
3030
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
31-
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
31+
::core::intrinsics::autodiff(f1::<> as fn(_: &[f64], _: f64) -> f64,
32+
df1::<>, (x, dx_0, y, dret))
3233
}
3334
#[rustc_autodiff]
3435
pub fn f2() {}
3536
#[rustc_autodiff(Reverse, 1, None)]
36-
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
37+
pub fn df2() { ::core::intrinsics::autodiff(f2::<> as fn(), df2::<>, ()) }
3738
#[rustc_autodiff]
3839
pub fn f3(x: &[f64], y: f64) -> f64 {
3940
::core::panicking::panic("not implemented")
4041
}
4142
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
4243
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
43-
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
44+
::core::intrinsics::autodiff(f3::<> as fn(_: &[f64], _: f64) -> f64,
45+
df3::<>, (x, dx_0, y, dret))
4446
}
4547
enum Foo { Reverse, }
4648
use Foo::Reverse;
4749
#[rustc_autodiff]
4850
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
4951
#[rustc_autodiff(Reverse, 1, Const, None)]
50-
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
52+
pub fn df4(x: f32) {
53+
::core::intrinsics::autodiff(f4::<> as fn(_: f32), df4::<>, (x,))
54+
}
5155
#[rustc_autodiff]
5256
pub fn f5(x: *const f32, y: &f32) {
5357
::core::panicking::panic("not implemented")
5458
}
5559
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
5660
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
57-
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
61+
::core::intrinsics::autodiff(f5::<> as fn(_: *const f32, _: &f32),
62+
df5::<>, (x, dx_0, y, dy_0))
5863
}
5964
fn main() {}

tests/pretty/autodiff/inherent_impl.pp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
}
3131
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
3232
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
33-
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
34-
(self, x, dret))
33+
::core::intrinsics::autodiff(Self::f::<> as
34+
fn(_: &Self, _: f64) -> f64, Self::df::<>, (self, x, dret))
3535
}
3636
}

0 commit comments

Comments
 (0)