1 unstable release
0.4.0 | Aug 23, 2024 |
---|
#305 in Machine learning
Used in 2 crates
(via nox)
515KB
12K
SLoC
noxla
Originally based on https://github.com/LaurentMazare/xla-rs
noxla is a Rust wrapper around XLA, a compiler for machine learning and linear algebra. The goal of this project is to create a set of safe bindings close to XLA's C++ API. The intended use of this crate is in other higher-level linear algebra and ML crates. It can be used on its own, but it doesn't promise to be the most ergonomic or well-documented.
Why Fork?
This crate differs from the original xla-rs in a few key ways. The biggest is that the bindings are rewritten to use the cpp and cxx crates. This allows our bindings to be written in line with Rust code, and to have more clear typing. The other large change is that we are using a fork of xla_extension
, which allows for building a fully static binary. This fits better the intended use of noxla as a low-level crate, by other higher-level libraries.
lib.rs
:
Rust bindings for XLA (Accelerated Linear Algebra).
XLA is a compiler library for Machine Learning. It can be used to run models efficiently on GPUs, TPUs, and on CPUs too.
XlaOp
s are used to build a computation graph. This graph can built into a
XlaComputation
. This computation can then be compiled into a PjRtLoadedExecutable
and
then this executable can be run on a PjRtClient
. Literal
values are used to represent
tensors in the host memory, and PjRtBuffer
represent views of tensors/memory on the
targeted device.
The following example illustrates how to build and run a simple computation.
// Create a CPU client.
let client = xla::PjRtClient::cpu()?;
// A builder object is used to store the graph of XlaOp.
let builder = xla::XlaBuilder::new("test-builder");
// Build a simple graph summing two constants.
let cst20 = xla_builder.constant_r0(20f32);
let cst22 = xla_builder.constant_r0(22f32);
let sum = (cst20 + cst22)?;
// Create a computation from the final node.
let sum= sum.build()?;
// Compile this computation for the target device and then execute it.
let result = client.compile(&sum)?;
let result = &result.execute::<xla::Literal>(&[])?;
// Retrieve the resulting value.
let result = result[0][0].to_literal_sync()?.to_vec::<f32>()?;
Dependencies
~65MB
~820K SLoC