#algebra #gradient


Generic automatic differentiation for Rust

3 releases (breaking)

0.2.0 Apr 27, 2021
0.1.0 Apr 22, 2021
0.0.0 Apr 20, 2021

#176 in Science


3.5K SLoC

Build Status gad on crates.io Documentation License License

Generic Automatic Differentiation (GAD)

This library provides automatic differentiation by backward propagation (aka. "autograd") in Rust. It was designed to allow first-class user extensions (e.g. with new array types or new operators) and to support multiples modes of execution with minimal overhead.

The following modes of execution are currently supported for all library-defined operators:

  • first-order differentiation,
  • higher-order differentiation,
  • forward-only evaluation, and
  • dimension checking.

Design Principles

The core of this library consists of a tape-based implementation of automatic differentiation in reverse mode.

We have chosen to prioritize idiomatic Rust in order to make this library as re-usable as possible:

  • The core differentiation algorithm does not use unsafe Rust features or interior mutability (e.g. RefCell). All differentiable expressions explicitly mutate a tape when they are constructed. (Below, the tape variable is noted graph or g.)

  • Fallible operations never panic and always return a Result type. For instance, the sum of two arrays x and y may be written g.add(&x, &y)?.

  • All structures and values implement Send and Sync to support concurrent programming.

  • Generic programming is encouraged so that user formulas can be interpreted in different modes of execution (forward evaluation, dimension checking, etc) with minimal overhead. (See the section below for a code example.)

While this library is primarily motivated by machine learning applications, it is meant to cover other use cases of automatic differentiation in reverse mode. In the sections below, we show how a user may define new operators and add new modes of execution while retaining automatic differentiability.


  • Currently, the usual syntax of operators +, -, *, etc is not available for differentiable values. All operations are method calls of the form g.op(x1, .. xN) (or typically g.op(x1, .. xN)? for fallible operations).

  • Because of a current limitation of the Rust borrow checker, expressions cannot be nested: g.add(&x, &g.mul(&y, &z)?)? must be written let v = g.mul(&y, &z)?; g.add(&x, &v)?.

We believe that this state of affairs could be improved in the future using Rust macros. Alternatively, future extensions of the library could define a new category of differentiable values that contain an implicit RefCell reference to a common tape and provide (implicitly fallible, thread unsafe) operator traits for these values.

Quick Start

To compute gradients, we first build an expression using operations provided by a fresh tape g of type Graph1. Successive algebraic operations modify the internal state of g to track all relevant computations and enables future backward propagation passes.

We then call g.evaluate_gradients(..) to run a backward propagation algorithm from the desired starting point and using an initial gradient value direction.

Unless a one-time optimized variant g.evaluate_gradients_once(..) is used, backward propagation with g.evaluate_gradients(..) does not modify g. This allows successive (or concurrent) backward propagations to be run from different starting points or with different gradient values.

// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values.
let a = g.variable(1f32);
let b = g.variable(2f32);
let c = g.mul(&a, &b)?;
// Compute the derivatives of `c` relative to `a` and `b`
let gradients = g.evaluate_gradients(c.gid()?, 1f32)?;
// Read the `dc/da` component.
assert_eq!(*gradients.get(a.gid()?).unwrap(), 2.0);

Because Graph1, the type of g, provides algebraic operations as methods, below we refer to such a type as an "algebra". GAD uses particular Rust traits to represent the set of operations supported by a given algebra.

Computations with Arrayfire

The default array operations of the library are currently based on Arrayfire, a portable array library supporting GPUs and JIT-compilation.

use arrayfire as af;
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values using Arrayfire arrays
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
let c = g.mul(&a, &b)?;
// Compute gradient of c
let direction = af::constant(1f32, dims);
let gradients = g.evaluate_gradients_once(c.gid()?, direction)?;

After installing the arrayfire library on your system, make sure to

  • select the package feature "arrayfire" in your build file Cargo.toml (e.g. gad = { version = "XX", features = ["arrayfire"]}),

  • run cargo with the environment variable AF_PATH set appropriately (e.g. after export AF_PATH=/usr/local).

Using Generics for Forward Evaluation and Fast Dimension Checking

The algebra Graph1 used in the examples above is one choice amongst several "default" algebras offered by the library:

  • We also provide a special algebra Eval for forward evaluation, that is, running only primitive operations and dimension checks (no tape, no gradients);

  • Similarly, using the algebra Check will check dimensions without evaluating or allocating any array data;

  • Finally, differentiation is obtained using Graph1 for first-order differentials, and GraphN for higher-order differentials.

Users are encouraged to program formulas in a generic way so that any of the default algebras can be chosen.

The following example illustrates such a programming style in the case of array operations:

use arrayfire as af;

fn get_value<A>(g: &mut A) -> Result<<A as AfAlgebra<f32>>::Value>
where A : AfAlgebra<f32>
    let dims = af::Dim4::new(&[4, 3, 1, 1]);
    let a = g.variable(af::randu::<f32>(dims));
    let b = g.variable(af::randu::<f32>(dims));
    g.mul(&a, &b)

// Direct evaluation. The result type is a primitive (non-differentiable) value.
let mut g = Eval::default();
let c : af::Array<f32> = get_value(&mut g)?;

// Fast dimension-checking. The result type is a dimension.
let mut g = Check::default();
let d : af::Dim4 = get_value(&mut g)?;
assert_eq!(c.dims(), d);

Higher-Order Differentials

Higher-order differentials are computed using the algebra GraphN. In this case, gradients are values whose computations is also tracked.

// A new tape supporting differentials of any order.
let mut g = GraphN::new();

// Compute forward values using floats.
let x = g.variable(1.0f32);
let y = g.variable(0.4f32);
// z = x * y^2
let z = {
    let h = g.mul(&x, &y)?;
    g.mul(&h, &y)?
// Use short names for gradient ids.
let (x, y, z) = (x.gid()?, y.gid()?, z.gid()?);

// Compute gradient.
let dz = g.constant(1f32);
let dz_d = g.compute_gradients(z, dz)?;
let dz_dx = dz_d.get(x).unwrap();

// Compute some 2nd-order differentials.
let ddz = g.constant(1f32);
let ddz_dxd = g.compute_gradients(dz_dx.gid()?, ddz)?;
let ddz_dxdy = ddz_dxd.get(y).unwrap().data();
assert_eq!(*ddz_dxdy, 0.8); // 2y

// Compute some 3rd-order differentials.
let dddz = g.constant(1f32);
let dddz_dxdyd = g.compute_gradients(ddz_dxd.get(y).unwrap().gid()?, dddz)?;
let dddz_dxdydy = dddz_dxdyd.get(y).unwrap().data();
assert_eq!(*dddz_dxdydy, 2.0);

Extending Automatic Differentiation

Operations and algebras

The default algebras Eval, Check, Graph1, GraphN are meant to provide interchangeable sets of operations in each of the default modes of operation (respectively, evaluation, dimension-checking, first-order differentiation, and higher-order differentiation).

Default operations are grouped into several traits named *Algebra and implemented by each of the four default algebras above.

  • The special trait CoreAlgebra<Data> defines the mapping from underlying data (e.g. array) to differentiable values. In particular, the method fn variable(&mut self, data: &Data) -> Self::Value creates differentiable variables x whose gradient value can be referred to later by an id written x.gid()? (assuming the algebra is Graph1 or GraphN).

  • Other traits are parameterized over one or several value types. E.g. ArithAlgebra<Value> provides pointwise negation, multiplication, subtraction, etc over Value.

The motivation for using several *Algebra traits is twofold:

  • Users may define their own operations (see next paragraph).

  • Certain operations are more broadly applicable than others.

The following example illustrates gradient computations over integers:

let mut g = Graph1::new();
let a = g.variable(1i32);
let b = g.variable(2i32);
let c = g.sub(&a, &b)?;
assert_eq!(*c.data(), -1);
let gradients = g.evaluate_gradients_once(c.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 1);
assert_eq!(*gradients.get(b.gid()?).unwrap(), -1);

User-defined operations

Users may define new differentiable operations by defining their own *Algebra trait and providing implementations to the default algebras Eval, Check, Graph1, GraphN.

In the following example, we define a new operation square over integers and af-arrays and add support for first-order differentials:

use arrayfire as af;

pub trait UserAlgebra<Value> {
    fn square(&mut self, v: &Value) -> Result<Value>;

impl UserAlgebra<i32> for Eval
    fn square(&mut self, v: &i32) -> Result<i32> { Ok(v * v) }

impl<T> UserAlgebra<af::Array<T>> for Eval
    T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>
    fn square(&mut self, v: &af::Array<T>) -> Result<af::Array<T>> { Ok(v * v) }

impl<D> UserAlgebra<Value<D>> for Graph1
    Eval: CoreAlgebra<D, Value = D>
        + UserAlgebra<D>
        + ArithAlgebra<D>
        + LinkedAlgebra<Value<D>, D>,
    D: HasDims + Clone + 'static + Send + Sync,
    D::Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
    fn square(&mut self, v: &Value<D>) -> Result<Value<D>> {
        let result = self.eval().square(v.data())?;
        let value = self.make_node(result, vec![v.input()], {
            let v = v.clone();
            move |graph, store, gradient| {
                if let Some(id) = v.id() {
                    let c = graph.link(&v);
                    let grad1 = graph.mul(&gradient, c)?;
                    let grad2 = graph.mul(c, &gradient)?;
                    let grad = graph.add(&grad1, &grad2)?;
                    store.add_gradient(graph, id, &grad)?;

fn main() -> Result<()> {
  let mut g = Graph1::new();
  let a = g.variable(3i32);
  let b = g.square(&a)?;
  assert_eq!(*b.data(), 9);
  let gradients = g.evaluate_gradients_once(b.gid()?, 1)?;
  assert_eq!(*gradients.get(a.gid()?).unwrap(), 6);

The implementation for GraphN would be identical to Graph1. We have omitted dimension-checking for simplicity. We refer readers to the test files of the library for a more complete example.

User-defined algebras

Users may define new "evaluation" algebras (similar to Eval) by implementing a subset of operation traits that includes CoreAlgebra<Data, Value=Data> for each supported Data types.

An evaluation-only algebra can be turned into algebras supporting differentiation (similar to Graph1 and GraphN) using the Graph construction provided by the library.

The following example illustrates how to define a new evaluation algebra SymEval then deduce its counterpart SymGraph1:

/// A custom algebra for forward-only symbolic evaluation.
#[derive(Clone, Default)]
struct SymEval;

/// Symbolic expressions of type T.
#[derive(Debug, PartialEq)]
enum Exp_<T> {
    Add(Exp<T>, Exp<T>),
    Mul(Exp<T>, Exp<T>),
    // ...

type Exp<T> = Arc<Exp_<T>>;

impl<T> CoreAlgebra<Exp<T>> for SymEval {
    type Value = Exp<T>;
    fn variable(&mut self, data: Exp<T>) -> Self::Value {
    fn constant(&mut self, data: Exp<T>) -> Self::Value {
    fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value> {
        Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone())))

impl<T> ArithAlgebra<Exp<T>> for SymEval {
    fn neg(&mut self, v: &Exp<T>) -> Exp<T> {
    fn sub(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
        let v2 = self.neg(v2);
        Ok(Arc::new(Exp_::Add(v1.clone(), v2)))
    fn mul(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
        Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone())))
    // ...

// No dimension checks.
impl<T> HasDims for Exp_<T> {
    type Dims = ();
    fn dims(&self) {}

impl<T: std::fmt::Display> std::fmt::Display for Exp_<T> {
    // ...

/// Apply `graph` module to Derive an algebra supporting gradients.
type SymGraph1 = Graph<Config1<SymEval>>;

fn main() -> Result<()> {
    let mut g = SymGraph1::new();
    let a = g.variable(Arc::new(Exp_::Num("a")));
    let b = g.variable(Arc::new(Exp_::Num("b")));
    let c = g.mul(&a, &b)?;
    let d = g.mul(&a, &c)?;
    assert_eq!(format!("{}", d.data()), "aab");
    let gradients = g.evaluate_gradients_once(d.gid()?, Arc::new(Exp_::Num("1")))?;
    assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)");
    assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1");


See the CONTRIBUTING file for how to help out.


This project is available under the terms of either the Apache 2.0 license or the MIT license.


~63K SLoC