15 releases (4 breaking)
Uses new Rust 2024
| new 0.6.2 | Apr 17, 2026 |
|---|---|
| 0.6.0 | Apr 9, 2026 |
| 0.5.0 | Mar 31, 2026 |
| 0.4.3 | Mar 25, 2026 |
| 0.2.8 | Jan 26, 2026 |
#2226 in Machine learning
1MB
22K
SLoC
axonml-fusion
Overview
axonml-fusion provides kernel fusion for AxonML: combining sequences of ops
into a single pass to cut memory traffic and kernel-launch overhead. It
exposes a FusedOp trait, a FusedLinear (matmul + bias + activation) kernel
backed by matrixmultiply, an ElementwiseOp chain with a fluent builder,
and a FusionOptimizer that runs pattern detection over an [OpType] graph
and reports statistics.
Features
- Pattern Detection —
detect_patterns(&[OpType])returns fusion opportunities with start/end indices - Linear Fusion —
FusedLinear(MatMul + optional bias + activation) withNone,Relu,Gelu(tanh approximation),Sigmoid,Tanh,Silu - Elementwise Fusion —
FusedElementwisechain withFusedElementwise::builder()fluent API - Graph Optimizer —
FusionOptimizerwithFusionConfig::all_enabled()/FusionConfig::conservative();OptimizationStatswithfusions_applied,ops_eliminated,estimated_speedup - Convenience Constructors —
fuse_matmul_bias_relu,fuse_matmul_bias_gelu,fuse_matmul_bias,fused_add_relu,fused_mul_add,fused_scale_bias_relu,fuse_elementwise,optimize_graph,estimate_speedup - Parallel Execution — Rayon for tensor operations;
matrixmultiplyfor the inner GEMM
Modules
| Module | Description |
|---|---|
patterns |
FusionPattern enum, OpType enum, detect_patterns |
elementwise |
ElementwiseOp, FusedElementwise, FusedElementwiseBuilder, convenience constructors |
linear |
Activation enum, FusedLinear (MatMul + Bias + Activation), fuse_matmul_bias* constructors |
optimizer |
FusionConfig, FusionOptimizer, OptimizationStats, optimize_graph, estimate_speedup |
error |
FusionError / FusionResult |
Usage
Add this to your Cargo.toml:
[dependencies]
axonml-fusion = "0.6.1"
Fused Linear Operations
use axonml_fusion::{fuse_matmul_bias_relu, FusedLinear, Activation};
use axonml_tensor::Tensor;
// Create weight and bias tensors
let weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])?;
let bias = Tensor::from_vec(vec![0.5, 0.5], &[2])?;
// Create fused MatMul + Bias + ReLU operation
let fused = fuse_matmul_bias_relu(&weight, &bias)?;
// Execute fused operation
let input = Tensor::from_vec(vec![1.0, 1.0], &[1, 2])?;
let output = fused.forward(&input)?;
// Or construct directly
let fl = FusedLinear::new(weight, Some(bias), Activation::Gelu)?;
FusedLinear::forward accepts 1D or 2D inputs (batch × in_features).
Fused Elementwise Operations
use axonml_fusion::{FusedElementwise, ElementwiseOp};
// Build a fused elementwise chain using the builder
let fused = FusedElementwise::builder()
.mul(2.0) // Scale by 2
.add(1.0) // Add bias
.relu() // Apply ReLU
.build();
let output = fused.forward(&input)?;
// Or construct from an explicit op list
let f = FusedElementwise::new(vec![
ElementwiseOp::MulConst(2.0),
ElementwiseOp::AddConst(1.0),
ElementwiseOp::Relu,
]);
Graph Optimization
use axonml_fusion::{optimize_graph, patterns::OpType};
// Define operation sequence
let ops = vec![
OpType::MatMul,
OpType::Add,
OpType::Relu,
OpType::Add,
OpType::Mul,
];
// Optimize with default configuration (None uses all-enabled defaults)
let (patterns, stats) = optimize_graph(&ops, None)?;
println!("Fusions applied: {}", stats.fusions_applied);
println!("Operations eliminated: {}", stats.ops_eliminated);
println!("Estimated speedup: {:.2}x", stats.estimated_speedup);
Custom Fusion Configuration
use axonml_fusion::{FusionOptimizer, FusionConfig};
// Preset configurations
let config = FusionConfig::conservative();
let config = FusionConfig::all_enabled();
// Or customize
let config = FusionConfig {
fuse_elementwise: true,
fuse_linear: true,
fuse_conv: false,
min_elementwise_chain: 3,
aggressive: false,
};
let optimizer = FusionOptimizer::with_config(config);
let patterns = optimizer.analyze(&ops);
Supported Fusion Patterns
All variants of FusionPattern with their num_ops() and estimated_speedup():
| Pattern | Operations | num_ops | Estimated Speedup |
|---|---|---|---|
MatMulBias |
MatMul, Add | 2 | 1.2x |
MatMulBiasRelu |
MatMul, Add, ReLU | 3 | 1.3x |
MatMulBiasGelu |
MatMul, Add, GELU | 3 | 1.3x |
ConvBatchNorm |
Conv, BatchNorm | 3 | 1.3x |
ConvBatchNormRelu |
Conv, BatchNorm, ReLU | 4 | 1.4x |
ElementwiseChain |
2+ elementwise ops | variable | 2.0x |
Softmax |
3 ops | 3 | 1.2x |
LayerNorm |
4 ops | 4 | 1.2x |
GeluApprox |
5 ops | 5 | 1.2x |
AddRelu |
Add, ReLU | 2 | 1.8x |
MulAdd |
Mul, Add (FMA) | 2 | 1.5x |
Elementwise Operations
FusedElementwise::builder() methods:
| Method | Op |
|---|---|
add(c: f32) |
ElementwiseOp::AddConst(c) |
mul(c: f32) |
ElementwiseOp::MulConst(c) |
relu() |
ElementwiseOp::Relu |
leaky_relu(alpha: f32) |
ElementwiseOp::LeakyRelu(alpha) |
sigmoid() |
ElementwiseOp::Sigmoid |
tanh() |
ElementwiseOp::Tanh |
exp() |
ElementwiseOp::Exp |
log() |
ElementwiseOp::Log (natural log) |
sqrt() |
ElementwiseOp::Sqrt |
square() |
ElementwiseOp::Square |
clamp(min, max) |
ElementwiseOp::Clamp(min, max) |
neg() |
ElementwiseOp::Neg |
abs() |
ElementwiseOp::Abs |
Convenience Constructors
fuse_matmul_bias(weight, bias)—Activation::Nonefuse_matmul_bias_relu(weight, bias)—Activation::Relufuse_matmul_bias_gelu(weight, bias)—Activation::Gelufused_add_relu(bias: f32)— add then ReLUfused_mul_add(scale, bias)— FMA-style chainfused_scale_bias_relu(scale, bias)— norm-style chainfuse_elementwise(ops: Vec<ElementwiseOp>)— explicit op chain
Tests
cargo test -p axonml-fusion
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.
Dependencies
~6MB
~117K SLoC