#machine-learning #tensor #deep-learning

rai-nn

ML framework with Ergonomic APIs in Rust

9 breaking releases

0.10.0 Mar 26, 2024
0.8.0 Feb 20, 2024
0.2.0 Dec 28, 2023

#383 in Machine learning

Download history 20/week @ 2023-12-29 16/week @ 2024-01-05 9/week @ 2024-01-12 9/week @ 2024-01-19 1/week @ 2024-01-26 105/week @ 2024-02-16 179/week @ 2024-02-23 100/week @ 2024-03-01 10/week @ 2024-03-08 2/week @ 2024-03-15 142/week @ 2024-03-22 21/week @ 2024-03-29 2/week @ 2024-04-05

165 downloads per month
Used in 3 crates (via rai)

MIT/Apache

305KB
9K SLoC

RAI

Rust Docs Status Latest Version Discord

ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations.

Installation

cargo add rai

Code snippets

Function transformations (jvp, vjp, grad, value_and_grad)

use rai::{grad, Cpu, Func, Tensor, F32};

fn f(x: &Tensor) -> Tensor {
    x.sin()
}

fn main() {
    let grad_fn = grad(grad(f));
    let x = Tensor::ones([1], F32, &Cpu);
    let grad = grad_fn.apply((&x,));
    println!("{}", grad.dot_graph());
    println!("{}", grad);
}

NN Modules, Optimizer and loss functions

fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}

fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn.apply((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(&params);
    model.update_params(&mut params);
}

Examples

LICENSE

This project is licensed under either of

at your option.

Dependencies

~5–17MB
~198K SLoC