#tensor #enums #derive #model #image #complex #tch

tch-tensor-like

Derive convenient methods for struct or enum of tch tensors

5 releases (breaking)

0.6.0 Apr 3, 2022
0.5.0 Nov 14, 2021
0.4.0 Jun 26, 2021
0.3.0 Mar 26, 2021
0.2.0 Nov 16, 2020

#1270 in Rust patterns

MIT license

12KB
337 lines

tch-tensor-like: Derive Tensor-like Types for tch-rs

About this crate

If you are a user of tch-rs, perhaps you ever worked with a complex model input type like this.

struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

Before you feed a batch input of this type into a model, you have to move it to the appropriate device. It could be tedious to call tensor.to_device() for each member of the type. The TensorLike derive macro comes to your rescue.

use tch_tensor_like::TensorLike;

#[derive(TensorLike)]
struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

By deriving the macro, you have to_device(), to_kind() and shallow_clone() out of box.

let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
                 .to_kind(Kind::Float)
                 .shallow_clone();

For non-tensor members, you can mark the attributes to clone the value instead.

#[derive(TensorLike)]
struct ModelInput {
    // primitives are copied by default
    pub number: i32,

    // copy the field
    #[tensor_like(copy)]
    pub text: &'static str,

    // clone the field
    #[tensor_like(clone)]
    pub desc: String,
}

Usage

The crate is not published to crates.io yet. Add the repo link to include this crate in your project.

[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }

License

MIT License. See LICENSE file.

Dependencies

~7–10MB
~212K SLoC