4 releases
| 0.1.3 | Jan 12, 2026 |
|---|---|
| 0.1.2 | Jan 12, 2026 |
| 0.1.1 | Jan 12, 2026 |
| 0.1.0 | Jan 12, 2026 |
#411 in Algorithms
Used in gllm
325KB
8K
SLoC
gllm-kernels
Low-level attention kernels for gllm with CUDA/ROCm support.
Features
- FlashAttention: Memory-efficient attention with O(N) memory complexity
- Hierarchical Attention: Multi-level attention for ultra-long contexts (2M+ tokens)
- CUDA Kernels: Native CUDA implementation with PTX for NVIDIA GPUs
- ROCm/HIP Kernels: AMD GPU support (experimental)
- Multiple Backends: CPU (ndarray), CUDA, WebGPU via Burn
Performance
| Implementation | Time (seq=512) | vs burn_cuda |
|---|---|---|
| cuda_kernel | 21.27ms | 37% faster |
| burn_cuda | 33.83ms | baseline |
Installation
Add to your Cargo.toml:
[dependencies]
gllm-kernels = "0.1"
Feature Flags
| Feature | Description | Default |
|---|---|---|
cpu |
CPU backend via burn-ndarray | Yes |
cuda |
CUDA backend via burn-cuda | No |
cuda-kernel |
Native CUDA kernels (requires CUDA toolkit) | No |
wgpu |
WebGPU backend | No |
rocm-kernel |
ROCm/HIP kernels (experimental) | No |
Usage
Basic FlashAttention
use gllm_kernels::ops::flash_attention::{
HierarchicalFlashAttention, AttentionConfig
};
// Create attention module
let attention = HierarchicalFlashAttention::new(
num_heads,
head_dim,
AttentionConfig::default(),
);
// Forward pass
let output = attention.forward(q, k, v, mask)?;
CUDA Kernel (Native)
use gllm_kernels::cuda_kernels::FlashAttentionKernel;
use cudarc::driver::CudaContext;
use std::sync::Arc;
let ctx = Arc::new(CudaContext::new(0)?);
let kernel = FlashAttentionKernel::new(&ctx)?;
let output = kernel.forward(
&stream, &q, &k, &v,
batch_size, num_heads, seq_len, head_dim,
is_causal, scale, position_offset,
)?;
Deterministic Mode
For reproducible results in ultra-long context scenarios:
use gllm_kernels::ops::flash_attention::DeterministicConfig;
let config = AttentionConfig {
determinism: DeterministicConfig::strict(),
..Default::default()
};
Architecture
gllm-kernels
├── ops/
│ ├── flash_attention.rs # HierarchicalFlashAttention
│ ├── flash_attention_v3.rs # Advanced attention variants
│ ├── paged_attention.rs # KV cache paging
│ ├── ring_attention.rs # Distributed attention
│ ├── sparse_attention.rs # Sparse patterns
│ ├── mla.rs # Multi-head Latent Attention
│ ├── mamba.rs # State space models
│ └── kv_compression.rs # KV cache compression
├── cuda_kernels/
│ ├── flash_attn.rs # CUDA kernel bindings
│ └── kernels/
│ ├── tiled_attention.cu # CUDA source
│ └── tiled_attention.ptx # Compiled PTX (sm_61)
├── hip_kernels/ # ROCm/HIP (experimental)
└── comm/ # Distributed communication
Building CUDA Kernels
If you need to recompile PTX for a different GPU architecture:
cd src/cuda_kernels/kernels
nvcc -ptx -arch=sm_XX tiled_attention.cu -o tiled_attention.ptx
Replace sm_XX with your GPU's compute capability (e.g., sm_61 for GTX 1060, sm_86 for RTX 3090).
Or set the environment variable to use a custom PTX:
export GLLM_FLASH_ATTN_PTX=/path/to/your/compiled.ptx
License
Apache-2.0
Dependencies
~9–56MB
~884K SLoC