Skip to content

Commit c0b383c

Browse files
committed
Add more info about where autodiff can be applied
1 parent e161368 commit c0b383c

3 files changed

Lines changed: 155 additions & 2 deletions

File tree

library/core/src/autodiff.md

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
This module provides support for automatic differentiation. For precise information on
2+
differences between the `autodiff_forward` and `autodiff_reverse` macros and how to
3+
use them, see their respective documentation.
4+
5+
## General usage
6+
7+
Autodiff macros can be applied to almost all function definitions, see below for examples.
8+
They can be applied to functions accepting structs, arrays, slices, vectors, tuples, and more.
9+
10+
It is possible to apply multiple autodiff macros to the same function. As an example, this can
11+
be helpful to compute the partial derivatives with respect to `x` and `y` independently:
12+
```rust,ignore (optional component)
13+
#[autodiff_forward(dsquare1, Dual, Const, Dual)]
14+
#[autodiff_forward(dsquare2, Const, Dual, Dual)]
15+
#[autodiff_forward(dsquare3, Active, Active, Active)]
16+
fn square(x: f64, y: f64) -> f64 {
17+
x * x + 2.0 * y
18+
}
19+
```
20+
21+
We also support autodiff on functions with generic parameters:
22+
```rust,ignore (optional component)
23+
#[autodiff_forward(generic_derivative, Duplicated, Active)]
24+
fn generic_f<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
25+
x * x
26+
}
27+
```
28+
29+
or applying autodiff to nested functions:
30+
```rust,ignore (optional component)
31+
fn outer(x: f64) -> f64 {
32+
#[autodiff_forward(inner_derivative, Dual, Const)]
33+
fn inner(y: f64) -> f64 {
34+
y * y
35+
}
36+
inner_derivative(x, 1.0)
37+
}
38+
39+
fn main() {
40+
assert_eq!(outer(3.14), 6.28);
41+
}
42+
```
43+
The generated function will be available in the same scope as the function differentiated, and
44+
have the same private/pub usability.
45+
46+
## Traits and impls
47+
Autodiff macros can be used in multiple ways in combination with traits:
48+
```rust,ignore (optional component)
49+
struct Foo {
50+
a: f64,
51+
}
52+
53+
trait MyTrait {
54+
#[autodiff_reverse(df, Const, Active, Active)]
55+
fn f(&self, x: f64) -> f64;
56+
}
57+
58+
impl MyTrait for Foo {
59+
fn f(&self, x: f64) -> f64 {
60+
x.sin()
61+
}
62+
}
63+
64+
fn main() {
65+
let foo = Foo { a: 3.0f64 };
66+
assert_eq!(foo.f(2.0), 2.0_f64.sin());
67+
assert_eq!(foo.df(2.0, 1.0).1, 2.0_f64.cos());
68+
}
69+
```
70+
In this case `df` will be the default implementation provided by the library who provided the
71+
trait. A user implementing `MyTrait` could then decide to use the default implementation of
72+
`df`, or overwrite it with a custom implementation as a form of "custom derivatives".
73+
74+
On the other hand, a function generated by either autodiff macro can also be used to implement a
75+
trait:
76+
```rust,ignore (optional component)
77+
struct Foo {
78+
a: f64,
79+
}
80+
81+
trait MyTrait {
82+
fn f(&self, x: f64) -> f64;
83+
fn df(&self, x: f64, seed: f64) -> (f64, f64);
84+
}
85+
86+
impl MyTrait for Foo {
87+
#[autodiff_reverse(df, Const, Active, Active)]
88+
fn f(&self, x: f64) -> f64 {
89+
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
90+
}
91+
}
92+
```
93+
94+
Simple `impl` blocks without traits are also supported. Differentiating with respect to the
95+
implemented struct will then require the use of a "shadow struct" to hold the derivatives of the
96+
struct fields:
97+
98+
```rust,ignore (optional component)
99+
struct OptProblem {
100+
a: f64,
101+
b: f64,
102+
}
103+
104+
impl OptProblem {
105+
#[autodiff_reverse(d_objective, Duplicated, Duplicated, Duplicated)]
106+
fn objective(&self, x: &[f64], out: &mut f64) {
107+
*out = self.a + x[0].sqrt() * self.b
108+
}
109+
}
110+
fn main() {
111+
let p = OptProblem { a: 1., b: 2. };
112+
let mut p_shadow = OptProblem { a: 0., b: 0. };
113+
let mut dx = [0.0];
114+
let mut out = 0.0;
115+
let mut dout = 1.0;
116+
117+
p.d_objective(&mut p_shadow, &x, &mut dx, &mut out, &mut dout);
118+
}
119+
```
120+
121+
## Higher-order derivatives
122+
Finally, it is possible to generate higher-order derivatives (e.g. Hessian) by applying an
123+
autodiff macro to a function that is already generated by an autodiff macro, via a thin wrapper.
124+
The following example uses Forward mode over Reverse mode
125+
126+
```rust,ignore (optional component)
127+
#[autodiff_reverse(df, Duplicated, Duplicated)]
128+
fn f(x: &[f64;2], y: &mut f64) {
129+
*y = x[0] * x[0] + x[1] * x[0]
130+
}
131+
132+
#[autodiff_forward(h, Dual, Dual, Dual, Dual)]
133+
fn wrapper(x: &[f64;2], dx: &mut [f64;2], y: &mut f64, dy: &mut f64) {
134+
df(x, dx, y, dy);
135+
}
136+
137+
fn main() {
138+
let mut y = 0.0;
139+
let x = [2.0, 2.0];
140+
141+
let mut dy = 0.0;
142+
let mut dx = [1.0, 0.0];
143+
144+
let mut bx = [0.0, 0.0];
145+
let mut by = 1.0;
146+
let mut dbx = [0.0, 0.0];
147+
let mut dby = 0.0;
148+
h(&x, &mut dx, &mut bx, &mut dbx, &mut y, &mut dy, &mut by, &mut dby);
149+
assert_eq!(&dbx, [2.0, 1.0]);
150+
}
151+
```
152+
153+

library/core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ pub mod from {
218218

219219
// We don't export this through #[macro_export] for now, to avoid breakage.
220220
#[unstable(feature = "autodiff", issue = "124509")]
221-
/// Unstable module containing the unstable `autodiff` macro.
221+
#[doc = include_str!("../../core/src/autodiff.md")]
222222
pub mod autodiff {
223223
#[unstable(feature = "autodiff", issue = "124509")]
224224
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};

library/std/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ pub mod simd {
634634
}
635635

636636
#[unstable(feature = "autodiff", issue = "124509")]
637-
/// This module provides support for automatic differentiation.
637+
#[doc = include_str!("../../core/src/autodiff.md")]
638638
pub mod autodiff {
639639
/// This macro handles automatic differentiation.
640640
pub use core::autodiff::{autodiff_forward, autodiff_reverse};

0 commit comments

Comments
 (0)