1 unstable release
0.10.0 | Mar 1, 2024 |
---|
#793 in Machine learning
Used in luminal_cuda
1MB
27K
SLoC
cudarc: minimal and safe api over the cuda toolkit
Checkout cudarc on crates.io and docs.rs.
Safe abstractions over:
Pre-alpha state, expect breaking changes and not all cuda functions contain a safe wrapper. Contributions welcome for any that aren't included!
Design
Goals are:
- As safe as possible (there will still be a lot of unsafe due to ffi & async)
- As ergonomic as possible
- Allow mixing of high level
safe
apis, with low levelsys
apis
To that end there are three levels to each wrapper (by default the safe api is exported):
use cudarc::driver::{safe, result, sys};
use cudarc::nvrtc::{safe, result, sys};
use cudarc::cublas::{safe, result, sys};
use cudarc::cublaslt::{safe, result, sys};
use cudarc::curand::{safe, result, sys};
where:
sys
is the raw ffi apis generated with bindgenresult
is a very small wrapper around sys to returnResult
from each functionsafe
is a wrapper around result/sys to provide safe abstractions
Heavily recommend sticking with safe APIs
API Preview
It's easy to create a new device and transfer data to the gpu:
let dev = cudarc::driver::CudaDevice::new(0)?;
// allocate buffers
let inp = dev.htod_copy(vec![1.0f32; 100])?;
let mut out = dev.alloc_zeros::<f32>(100)?;
You can also use the nvrtc api to compile kernels at runtime:
let ptx = cudarc::nvrtc::compile_ptx("
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, const size_t numel) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}")?;
// and dynamically load it into the device
dev.load_ptx(ptx, "my_module", &["sin_kernel"])?;
cudarc
provides a very simple interface to launch kernels, tuples
are the arguments!
let sin_kernel = dev.get_func("my_module", "sin_kernel").unwrap();
let cfg = LaunchConfig::for_num_elems(100);
unsafe { sin_kernel.launch(cfg, (&mut out, &inp, 100usize)) }?;
And of course it's easy to copy things back to host after you're done:
let out_host: Vec<f32> = dev.dtoh_sync_copy(&out)?;
assert_eq!(out_host, [1.0; 100].map(f32::sin));
License
Dual-licensed to be compatible with the Rust project.
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
Dependencies
~0–270KB