8 releases
0.1.7 | Dec 12, 2023 |
---|---|
0.1.6 | Dec 3, 2023 |
0.1.5 | Nov 15, 2023 |
#443 in Machine learning
25 downloads per month
Used in candlelighter
42KB
946 lines
Candle Extensions
An extension library to Candle that provides PyTorch functions not currently available in Candle
use candle_ext::{
candle::{ D, DType, Device, Result, Tensor},
TensorExt, F,
};
fn main() -> Result<()> {
let device = Device::Cpu;
let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;
let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;
Ok(())
}
Currently provides (see also tests):
-
F::scaled_dot_product_attention
-
F::chunk2..5 / Tensor::chunk2..5
-
F::cumsum / Tensor::cumsum
-
F::equal / Tensor::equal
-
F::eye / Tensor::eye
-
F::full / Tensor::full
-
F::full_like / Tensor::full_like
-
F::triu / Tensor::triu
-
F::tril / Tensor::tril
-
F::masked_fill / Tensor::masked_fill
-
F::logical_not / Tensor::logical_not
-
F::logical_or / Tensor::logical_or
-
F::outer / Tensor::outer
-
F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5
License
Licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or https://opensource.org/licenses/MIT)
at your option.
Contribution
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.
Dependencies
~8–12MB
~275K SLoC