5 releases
new 0.1.4 | Apr 15, 2025 |
---|---|
0.1.3 | Apr 15, 2025 |
0.1.2 | Apr 15, 2025 |
0.1.1 | Apr 15, 2025 |
0.1.0 | Apr 15, 2025 |
#135 in Machine learning
85 downloads per month
120KB
1.5K
SLoC
Hibachi
Efficient batched inference tensor models
Hibachi is a Rust library for efficient batched inference with autoregressive (and soon feedforward) models. It dynamically groups multiple generation requests into batches, manages tensor operations, and streams results back to clients as they become available.
Key Features
- Dynamic Batching - Optimizes resource utilization by batching requests
- Asynchronous Processing - Non-blocking architecture built on Tokio
- Stream-Based API - Tokens are streamed back to clients as they're generated
- Backend Agnostic - Works with any tensor library that implements the
Backend
trait, includes implementations forCandle
andBurn
backends (maxBurn
tensor rank of9
) - Memory Efficient - Manages tensor padding, concatenation, and cleanup
Installation
Add this to your Cargo.toml
:
[dependencies]
hibachi = {version = "0.1.0", features = ["candle", "autoregressive"] }# burn flag available as well
tokio = { version = "1", features = ["full"] }
Quick Start
use hibachi::autoregressive::{Autoregressive, AutoregressiveBatcher, AutoregressiveBatchInference};
use hibachi::backend::{Backend, Unsqueezable};
use std::sync::Arc;
use candle_core::{Tensor, Device, DType};
// 1. Implement the Autoregressive trait for your model
struct MyModel { /* ... */ }
#[async_trait]
impl Autoregressive<MyTensor> for MyModel {
async fn forward(&self, tensor: <MyTensor as Unsqueezable>::Unsqueezed) -> MyTensor {
// Implement your model's forward pass
}
}
// 3. Create the batched inference engine
#[tokio::main]
async fn main() {
// Initialize model
let model = MyModel::new();
let device = Device::Cpu;
// will be of rank + 1
let stop_token = Tensor::ones(&[1], DType::U8, &device).unwrap();
let padding_token = Tensor::zeros(&[1], DType::U8, &device).unwrap();
// Create inference engine with max batch size of 16
let engine = AutoregressiveBatchInference::<Tensor, 16>::new(
model,
&stop_token,
&padding_token
);
// Process requests
let input = Tensor::arange(2., 5., &device);
let mut stream = engine.run(input).await;
// Stream results
while let Some(token) = stream.next().await {
println!("Generated token: {:?}", token);
}
}
Architecture
Tensor Batch consists of several core components:
-
Backend Abstraction
- Traits that define required tensor operations
- Enables support for different tensor libraries
-
Autoregressive Models
- Interface for models that predict the next token based on previous tokens
- Supports variable batch and sequence dimensions
-
Batching Engine
- Dynamically manages multiple generation requests
- Handles tensor padding, concatenation, and state management
- Streams generated tokens back to clients
-
Communication Layer
- Asynchronous channels for efficient token streaming
- Proper error handling and resource cleanup
Advanced Usage
Custom Tensor Backends
To use with a custom tensor library, implement the Backend
and Unsqueezable
traits:
use tensor_batch::backend::{Backend, Unsqueezable};
impl Backend for MyCustomTensor {
fn shape(&self) -> Vec<usize> { /* ... */ }
fn clone(&self) -> Self { /* ... */ }
// ... implement other required methods
}
impl Unsqueezable for MyCustomTensor {
type Unsqueezed = MyCustomTensorHigherDim;
fn unsqueeze(&self, dim: usize) -> Self::Unsqueezed { /* ... */ }
}
Custom Autoregressive Models
Implement the Autoregressive
trait for your model:
use tensor_batch::autoregressive::Autoregressive;
use async_trait::async_trait;
#[async_trait]
impl Autoregressive<MyTensor> for MyTransformerModel {
async fn forward(&self, tensor: <MyTensor as Unsqueezable>::Unsqueezed) -> MyTensor {
// Your transformer forward logic here
// Input shape: (batch, seq, ...)
// Output shape: (batch, ...)
}
}
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Building Docs
RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features
Dependencies
~4–43MB
~696K SLoC