3 releases

new 0.1.2 Feb 11, 2025
0.1.1 Feb 10, 2025
0.1.0 Feb 8, 2025

#545 in Machine learning

Download history 98/week @ 2025-02-03

98 downloads per month

MIT license

39KB
1K SLoC

ort2

  • onnxruntime wrapper for rust 2
  • tested
OS Linux Windows MacOS
CPU Y Y Y(aarch64)
CUDA Y Y N/A

Pre-requirements

  • download onnxruntime from onnxruntime github Repo, unzip it
  • setup enviroment variable
    • ORT_INC_PATH=/opt/homebrew/opt/onnxruntime/include
    • ORT_LIB_PATH=/opt/homebrew/opt/onnxruntime/lib

Getting Started

use ort2::prelude::*;

// load model stuff
let model = include_bytes!("models/mnist-8.onnx");

// create session
let session = Session::builder()
    .build(model.as_ref())
    .expect("failed to create session");

// dump input
let input = vec![0.0f32;28 * 28];

// create value from input
let value = Value::tensor()
    .with_shape([1, 1, 28, 28])
    .with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
    .borrow(&input)
    .expect("failed to build value");

// get output
let output = session.run([&value])
    .expect("failed to run")
    .into_iter()
    .next()
    .expect("failed to get outputs");

// output of session as ndarray Array
let output = output
    .view::<f32>()
    .expect("failed to view output");

assert_eq!(output.shape()[1], 10);

Run with IoBinding

use ort2::prelude::*;

// load model stuff
let model = include_bytes!("models/mnist-8.onnx");

// create session
let session = Session::builder()
    .build(model.as_ref())
    .expect("failed to create session");

// dump input
let input = vec![0.0f32;28 * 28];

// create value from input
let value = Value::tensor()
    .with_shape([1, 1, 28, 28])
    .with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
    .borrow(&input)
    .expect("failed to build value");

// create iobinding
let mut iobinding = session.iobinding()
    .expect("failed to create iobinding");

// bind input
iobinding.bind_input(
        &session.get_inputs()
        .expect("failed to get input")[0].name,
        &value
    )
    .expect("failed to bind input");

let mem_info = MemoryInfo::default();

// bind_output

iobinding.bind_output_to_device(
        &session.get_outputs()
        .expect("failed to get outputs")[0].name,
        &mem_info
    )
    .expect("failed to bind outputs");

// run
session.run_with_iobinding(&mut iobinding)
    .expect("failed to run");

let alloc = DefaultAllocator::default();

// get output
let output = iobinding
    .get_bound_outputs(&alloc)
    .expect("failed to get output from iobinding")
    .into_iter()
    .next()
    .expect("failed to get output");

let output = output
    .view::<f32>()
    .expect("failed to view output");

assert_eq!(output.shape()[1], 10);

Dependencies

~3.5–6.5MB
~121K SLoC