27 releases

0.5.9 Mar 28, 2022
0.5.5 Feb 28, 2022
0.3.8 Aug 16, 2020
0.3.6 Jul 7, 2020

#118 in Machine learning

Download history 75/week @ 2024-02-19 4/week @ 2024-02-26 1/week @ 2024-03-04

80 downloads per month
Used in 2 crates

MIT license

555KB
13K SLoC

A simple machine learning toolset

crates.io version License example workflow doc badge

Introduction

This is an auto-difference based learning library.

Features

  • A type-less tensor.
  • Variable over tensor with support for back propagation.
  • Support for common operators, including convolution.

Example

use tensor_rs::tensor::Tensor;
use auto_diff::rand::RNG;
use auto_diff::var::{Module};
use auto_diff::optim::{SGD, Optimizer};

fn main() {

    fn func(input: &Tensor) -> Tensor {
        input.matmul(&Tensor::from_vec_f32(&vec![2., 3.], &vec![2, 1])).add(&Tensor::from_vec_f32(&vec![1.], &vec![1]))
    }

    let N = 100;
    let mut rng = RNG::new();
    rng.set_seed(123);
    let data = rng.normal(&vec![N, 2], 0., 2.);
    let label = func(&data);


    let mut m = Module::new();
    
    let op1 = m.linear(Some(2), Some(1), true);
    let weights = op1.get_values().unwrap();
    rng.normal_(&weights[0], 0., 1.);
    rng.normal_(&weights[1], 0., 1.);
    op1.set_values(&weights);

    let op2 = op1.clone();
    let block = m.func(
        move |x| {
            op2.call(x)
        }
    );
    
    let loss_func = m.mse_loss();
    
    let mut opt = SGD::new(3.);

    for i in 0..200 {
        let input = m.var_value(data.clone());
        
        let y = block.call(&[&input]);
        
        let loss = loss_func.call(&[&y, &m.var_value(label.clone())]);
        println!("index: {}, loss: {}", i, loss.get().get_scale_f32());
        
        loss.backward(-1.);
        opt.step2(&block);

    }

    let weights = op1.get_values().expect("");
    println!("{:?}, {:?}", weights[0], weights[1]);
}

Dependence

install gfortran is openblas-src = "0.9" is used.

Contributing

Any contribution is welcome and please open an issue by creating a pull request.

Dependencies

~2.5MB
~55K SLoC