Skip to main content

Module autodiff

Module autodiff 

Source
🔬This is a nightly-only experimental API. (autodiff #124509)
Expand description

This module provides support for automatic differentiation. For precise information on differences between the autodiff_forward and autodiff_reverse macros and how to use them, see their respective documentation.

§General usage

Autodiff macros can be applied to almost all function definitions, see below for examples. They can be applied to functions accepting structs, arrays, slices, vectors, tuples, and more.

It is possible to apply multiple autodiff macros to the same function. As an example, this can be helpful to compute the partial derivatives with respect to x and y independently:

ⓘ
#[autodiff_forward(dsquare1, Dual, Const, Dual)]
#[autodiff_forward(dsquare2, Const, Dual, Dual)]
#[autodiff_forward(dsquare3, Active, Active, Active)]
fn square(x: f64, y: f64) -> f64 {
  x * x + 2.0 * y
}

We also support autodiff on functions with generic parameters:

ⓘ
#[autodiff_forward(generic_derivative, Duplicated, Active)]
fn generic_f<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
 x * x
}

or applying autodiff to nested functions:

ⓘ
fn outer(x: f64) -> f64 {
  #[autodiff_forward(inner_derivative, Dual, Const)]
  fn inner(y: f64) -> f64 {
    y * y
  }
  inner_derivative(x, 1.0)
}

fn main() {
    assert_eq!(outer(3.14), 6.28);
}

The generated function will be available in the same scope as the function differentiated, and have the same private/pub usability.

§Traits and impls

Autodiff macros can be used in multiple ways in combination with traits:

ⓘ
struct Foo {
    a: f64,
}

trait MyTrait {
    #[autodiff_reverse(df, Const, Active, Active)]
    fn f(&self, x: f64) -> f64;
}

impl MyTrait for Foo {
    fn f(&self, x: f64) -> f64 {
        x.sin()
    }
}

fn main() {
    let foo = Foo { a: 3.0f64 };
    assert_eq!(foo.f(2.0), 2.0_f64.sin());
    assert_eq!(foo.df(2.0, 1.0).1, 2.0_f64.cos());
}

In this case df will be the default implementation provided by the library who provided the trait. A user implementing MyTrait could then decide to use the default implementation of df, or overwrite it with a custom implementation as a form of “custom derivatives”.

On the other hand, a function generated by either autodiff macro can also be used to implement a trait:

ⓘ
struct Foo {
    a: f64,
}

trait MyTrait {
    fn f(&self, x: f64) -> f64;
    fn df(&self, x: f64, seed: f64) -> (f64, f64);
}

impl MyTrait for Foo {
    #[autodiff_reverse(df, Const, Active, Active)]
    fn f(&self, x: f64) -> f64 {
        self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
    }
}

Simple impl blocks without traits are also supported. Differentiating with respect to the implemented struct will then require the use of a “shadow struct” to hold the derivatives of the struct fields:

ⓘ
struct OptProblem {
    a: f64,
    b: f64,
}

impl OptProblem {
    #[autodiff_reverse(d_objective, Duplicated, Duplicated, Duplicated)]
    fn objective(&self, x: &[f64], out: &mut f64) {
        *out = self.a + x[0].sqrt() * self.b
    }
}
fn main() {
    let p = OptProblem { a: 1., b: 2. };
    let mut p_shadow = OptProblem { a: 0., b: 0. };
    let mut dx = [0.0];
    let mut out = 0.0;
    let mut dout = 1.0;

    p.d_objective(&mut p_shadow, &x, &mut dx, &mut out, &mut dout);
}

§Higher-order derivatives

Finally, it is possible to generate higher-order derivatives (e.g. Hessian) by applying an autodiff macro to a function that is already generated by an autodiff macro, via a thin wrapper. The following example uses Forward mode over Reverse mode

ⓘ
#[autodiff_reverse(df, Duplicated, Duplicated)]
fn f(x: &[f64;2], y: &mut f64) {
  *y = x[0] * x[0] + x[1] * x[0]
}

#[autodiff_forward(h, Dual, Dual, Dual, Dual)]
fn wrapper(x: &[f64;2], dx: &mut [f64;2], y: &mut f64, dy: &mut f64) {
  df(x, dx, y, dy);
}

fn main() {
    let mut y = 0.0;
    let x = [2.0, 2.0];

    let mut dy = 0.0;
    let mut dx = [1.0, 0.0];

    let mut bx = [0.0, 0.0];
    let mut by = 1.0;
    let mut dbx = [0.0, 0.0];
    let mut dby = 0.0;
    h(&x, &mut dx, &mut bx, &mut dbx, &mut y, &mut dy, &mut by, &mut dby);
    assert_eq!(&dbx, [2.0, 1.0]);
}

§Current limitations:

  • Differentiating a function which accepts a dyn Trait is currently not supported.
  • Builds without lto="fat" are not yet supported.
  • Builds in debug mode are currently more likely to fail compilation.

Attribute Macros§

autodiff_forwardExperimental
This macro uses forward-mode automatic differentiation to generate a new function. It may only be applied to a function. The new function will compute the derivative of the function to which the macro was applied.
autodiff_reverseExperimental
This macro uses reverse-mode automatic differentiation to generate a new function. It may only be applied to a function. The new function will compute the derivative of the function to which the macro was applied.