3 releases
new 0.1.2 | Feb 11, 2025 |
---|---|
0.1.1 | Feb 10, 2025 |
0.1.0 | Feb 8, 2025 |
#545 in Machine learning
98 downloads per month
39KB
1K
SLoC
ort2
- onnxruntime wrapper for rust 2
- tested
OS | Linux | Windows | MacOS |
---|---|---|---|
CPU | Y | Y | Y(aarch64) |
CUDA | Y | Y | N/A |
Pre-requirements
- download onnxruntime from onnxruntime github Repo, unzip it
- setup enviroment variable
ORT_INC_PATH=/opt/homebrew/opt/onnxruntime/include
ORT_LIB_PATH=/opt/homebrew/opt/onnxruntime/lib
Getting Started
use ort2::prelude::*;
// load model stuff
let model = include_bytes!("models/mnist-8.onnx");
// create session
let session = Session::builder()
.build(model.as_ref())
.expect("failed to create session");
// dump input
let input = vec![0.0f32;28 * 28];
// create value from input
let value = Value::tensor()
.with_shape([1, 1, 28, 28])
.with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
.borrow(&input)
.expect("failed to build value");
// get output
let output = session.run([&value])
.expect("failed to run")
.into_iter()
.next()
.expect("failed to get outputs");
// output of session as ndarray Array
let output = output
.view::<f32>()
.expect("failed to view output");
assert_eq!(output.shape()[1], 10);
Run with IoBinding
use ort2::prelude::*;
// load model stuff
let model = include_bytes!("models/mnist-8.onnx");
// create session
let session = Session::builder()
.build(model.as_ref())
.expect("failed to create session");
// dump input
let input = vec![0.0f32;28 * 28];
// create value from input
let value = Value::tensor()
.with_shape([1, 1, 28, 28])
.with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
.borrow(&input)
.expect("failed to build value");
// create iobinding
let mut iobinding = session.iobinding()
.expect("failed to create iobinding");
// bind input
iobinding.bind_input(
&session.get_inputs()
.expect("failed to get input")[0].name,
&value
)
.expect("failed to bind input");
let mem_info = MemoryInfo::default();
// bind_output
iobinding.bind_output_to_device(
&session.get_outputs()
.expect("failed to get outputs")[0].name,
&mem_info
)
.expect("failed to bind outputs");
// run
session.run_with_iobinding(&mut iobinding)
.expect("failed to run");
let alloc = DefaultAllocator::default();
// get output
let output = iobinding
.get_bound_outputs(&alloc)
.expect("failed to get output from iobinding")
.into_iter()
.next()
.expect("failed to get output");
let output = output
.view::<f32>()
.expect("failed to view output");
assert_eq!(output.shape()[1], 10);
Dependencies
~3.5–6.5MB
~121K SLoC