19 releases (5 breaking)

✓ Uses Rust 2018 edition

new 0.6.3 Oct 10, 2019
0.5.4 Sep 6, 2019
0.3.1 Jul 2, 2019
0.1.2 Jan 16, 2019
0.1.1 Oct 30, 2018

#12 in Machine learning

Download history 369/week @ 2019-06-27 809/week @ 2019-07-04 474/week @ 2019-07-11 422/week @ 2019-07-18 981/week @ 2019-07-25 768/week @ 2019-08-01 860/week @ 2019-08-08 921/week @ 2019-08-15 1106/week @ 2019-08-22 1241/week @ 2019-08-29 577/week @ 2019-09-05 605/week @ 2019-09-12 638/week @ 2019-09-19 119/week @ 2019-09-26 114/week @ 2019-10-03

2,866 downloads per month
Used in 1 crate

MIT/Apache

160KB
4.5K SLoC

Build Status

Rust bindings for TensorFlow Lite

This crates provides TensorFlow Lite APIs. Please read the API documentation on docs.rs

Using the interpreter from a model file

The following example shows how to use the TensorFlow Lite interpreter when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates how to run inference on input data.

extern crate failure;

extern crate tflite;

use std::fs::File;
use std::io::Read;
use std::path::Path;

use failure::Fallible;

use tflite::ops::builtin::BuiltinOpResolver;
use tflite::{FlatBufferModel, InterpreterBuilder};

fn test_mnist<P: AsRef<Path>>(model_path: P) -> Fallible<()> {
    let model = FlatBufferModel::build_from_file(model_path)?;
    let resolver = BuiltinOpResolver::default();

    let builder = InterpreterBuilder::new(&model, &resolver)?;
    let mut interpreter = builder.build()?;

    interpreter.allocate_tensors()?;

    let inputs = interpreter.inputs().to_vec();
    assert_eq!(inputs.len(), 1);

    let input_index = inputs[0];

    let outputs = interpreter.outputs().to_vec();
    assert_eq!(outputs.len(), 1);

    let output_index = outputs[0];

    let input_tensor = interpreter.tensor_info(input_index)?;
    assert_eq!(input_tensor.dims, vec![1, 28, 28, 1]);

    let output_tensor = interpreter.tensor_info(output_index)?;
    assert_eq!(output_tensor.dims, vec![1, 10]);

    let mut input_file = File::open("data/mnist10.bin")?;
    for i in 0..10 {
        input_file.read_exact(interpreter.tensor_data_mut(input_index)?)?;

        interpreter.invoke()?;

        let output: &[u8] = interpreter.tensor_data(output_index)?;
        let guess = output
            .iter()
            .enumerate()
            .max_by(|x, y| x.1.cmp(y.1))
            .unwrap()
            .0;

        println!("{}: {:?}", i, output);
        assert_eq!(i, guess);
    }
    Ok(())
}

#[test]
fn mobilenetv1_mnist() {
    test_mnist("data/MNISTnet_uint8_quant.tflite").unwrap();
}

#[test]
fn mobilenetv2_mnist() {
    test_mnist("data/MNISTnet_v2_uint8_quant.tflite").unwrap();
}

Using the FlatBuffers model APIs

This crate also provides a limited set of FlatBuffers model APIs.

use tflite::model::stl::vector::{VectorInsert, VectorErase, VectorSlice};
use tflite::model::{BuiltinOperator, BuiltinOptions, Model, SoftmaxOptionsT};

#[test]
fn flatbuffer_model_apis_inspect() {
    let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
    assert_eq!(model.version, 3);
    assert_eq!(model.operator_codes.size(), 5);
    assert_eq!(model.subgraphs.size(), 1);
    assert_eq!(model.buffers.size(), 24);
    assert_eq!(
        model.description.c_str().to_string_lossy(),
        "TOCO Converted."
    );

    assert_eq!(
        model.operator_codes[0].builtin_code,
        BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D
    );

    assert_eq!(
        model
            .operator_codes
            .iter()
            .map(|oc| oc.builtin_code)
            .collect::<Vec<_>>(),
        vec![
            BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D,
            BuiltinOperator::BuiltinOperator_CONV_2D,
            BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D,
            BuiltinOperator::BuiltinOperator_SOFTMAX,
            BuiltinOperator::BuiltinOperator_RESHAPE
        ]
    );

    let subgraph = &model.subgraphs[0];
    assert_eq!(subgraph.tensors.size(), 23);
    assert_eq!(subgraph.operators.size(), 9);
    assert_eq!(subgraph.inputs.as_slice(), &[22]);
    assert_eq!(subgraph.outputs.as_slice(), &[21]);

    let softmax = subgraph
        .operators
        .iter()
        .position(|op| {
            model.operator_codes[op.opcode_index as usize].builtin_code
                == BuiltinOperator::BuiltinOperator_SOFTMAX
        })
        .unwrap();

    assert_eq!(subgraph.operators[softmax].inputs.as_slice(), &[4]);
    assert_eq!(subgraph.operators[softmax].outputs.as_slice(), &[21]);
    assert_eq!(
        subgraph.operators[softmax].builtin_options.type_,
        BuiltinOptions::BuiltinOptions_SoftmaxOptions
    );

    let softmax_options: &SoftmaxOptionsT = subgraph.operators[softmax].builtin_options.as_ref();
    assert_eq!(softmax_options.beta, 1.);
}

#[test]
fn flatbuffer_model_apis_mutate() {
    let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
    model.version = 2;
    model.operator_codes.erase(4);
    model.buffers.erase(22);
    model.buffers.erase(23);
    model
        .description
        .assign(CString::new("flatbuffer").unwrap());

    {
        let subgraph = &mut model.subgraphs[0];
        subgraph.inputs.erase(0);
        subgraph.outputs.assign(vec![1, 2, 3, 4]);
    }

    let model_buffer = model.to_buffer();
    let model = Model::from_buffer(&model_buffer);
    assert_eq!(model.version, 2);
    assert_eq!(model.operator_codes.size(), 4);
    assert_eq!(model.subgraphs.size(), 1);
    assert_eq!(model.buffers.size(), 22);
    assert_eq!(model.description.c_str().to_string_lossy(), "flatbuffer");

    let subgraph = &model.subgraphs[0];
    assert_eq!(subgraph.tensors.size(), 23);
    assert_eq!(subgraph.operators.size(), 9);
    assert!(subgraph.inputs.as_slice().is_empty());
    assert_eq!(subgraph.outputs.as_slice(), &[1, 2, 3, 4]);
}

Dependencies

~1.3–5.5MB
~119K SLoC