9 breaking releases
0.10.0 | Mar 26, 2024 |
---|---|
0.8.0 | Feb 20, 2024 |
0.2.0 | Dec 28, 2023 |
#136 in Machine learning
163 downloads per month
Used in 2 crates
320KB
9K
SLoC
RAI
ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations.
Installation
cargo add rai
Code snippets
Function transformations (jvp, vjp, grad, value_and_grad)
use rai::{grad, Cpu, Func, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
x.sin()
}
fn main() {
let grad_fn = grad(grad(f));
let x = Tensor::ones([1], F32, &Cpu);
let grad = grad_fn.apply((&x,));
println!("{}", grad.dot_graph());
println!("{}", grad);
}
NN Modules, Optimizer and loss functions
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
model: &M,
input: &Tensor,
labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
let logits = model.forward(input);
let loss = softmax_cross_entropy(&logits, labels).mean(..);
(loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
optimizer: &mut O,
model: &M,
input: &Tensor,
labels: &Tensor,
) {
let vg_fn = value_and_grad(loss_fn);
let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn.apply((model, input, labels));
let mut params = optimizer.step(&grads);
eval(¶ms);
model.update_params(&mut params);
}
Examples
- linear_regression
cargo run --bin linear_regression --release
- mnist
cargo run --bin mnist --release
- phi2
cargo run --bin phi2 --release
- qwen2
cargo run --bin qwen2 --release
- gemma
- accept license agreement in https://huggingface.co/google/gemma-2b
pip install huggingface_hub
- login to hf
huggingface-cli login
cargo run --bin gemma --release
- vit
cargo run --bin vit --release
LICENSE
This project is licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Dependencies
~7–21MB
~257K SLoC