#optimization #machine-learning #ai #ot #optimal-transport

rust-optimal-transport

A library of optimal transport solvers for Rust

3 unstable releases

0.2.0 Feb 16, 2022
0.1.1 Jan 12, 2022
0.1.0 Jan 12, 2022

#473 in Machine learning

22 downloads per month

MIT license

62KB
1.5K SLoC

Rust Optimal Transport

This library provides solvers for performing regularized and unregularized Optimal Transport in Rust.

Inspired by Python Optimal Transport, this library provides the following solvers:

  • Network simplex algorithm for linear program / Earth Movers Distance
  • Entropic regularization OT solvers including Sinkhorn Knopp and Greedy Sinkhorn
  • Unbalanced Sinkhorn Knopp

Installation

The library has been tested on macOS. It requires a C++ compiler for building the EMD solver and relies on the following Rust libraries:

  • cxx 1.0
  • thiserror 1.0
  • ndarray 0.15

Cargo installation

Edit your Cargo.toml with the following to use rust-optimal-transport in your project.

[dependencies]
rust-optimal-transport = "0.1"

Features

If you would like to enable LAPACK backend (currently supporting OpenBLAS):

[dependencies]
rust-optimal-transport = { version = "0.1", features = ["blas"] }

This will link against an installed instance of OpenBLAS on your system. For more details see the ndarray-linalg crate.

Examples

Short examples

  • Import the library
use rust_optimal_transport as ot;
use ot::prelude::*;

  • Compute OT matrix as the Earth Mover's Distance
// Generate data
let n_samples = 100;

// Mean, Covariance of the source distribution
let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];

// Mean, Covariance of the target distribution
let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];

// Samples of a 2D gaussian distribution
let source = ot::utils::sample_2D_gauss(n_samples, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n_samples, &mu_target, &cov_target).unwrap();

// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));

// Compute ground cost matrix - Squared Euclidean distance
let mut cost = dist(&source, &target, SqEuclidean);
let max_cost = cost.max().unwrap();

// Normalize cost matrix for numerical stability
cost = &cost / *max_cost;

// Compute optimal transport matrix as the Earth Mover's Distance
let ot_matrix = match EarthMovers::new(
    &mut source_weights,
    &mut target_weights,
    &mut ground_cost
).solve()?;

Acknowledgements

This library is inspired by Python Optimal Transport. The original authors and contributors of that project are listed at POT.

Dependencies

~71MB
~1M SLoC