1 unstable release

0.4.0 Aug 23, 2024

#205 in Machine learning

Download history 143/week @ 2024-08-19 19/week @ 2024-08-26

162 downloads per month
Used in 2 crates (via nox)

MIT/Apache

515KB
12K SLoC

C++ 7.5K SLoC // 0.1% comments Rust 3.5K SLoC // 0.0% comments CUDA 247 SLoC // 0.2% comments Bazel 48 SLoC // 0.2% comments

noxla

Originally based on https://github.com/LaurentMazare/xla-rs

noxla is a Rust wrapper around XLA, a compiler for machine learning and linear algebra. The goal of this project is to create a set of safe bindings close to XLA's C++ API. The intended use of this crate is in other higher-level linear algebra and ML crates. It can be used on its own, but it doesn't promise to be the most ergonomic or well-documented.

Why Fork?

This crate differs from the original xla-rs in a few key ways. The biggest is that the bindings are rewritten to use the cpp and cxx crates. This allows our bindings to be written in line with Rust code, and to have more clear typing. The other large change is that we are using a fork of xla_extension, which allows for building a fully static binary. This fits better the intended use of noxla as a low-level crate, by other higher-level libraries.


lib.rs:

Rust bindings for XLA (Accelerated Linear Algebra).

XLA is a compiler library for Machine Learning. It can be used to run models efficiently on GPUs, TPUs, and on CPUs too.

XlaOps are used to build a computation graph. This graph can built into a XlaComputation. This computation can then be compiled into a PjRtLoadedExecutable and then this executable can be run on a PjRtClient. Literal values are used to represent tensors in the host memory, and PjRtBuffer represent views of tensors/memory on the targeted device.

The following example illustrates how to build and run a simple computation.

// Create a CPU client.
let client = xla::PjRtClient::cpu()?;

// A builder object is used to store the graph of XlaOp.
let builder = xla::XlaBuilder::new("test-builder");

// Build a simple graph summing two constants.
let cst20 = xla_builder.constant_r0(20f32);
let cst22 = xla_builder.constant_r0(22f32);
let sum = (cst20 + cst22)?;

// Create a computation from the final node.
let sum= sum.build()?;

// Compile this computation for the target device and then execute it.
let result = client.compile(&sum)?;
let result = &result.execute::<xla::Literal>(&[])?;

// Retrieve the resulting value.
let result = result[0][0].to_literal_sync()?.to_vec::<f32>()?;

Dependencies

~64MB
~818K SLoC