1 stable release
1.19.2 | Sep 27, 2024 |
---|
#216 in Machine learning
51KB
1.5K
SLoC
ortn
Yet another minimum rust
binding for onnxruntime
c_api
, inspired by onnxruntime-rs.
What's "minimum" means?
- Only subset of
c_api
is wrapped, enough to run a onnx model. - Less
'a
lifetime generic... - Less concept overhead when use
rust
compare to useonnxruntime
c_api
. - Best effort to work with
latest
onnxruntime version on different platform, NOfeature
flag introduced by multi-version of onnxruntime. - Only shared library (
onnxruntime.dll, libonnxruntime.[so|dyn])
supported. ndarray
is used to handle input/output tensor
Test Matrix
OS | onnxuntime version |
Arch | CPU | CUDA | TensorRT | CANN |
---|---|---|---|---|---|---|
mac | 1.19.2 | aarch64 | ✅ | |||
mac | 1.19.2 | intel64 | ✅ | |||
linux | 1.19.2 | intel64 | ✅ | ✅ | ✅ | TODO |
windows | TODO | intel64 | TODO | TODO | TODO |
Getting Started
-
please download
onnxruntime
first, unzip, or build it from source. -
before start everything, setup environment variable to help
ortn
to findheader
orlibraries
needed.
ORT_LIB_DIR
:- folder where
libonnxruntime.[so|dylib]
oronnxruntime.dll
located
- folder where
ORT_INC_DIR
:- folder where header file:
onnxruntime/onnxruntime_c_api.h
located
- folder where header file:
DYLD_LIBRARY_PATH
:- (mac only) folder where
libonnxruntime.dylib
located
- (mac only) folder where
LD_LIBRARY_PATH
:- (linux only) folder where
libonnxruntime.so
located
- (linux only) folder where
PATH
:- folder where
onnxruntime.dll
located
- folder where
- build environment, build session, run session
use ndarray::Array4;
use ndarray_rand::{rand_distr::Uniform, RandomExt};
use ortn::prelude::*;
std::env::set_var("RUST_LOG", "trace");
let _ = tracing_subscriber::fmt::try_init();
let output = Session::builder()
// create env and use it as session's env
.with_env(
Environment::builder()
.with_name("minst")
.with_level(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE)
.build()?,
)
// disable all optimization
.with_graph_optimization_level(GraphOptimizationLevel::ORT_DISABLE_ALL)
// set session intra threads to 4
.with_intra_threads(4)
// build model
.build(include_bytes!("../models/mnist.onnx"))?
// run model
.run([
// convert input tensor to ValueView
ValueBorrowed::try_from(
// create random input tensor
Array4::random([1, 1, 28, 28], Uniform::new(0., 1.)).view(),
)?,
])?
// output is a vector, we need to get first result
.into_iter()
.next()
.unwrap()
// view output as a f32 array
.view::<f32>()?
// the output is owned by session, copy it out as a owned tensor/ndarray
.to_owned();
tracing::info!(?output);
Result::Ok(())
Update bindings
In case bindings need to be update, just:
- git clone this repo
git clone https://github.com/yexiangyu/ortn
- export environment variable
export ORT_LIB_DIR=/path/to/onnxruntime/lib
export ORT_INC_DIR=/path/to/onnxruntime/include
- build bindings with feature
bindgen
enabled
cargo build --features bindgen
TODO
- More input data type like
f16
,i64
... - More runtime provider like
rocm
andcann
onnxruntime-agi
training api
Dependencies
~5–7.5MB
~141K SLoC