24 releases

0.11.4 Oct 12, 2023
0.11.1 Sep 28, 2023
0.10.2 Dec 31, 2022
0.8.0 Nov 11, 2022

#40 in Machine learning

Download history 9/week @ 2023-10-29 1/week @ 2023-11-05 25/week @ 2023-11-12 13/week @ 2023-11-19 74/week @ 2023-11-26 24/week @ 2023-12-17 48/week @ 2023-12-24 25/week @ 2023-12-31 48/week @ 2024-01-28 2/week @ 2024-02-04 104/week @ 2024-02-11

154 downloads per month

MIT/Apache

235KB
5K SLoC

Zyx

Machine learning library written in rust.

Zyx is semi small and has only few dependencies (i. e. zyx was build from scratch in rust and OpenCL).

Three accelerators are currently supported: native rust (with and without multithreading), opencl and c++ libtorch.

Tensors

Tensor is the basic unit of zyx. Tensors are immutable. Context manages all tensors and connects them with backends.

Tensors are just reference counted fat pointers. Feel free to clone them.

# use zyx::context::Context;
let mut ctx = Context::new();
let x = ctx.tensor([[2, 4, 3], [4, 2, 3]]);
let y = x.clone();

Automatic differentation/Backpropagation

Every operation is traced automatically. Thus you can calculate derivative of any tensor with respect to any other tensor. There is no need for gradient tape and you don't need to set requires_grad. See [automatic differentiation](Automatic differentiation.md) for more details.

Following function calculates gradients for w1 and w2 (i. e. derivative of loss w.r.t w1 and w.r.t w2).

# use zyx::{context::Context, dtype::DType};
let mut ctx = Context::new();
let x = ctx.randn((2, 4));
let mut w1 = ctx.randn((4, 4));
let mut w2 = ctx.randn((4, 3));
let out = x.dot(&w1).tanh().dot(&w2);
let y = ctx.tensor([2, 1, 4]).cast(DType::F32);
let loss = (out - y).pow(2.);
loss.backward([&mut w1, &mut w2]);

Graph realization

Neural networks are directed acyclic graphs of many tensors. Thus one of the big tradeoffs of modern machine learning libraries is when and how to do calculations. There is no clear winner here, each method has it's pros and cons.

Zyx uses fully dynamic graph. Realization of tensors happens only when you call realize. This means tensors are evaluated lazily.

# #[cfg(feature = "opencl")] {
# use zyx::context::Context;
let mut ctx = Context::opencl().unwrap();
let x = ctx.randn((256, 1024));
let w1 = ctx.randn((1024, 1024));
let mut z = x.dot(w1);
z.realize().unwrap();
# }

This enables certain optimizations, but you need to call realize during training loop.

Neural networks

Implementing module allows for custom high level constructs.

# #[cfg(feature = "opencl")] {
# use zyx::prelude::*;
# use zyx::nn::*;
# use zyx::optim::*;
struct TinyNet {
    l0: Linear,
    l1: Linear,
}

impl Module for TinyNet {
    fn forward(&self, x: &Tensor) -> Tensor {
        self.l1.forward(&self.l0.forward(x).tanh())
    }

    fn parameters(&mut self) -> Parameters {
        self.l0.parameters().join(self.l1.parameters())
    }
}

let mut ctx = Context::opencl().unwrap();
let mut net = TinyNet {
    l0: ctx.linear(12, 1024),
    l1: ctx.linear(1024, 53),
};
let mut opt = SGD::new().set_lr(0.01);
let x = ctx.randn((32, 12)).set_label("x");
let y = ctx.randn(53);
for _ in 0..10 {
    let out = net.forward(&x);
    let loss = out.mse(&y).sum(());
    out.backward(net.parameters());
    // optimizer.step realizes parameters and zeros gradients
    opt.step(net.parameters()).unwrap();
}
# }

Goals

These are general directions for further development of Zyx.

  1. Correctness
  2. Performance
  3. Hardware support

Visualization

Networks can be visualized using dot language.

let graph = ctx.dot_graph();
std::fs::File::create("graph.dot").unwrap().write_all(graph.as_bytes()).unwrap();

tiny_net example forward pass:

Tiny net forward pass image

Backends

Zyx has three backends, CPU, OpenCL (all OpenCL versions) and Torch.

Backends are easy to add. Only few ops are needed and automatic differentiation works with all backends. However making them fast is very hard.

Performance

Here is comparison of Zyx, tinygrad, dfdx and PyTorch running TinyNet example (forward + backward). This is cherry picked benchmark. Take it with grain of salt.

Table shows running time in seconds. PyTorch uses compiled model. Tinygrad runs OpenCL backend for GPU and numpy for CPU.

We couldn't get dfdx and PyTorch working with given gpu.

Device Zyx tinygrad dfdx PyTorch
GPU RX 550 4.58 5.51 - -
CPU i5 Haswell 14.39 11.03 7.79 4.74

As you can see, Zyx is ok on the GPU, but needs to be further optimized for the CPU. PyTorch looks really impressive here, given that it can only utilize CPU. Note that both Zyx and tinygrad have pytorch backend, so they can both run at the same speed as PyTorch.

Load/Save

Zyx works with .safetensors format from huggingface. Enable io feature to have this work.

net.save("model.safetensors");

Loading is not much more complex.

net.load("model.safetensors");

No-std

Zyx is no-std library, but alloc is required.

Features

  • opencl - enables OpenCL backend
  • cpu - enables multithreading, faster cpu operations and std
  • io - enables file operations and std
  • debug1 - enables printing of debug information during runtime and std
  • torch - enables support for torch using tch crate, please specify: export LIBTORCH=/path/to/libtorch export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

Multiple GPUs

Zyx should work with multiple GPUs within single OpenCL platform, but this was not tested.

Syntax

Zyx has syntax similar to other ML libraries (i. e. PyTorch).

Zyx PyTorch
random tensor ctx.randn((5, 3)) torch.randn((5, 3))
zeros tensor ctx.zeros((5, 3)) torch.zeros((5, 3))
uniform tensor ctx.uniform((4, 6), 0.0..4.0) torch.zero((4, 6)).uniform_(0, 4)
matmul let z = x.dot(y); z = x @ y
tanh let y = x.tanh(); y = x.tanh()
binary ops let z = x + 2; z = x + 2
dtypes let z = x.cast(DType::I32); z = x.type(torch.int32)
saving net.save("net.safetensors"); torch.save(net.state_dict(), "net.pt")
loading net.load("net.safetensors"); net.load_state_dict(torch.load("net.pt"))
backpropagation let y = x.exp();
y.backward(&mut x);
x.requires_grad = True
y = x.exp()
y.backward()
optimizers let opt = SGD::new();
opt.step(&mut net);
opt = SGD(net.parameters())
opt.step()

Missing features

Zyx is very much experimental software. Some notable missing features are convolutions and padding.

Contributing

Any contributions are welcome. If interested in simple contributions, improving documentation and adding examples is great. For those interested in adding modules, there is folder nn, where you can leverage tensor's existing ops. And if you are interested in low level programming, improving performance of opencl kernels is the most difficult option.

Bugs

Please report any correctness and performance bugs. Especially please report incorrect/insufficient tests.

Thanks

Libraries that are used as dependencies for Zyx deserve special thanks, because Zyx would not be possible without them.

We would also like to thank users of Zyx for providing continuous interest and showing that there is a demand for this library.

License

Zyx is free software licensed under the terms of both the MIT license and the Apache License, Version 2.0. For OpenCL licensing see it's website.

Dependencies

~0.7–3MB
~67K SLoC