3 releases (breaking)
0.11.0 | May 14, 2024 |
---|---|
0.10.0 | Mar 26, 2024 |
0.7.0 | Jan 23, 2024 |
#1056 in Machine learning
17KB
136 lines
RAI
ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations like JAX.
Installation
cargo add rai
Code snippets
Function transformations (jvp, vjp, grad, value_and_grad)
use rai::{grad, Cpu, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
x.sin()
}
fn main() {
let grad_fn = grad(grad(f));
let x = &Tensor::ones([1], F32, &Cpu);
let grad = grad_fn(x);
println!("{}", grad.dot_graph());
println!("{}", grad);
}
NN Modules, Optimizer and loss functions
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
model: &M,
input: &Tensor,
labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
let logits = model.forward(input);
let loss = softmax_cross_entropy(&logits, labels).mean(..);
(loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
optimizer: &mut O,
model: &M,
input: &Tensor,
labels: &Tensor,
) {
let vg_fn = value_and_grad(loss_fn);
let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
let mut params = optimizer.step(&grads);
eval(¶ms);
model.update_params(&mut params);
}
Examples
- linear_regression
cargo run --bin linear_regression --release
- mnist
cargo run --bin mnist --release
cargo run --bin mnist --release --features=cuda
- mnist-cnn
cargo run --bin mnist-cnn --release
cargo run --bin mnist-cnn --release --features=cuda
- phi2
cargo run --bin phi2 --release
cargo run --bin phi2 --release --features=cuda
- phi3
cargo run --bin phi3 --release
cargo run --bin phi3 --release --features=cuda
- qwen2
cargo run --bin qwen2 --release
cargo run --bin qwen2 --release --features=cuda
- gemma
- accept license agreement in https://huggingface.co/google/gemma-2b
pip install huggingface_hub
- login to hf
huggingface-cli login
cargo run --bin gemma --release
cargo run --bin gemma --release --features=cuda
- vit
cargo run --bin vit --release
cargo run --bin vit --release --features=cuda
LICENSE
This project is licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Dependencies
~24–37MB
~676K SLoC