5 releases
0.2.0 | Apr 3, 2022 |
---|---|
0.1.3 | Nov 14, 2021 |
0.1.2 | Jun 26, 2021 |
0.1.1 | Mar 26, 2021 |
0.1.0 | Nov 16, 2020 |
#147 in #tensor
Used in tch-tensor-like
17KB
362 lines
tch-tensor-like: Derive Tensor-like Types for tch-rs
About this crate
If you are a user of tch-rs, perhaps you ever worked with a complex model input type like this.
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
Before you feed a batch input of this type into a model, you have to move it to the appropriate device.
It could be tedious to call tensor.to_device()
for each member of the type.
The TensorLike
derive macro comes to your rescue.
use tch_tensor_like::TensorLike;
#[derive(TensorLike)]
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
By deriving the macro, you have to_device()
, to_kind()
and shallow_clone()
out of box.
let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
.to_kind(Kind::Float)
.shallow_clone();
For non-tensor members, you can mark the attributes to clone the value instead.
#[derive(TensorLike)]
struct ModelInput {
// primitives are copied by default
pub number: i32,
// copy the field
#[tensor_like(copy)]
pub text: &'static str,
// clone the field
#[tensor_like(clone)]
pub desc: String,
}
Usage
The crate is not published to crates.io yet. Add the repo link to include this crate in your project.
[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }
License
MIT License. See LICENSE file.
Dependencies
~1.5MB
~37K SLoC