#candle #extension #pytorch #function #tensor #ext #devices

candle-ext

An extension library to Candle that provides PyTorch functions not currently available in Candle

8 releases

0.1.7 Dec 12, 2023
0.1.6 Dec 3, 2023
0.1.5 Nov 15, 2023

#295 in Machine learning

Download history 6/week @ 2024-02-15 13/week @ 2024-02-22 6/week @ 2024-02-29 9/week @ 2024-03-07 27/week @ 2024-03-14 18/week @ 2024-03-21 33/week @ 2024-03-28 12/week @ 2024-04-04 6/week @ 2024-04-11

69 downloads per month
Used in candlelighter

MIT/Apache

42KB
946 lines

Candle Extensions

Test

An extension library to Candle that provides PyTorch functions not currently available in Candle

 use candle_ext::{
     candle::{ D, DType, Device, Result, Tensor},
     TensorExt, F,
 };

 fn main() -> Result<()> {
     let device = Device::Cpu;
     let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
     let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

     let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;

     Ok(())
 }

Currently provides (see also tests):

  • F::scaled_dot_product_attention

  • F::chunk2..5 / Tensor::chunk2..5

  • F::cumsum / Tensor::cumsum

  • F::equal / Tensor::equal

  • F::eye / Tensor::eye

  • F::full / Tensor::full

  • F::full_like / Tensor::full_like

  • F::triu / Tensor::triu

  • F::tril / Tensor::tril

  • F::masked_fill / Tensor::masked_fill

  • F::logical_not / Tensor::logical_not

  • F::logical_or / Tensor::logical_or

  • F::outer / Tensor::outer

  • F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5

License

Licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

Dependencies

~8.5MB
~181K SLoC