3 releases (breaking)
0.2.0 | Apr 27, 2021 |
---|---|
0.1.0 | Apr 22, 2021 |
0.0.0 | Apr 20, 2021 |
#1006 in Algorithms
160KB
3.5K
SLoC
Generic Automatic Differentiation (GAD)
This library provides automatic differentiation by backward propagation (aka. "autograd") in Rust. It was designed to allow first-class user extensions (e.g. with new array types or new operators) and to support multiples modes of execution with minimal overhead.
The following modes of execution are currently supported for all library-defined operators:
- first-order differentiation,
- higher-order differentiation,
- forward-only evaluation, and
- dimension checking.
Design Principles
The core of this library consists of a tape-based implementation of automatic differentiation in reverse mode.
We have chosen to prioritize idiomatic Rust in order to make this library as re-usable as possible:
-
The core differentiation algorithm does not use unsafe Rust features or interior mutability (e.g.
RefCell
). All differentiable expressions explicitly mutate a tape when they are constructed. (Below, the tape variable is notedgraph
org
.) -
Fallible operations never panic and always return a
Result
type. For instance, the sum of two arraysx
andy
may be writteng.add(&x, &y)?
. -
All structures and values implement
Send
andSync
to support concurrent programming. -
Generic programming is encouraged so that user formulas can be interpreted in different modes of execution (forward evaluation, dimension checking, etc) with minimal overhead. (See the section below for a code example.)
While this library is primarily motivated by machine learning applications, it is meant to cover other use cases of automatic differentiation in reverse mode. In the sections below, we show how a user may define new operators and add new modes of execution while retaining automatic differentiability.
Limitations
-
Currently, the usual syntax of operators
+
,-
,*
, etc is not available for differentiable values. All operations are method calls of the formg.op(x1, .. xN)
(or typicallyg.op(x1, .. xN)?
for fallible operations). -
Because of a current limitation of the Rust borrow checker, expressions cannot be nested:
g.add(&x, &g.mul(&y, &z)?)?
must be writtenlet v = g.mul(&y, &z)?; g.add(&x, &v)?
.
We believe that this state of affairs could be improved in the future using Rust
macros. Alternatively, future extensions of the library could define a new category of
differentiable values that contain an implicit RefCell
reference to a common tape
and provide (implicitly fallible, thread unsafe) operator traits for these values.
Quick Start
To compute gradients, we first build an expression using operations provided by a
fresh tape g
of type Graph1
. Successive algebraic operations modify the internal
state of g
to track all relevant computations and enables future backward
propagation passes.
We then call g.evaluate_gradients(..)
to run a backward propagation algorithm from the
desired starting point and using an initial gradient value direction
.
Unless a one-time optimized variant g.evaluate_gradients_once(..)
is used, backward
propagation with g.evaluate_gradients(..)
does not modify g
. This allows
successive (or concurrent) backward propagations to be run from different starting
points or with different gradient values.
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values.
let a = g.variable(1f32);
let b = g.variable(2f32);
let c = g.mul(&a, &b)?;
// Compute the derivatives of `c` relative to `a` and `b`
let gradients = g.evaluate_gradients(c.gid()?, 1f32)?;
// Read the `dc/da` component.
assert_eq!(*gradients.get(a.gid()?).unwrap(), 2.0);
Because Graph1
, the type of g
, provides algebraic operations as methods, below we
refer to such a type as an "algebra". GAD uses particular Rust traits to represent the
set of operations supported by a given algebra.
Computations with Arrayfire
The default array operations of the library are currently based on Arrayfire, a portable array library supporting GPUs and JIT-compilation.
use arrayfire as af;
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values using Arrayfire arrays
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
let c = g.mul(&a, &b)?;
// Compute gradient of c
let direction = af::constant(1f32, dims);
let gradients = g.evaluate_gradients_once(c.gid()?, direction)?;
After installing the arrayfire library on your system, make sure to
-
select the package feature "arrayfire" in your build file
Cargo.toml
(e.g.gad = { version = "XX", features = ["arrayfire"]}
), -
run
cargo
with the environment variableAF_PATH
set appropriately (e.g. afterexport AF_PATH=/usr/local
).
Using Generics for Forward Evaluation and Fast Dimension Checking
The algebra Graph1
used in the examples above is one choice amongst several
"default" algebras offered by the library:
-
We also provide a special algebra
Eval
for forward evaluation, that is, running only primitive operations and dimension checks (no tape, no gradients); -
Similarly, using the algebra
Check
will check dimensions without evaluating or allocating any array data; -
Finally, differentiation is obtained using
Graph1
for first-order differentials, andGraphN
for higher-order differentials.
Users are encouraged to program formulas in a generic way so that any of the default algebras can be chosen.
The following example illustrates such a programming style in the case of array operations:
use arrayfire as af;
fn get_value<A>(g: &mut A) -> Result<<A as AfAlgebra<f32>>::Value>
where A : AfAlgebra<f32>
{
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
g.mul(&a, &b)
}
// Direct evaluation. The result type is a primitive (non-differentiable) value.
let mut g = Eval::default();
let c : af::Array<f32> = get_value(&mut g)?;
// Fast dimension-checking. The result type is a dimension.
let mut g = Check::default();
let d : af::Dim4 = get_value(&mut g)?;
assert_eq!(c.dims(), d);
Higher-Order Differentials
Higher-order differentials are computed using the algebra GraphN
. In this case, gradients
are values whose computations is also tracked.
// A new tape supporting differentials of any order.
let mut g = GraphN::new();
// Compute forward values using floats.
let x = g.variable(1.0f32);
let y = g.variable(0.4f32);
// z = x * y^2
let z = {
let h = g.mul(&x, &y)?;
g.mul(&h, &y)?
};
// Use short names for gradient ids.
let (x, y, z) = (x.gid()?, y.gid()?, z.gid()?);
// Compute gradient.
let dz = g.constant(1f32);
let dz_d = g.compute_gradients(z, dz)?;
let dz_dx = dz_d.get(x).unwrap();
// Compute some 2nd-order differentials.
let ddz = g.constant(1f32);
let ddz_dxd = g.compute_gradients(dz_dx.gid()?, ddz)?;
let ddz_dxdy = ddz_dxd.get(y).unwrap().data();
assert_eq!(*ddz_dxdy, 0.8); // 2y
// Compute some 3rd-order differentials.
let dddz = g.constant(1f32);
let dddz_dxdyd = g.compute_gradients(ddz_dxd.get(y).unwrap().gid()?, dddz)?;
let dddz_dxdydy = dddz_dxdyd.get(y).unwrap().data();
assert_eq!(*dddz_dxdydy, 2.0);
Extending Automatic Differentiation
Operations and algebras
The default algebras Eval
, Check
, Graph1
, GraphN
are meant to provide
interchangeable sets of operations in each of the default modes of operation
(respectively, evaluation, dimension-checking, first-order differentiation, and
higher-order differentiation).
Default operations are grouped into several traits named *Algebra
and implemented by
each of the four default algebras above.
-
The special trait
CoreAlgebra<Data>
defines the mapping from underlying data (e.g. array) to differentiable values. In particular, the methodfn variable(&mut self, data: &Data) -> Self::Value
creates differentiable variablesx
whose gradient value can be referred to later by an id writtenx.gid()?
(assuming the algebra isGraph1
orGraphN
). -
Other traits are parameterized over one or several value types. E.g.
ArithAlgebra<Value>
provides pointwise negation, multiplication, subtraction, etc overValue
.
The motivation for using several *Algebra
traits is twofold:
-
Users may define their own operations (see next paragraph).
-
Certain operations are more broadly applicable than others.
The following example illustrates gradient computations over integers:
let mut g = Graph1::new();
let a = g.variable(1i32);
let b = g.variable(2i32);
let c = g.sub(&a, &b)?;
assert_eq!(*c.data(), -1);
let gradients = g.evaluate_gradients_once(c.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 1);
assert_eq!(*gradients.get(b.gid()?).unwrap(), -1);
User-defined operations
Users may define new differentiable operations by defining their own *Algebra
trait
and providing implementations to the default algebras Eval
, Check
, Graph1
,
GraphN
.
In the following example, we define a new operation square
over integers and
af-arrays and add support for first-order differentials:
use arrayfire as af;
pub trait UserAlgebra<Value> {
fn square(&mut self, v: &Value) -> Result<Value>;
}
impl UserAlgebra<i32> for Eval
{
fn square(&mut self, v: &i32) -> Result<i32> { Ok(v * v) }
}
impl<T> UserAlgebra<af::Array<T>> for Eval
where
T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>
{
fn square(&mut self, v: &af::Array<T>) -> Result<af::Array<T>> { Ok(v * v) }
}
impl<D> UserAlgebra<Value<D>> for Graph1
where
Eval: CoreAlgebra<D, Value = D>
+ UserAlgebra<D>
+ ArithAlgebra<D>
+ LinkedAlgebra<Value<D>, D>,
D: HasDims + Clone + 'static + Send + Sync,
D::Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
fn square(&mut self, v: &Value<D>) -> Result<Value<D>> {
let result = self.eval().square(v.data())?;
let value = self.make_node(result, vec![v.input()], {
let v = v.clone();
move |graph, store, gradient| {
if let Some(id) = v.id() {
let c = graph.link(&v);
let grad1 = graph.mul(&gradient, c)?;
let grad2 = graph.mul(c, &gradient)?;
let grad = graph.add(&grad1, &grad2)?;
store.add_gradient(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}
}
fn main() -> Result<()> {
let mut g = Graph1::new();
let a = g.variable(3i32);
let b = g.square(&a)?;
assert_eq!(*b.data(), 9);
let gradients = g.evaluate_gradients_once(b.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 6);
Ok(())
}
The implementation for GraphN
would be identical to Graph1
. We have omitted
dimension-checking for simplicity. We refer readers to the test files of the library
for a more complete example.
User-defined algebras
Users may define new "evaluation" algebras (similar to Eval
) by implementing a
subset of operation traits that includes CoreAlgebra<Data, Value=Data>
for each
supported Data
types.
An evaluation-only algebra can be turned into algebras supporting differentiation
(similar to Graph1
and GraphN
) using the Graph
construction provided by the
library.
The following example illustrates how to define a new evaluation algebra SymEval
then deduce its counterpart SymGraph1
:
/// A custom algebra for forward-only symbolic evaluation.
#[derive(Clone, Default)]
struct SymEval;
/// Symbolic expressions of type T.
#[derive(Debug, PartialEq)]
enum Exp_<T> {
Num(T),
Neg(Exp<T>),
Add(Exp<T>, Exp<T>),
Mul(Exp<T>, Exp<T>),
// ...
}
type Exp<T> = Arc<Exp_<T>>;
impl<T> CoreAlgebra<Exp<T>> for SymEval {
type Value = Exp<T>;
fn variable(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn constant(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value> {
Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone())))
}
}
impl<T> ArithAlgebra<Exp<T>> for SymEval {
fn neg(&mut self, v: &Exp<T>) -> Exp<T> {
Arc::new(Exp_::Neg(v.clone()))
}
fn sub(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
let v2 = self.neg(v2);
Ok(Arc::new(Exp_::Add(v1.clone(), v2)))
}
fn mul(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone())))
}
// ...
}
// No dimension checks.
impl<T> HasDims for Exp_<T> {
type Dims = ();
fn dims(&self) {}
}
impl<T: std::fmt::Display> std::fmt::Display for Exp_<T> {
// ...
}
/// Apply `graph` module to Derive an algebra supporting gradients.
type SymGraph1 = Graph<Config1<SymEval>>;
fn main() -> Result<()> {
let mut g = SymGraph1::new();
let a = g.variable(Arc::new(Exp_::Num("a")));
let b = g.variable(Arc::new(Exp_::Num("b")));
let c = g.mul(&a, &b)?;
let d = g.mul(&a, &c)?;
assert_eq!(format!("{}", d.data()), "aab");
let gradients = g.evaluate_gradients_once(d.gid()?, Arc::new(Exp_::Num("1")))?;
assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)");
assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1");
Ok(())
}
Contributing
See the CONTRIBUTING file for how to help out.
License
This project is available under the terms of either the Apache 2.0 license or the MIT license.
Dependencies
~3.5–5.5MB
~115K SLoC