1 unstable release
new 0.1.0 | Jan 28, 2025 |
---|
#357 in Machine learning
71 downloads per month
94KB
1.5K
SLoC
Prefrontal
A blazing fast text classifier for real-time agent routing, built in Rust. Prefrontal provides a simple yet powerful interface for routing conversations to the right agent or department based on text content, powered by transformer-based models.
System Requirements
To use Prefrontal in your project, you need:
-
Build Dependencies
- Linux:
sudo apt-get install cmake pkg-config libssl-dev
- macOS:
brew install cmake pkg-config openssl
- Windows: Install CMake and OpenSSL
- Linux:
-
Runtime Requirements
- Internet connection for first-time model downloads
- Models are automatically downloaded and cached
Note: ONNX Runtime v2.0.0-rc.9 is automatically downloaded and managed by the crate.
Model Downloads
On first use, Prefrontal will automatically download the required model files from HuggingFace:
- Model:
https://huggingface.co/axar-ai/minilm/resolve/main/model.onnx
- Tokenizer:
https://huggingface.co/axar-ai/minilm/resolve/main/tokenizer.json
The files are cached locally (see Model Cache Location section). You can control the download behavior using the ModelManager
:
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
// Check if model is already downloaded
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
Features
- ⚡️ Blazing fast (~10ms in barebone M1) classification for real-time routing
- 🎯 Optimized for agent routing and conversation classification
- 🚀 Easy-to-use builder pattern interface
- 🔧 Support for both built-in and custom ONNX models
- 📊 Multiple class classification with confidence scores
- 🛠️ Automatic model management and downloading
- 🧪 Thread-safe for high-concurrency environments
- 📝 Comprehensive logging and error handling
Installation
Add this to your Cargo.toml
:
[dependencies]
prefrontal = "0.1.0"
Quick Start
use prefrontal::{Classifier, BuiltinModel, ClassDefinition, ModelManager};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Download the model if not already present
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
// Initialize the classifier with built-in MiniLM model
let classifier = Classifier::builder()
.with_model(model)?
.add_class(
ClassDefinition::new(
"technical_support",
"Technical issues and product troubleshooting"
).with_examples(vec![
"my app keeps crashing",
"can't connect to the server",
"integration setup help"
])
)?
.add_class(
ClassDefinition::new(
"billing",
"Billing and subscription inquiries"
).with_examples(vec![
"update credit card",
"cancel subscription",
"billing cycle question"
])
)?
.build()?;
// Route incoming message
let (department, scores) = classifier.predict("I need help setting up the API integration")?;
println!("Route to: {}", department);
println!("Confidence scores: {:?}", scores);
Ok(())
}
Built-in Models
MiniLM
The MiniLM model is a small and efficient model optimized for text classification:
-
Embedding Size: 384 dimensions
- Determines the size of vectors used for similarity calculations
- Balanced for capturing semantic relationships while being memory efficient
-
Max Sequence Length: 256 tokens
- Each token roughly corresponds to 4-5 characters
- Longer inputs are automatically truncated
-
Model Size: ~85MB
- Compact size for efficient deployment
- Good balance of speed and accuracy
Runtime Configuration
The ONNX runtime configuration is optional. If not specified, the classifier uses a default configuration optimized for most use cases:
// Using default configuration (recommended for most cases)
let classifier = Classifier::builder()
.with_model(BuiltinModel::MiniLM)?
.build()?;
// Or with custom runtime configuration
use prefrontal::{Classifier, RuntimeConfig};
use ort::session::builder::GraphOptimizationLevel;
let config = RuntimeConfig {
inter_threads: 4, // Number of threads for parallel model execution (0 = auto)
intra_threads: 2, // Number of threads for parallel computation within nodes (0 = auto)
optimization_level: GraphOptimizationLevel::Level3, // Maximum optimization
};
let classifier = Classifier::builder()
.with_runtime_config(config) // Optional: customize ONNX runtime behavior
.with_model(BuiltinModel::MiniLM)?
.build()?;
The default configuration (RuntimeConfig::default()
) uses:
- Automatic thread selection for both inter and intra operations (0 threads = auto)
- Maximum optimization level (Level3) for best performance
You can customize these settings if needed:
-
inter_threads: Number of threads for parallel model execution
- Controls parallelism between independent nodes in the model graph
- Set to 0 to let ONNX Runtime automatically choose based on system resources
-
intra_threads: Number of threads for parallel computation within nodes
- Controls parallelism within individual operations like matrix multiplication
- Set to 0 to let ONNX Runtime automatically choose based on system resources
-
optimization_level: The level of graph optimization to apply
- Higher levels perform more aggressive optimizations but may take longer to compile
- Level3 (maximum) is recommended for production use
Custom Models
You can use your own ONNX models:
let classifier = Classifier::builder()
.with_custom_model(
"path/to/model.onnx",
"path/to/tokenizer.json",
Some(512) // Optional: custom sequence length
)?
.add_class(
ClassDefinition::new(
"custom_class",
"Description of the custom class"
).with_examples(vec!["example text"])
)?
.build()?;
Model Management
The library includes a model management system that handles downloading and verifying models:
use prefrontal::{ModelManager, BuiltinModel};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let manager = ModelManager::new_default()?;
let model = BuiltinModel::MiniLM;
// Check if model is downloaded
if !manager.is_model_downloaded(model) {
println!("Downloading model...");
manager.download_model(model).await?;
}
// Verify model integrity
if !manager.verify_model(model)? {
println!("Model verification failed, redownloading...");
manager.download_model(model).await?;
}
Ok(())
}
Model Cache Location
Models are stored in one of the following locations, in order of precedence:
- The directory specified by the
PREFRONTAL_CACHE
environment variable, if set - The platform-specific cache directory:
- Linux:
~/.cache/prefrontal/
- macOS:
~/Library/Caches/prefrontal/
- Windows:
%LOCALAPPDATA%\prefrontal\Cache\
- Linux:
- A
.prefrontal
directory in the current working directory
You can override the default location by:
# Set a custom cache directory
export PREFRONTAL_CACHE=/path/to/your/cache
# Or when running your application
PREFRONTAL_CACHE=/path/to/your/cache cargo run
Class Definitions
Classes (departments or routing destinations) are defined with labels, descriptions, and optional examples:
// With examples for few-shot classification
let class = ClassDefinition::new(
"technical_support",
"Technical issues and product troubleshooting"
).with_examples(vec![
"integration problems",
"api errors"
]);
// Without examples for zero-shot classification
let class = ClassDefinition::new(
"billing",
"Payment and subscription related inquiries"
);
// Get routing configuration
let info = classifier.info();
println!("Number of departments: {}", info.num_classes);
Each class requires:
- A unique label that identifies the group (non-empty)
- A description that explains the category (non-empty, max 1000 characters)
- At least one example (when using examples)
- No empty example texts
- Maximum of 100 classes per classifier
Error Handling
The library provides detailed error types:
pub enum ClassifierError {
TokenizerError(String), // Tokenizer-related errors
ModelError(String), // ONNX model errors
BuildError(String), // Construction errors
PredictionError(String), // Prediction-time errors
ValidationError(String), // Input validation errors
DownloadError(String), // Model download errors
IoError(String), // File system errors
}
Logging
The library uses the log
crate for logging. Enable debug logging for detailed information:
use prefrontal::init_logger;
fn main() {
init_logger(); // Initialize with default configuration
// Or configure env_logger directly for more control
}
Thread Safety
The classifier is thread-safe and can be shared across threads using Arc
:
use std::sync::Arc;
use std::thread;
let classifier = Arc::new(classifier);
let mut handles = vec![];
for _ in 0..3 {
let classifier = Arc::clone(&classifier);
handles.push(thread::spawn(move || {
classifier.predict("test text").unwrap();
}));
}
for handle in handles {
handle.join().unwrap();
}
Performance
Benchmarks run on MacBook Pro M1, 16GB RAM:
Text Processing
Operation | Text Length | Time |
---|---|---|
Tokenization | Short (< 50 chars) | ~23.2 µs |
Tokenization | Medium (~100 chars) | ~51.5 µs |
Tokenization | Long (~200 chars) | ~107.0 µs |
Classification
Operation | Scenario | Time |
---|---|---|
Embedding | Single text | ~11.1 ms |
Classification | Single class | ~11.1 ms |
Classification | 10 classes | ~11.1 ms |
Build | 10 classes | ~450 ms |
Key Performance Characteristics:
- Sub-millisecond tokenization
- Consistent classification time regardless of number of classes
- Thread-safe for concurrent processing
- Optimized for real-time routing decisions
Memory Usage
- Model size: ~85MB (MiniLM)
- Runtime memory: ~100MB per instance
- Shared memory when using multiple instances (thread-safe)
Run the benchmarks yourself:
cargo bench
Contributing
Developer Prerequisites
If you want to contribute to Prefrontal or build it from source, you'll need:
- Build Dependencies (as listed in System Requirements)
- Rust Toolchain (latest stable version)
Development Setup
- Clone the repository:
git clone https://github.com/yourusername/prefrontal.git
cd prefrontal
-
Install dependencies as described in System Requirements
-
Build and test:
cargo build
cargo test
Running Tests
# Run all tests
cargo test
# Run specific test
cargo test test_name
# Run benchmarks
cargo bench
License
This project is licensed under the Apache License, Version 2.0 - see the LICENSE file for details.
Dependencies
~27–42MB
~731K SLoC