15 releases (4 breaking)

Uses new Rust 2024

new 0.6.2 Apr 17, 2026
0.6.0 Apr 9, 2026
0.5.0 Mar 31, 2026
0.4.3 Mar 25, 2026
0.2.8 Jan 26, 2026

#2226 in Machine learning

MIT/Apache

1MB
22K SLoC

axonml-fusion

AxonML Logo

License: Apache-2.0 Rust: 1.75+ Version: 0.6.1 Part of AxonML

Overview

axonml-fusion provides kernel fusion for AxonML: combining sequences of ops into a single pass to cut memory traffic and kernel-launch overhead. It exposes a FusedOp trait, a FusedLinear (matmul + bias + activation) kernel backed by matrixmultiply, an ElementwiseOp chain with a fluent builder, and a FusionOptimizer that runs pattern detection over an [OpType] graph and reports statistics.

Features

  • Pattern Detectiondetect_patterns(&[OpType]) returns fusion opportunities with start/end indices
  • Linear FusionFusedLinear (MatMul + optional bias + activation) with None, Relu, Gelu (tanh approximation), Sigmoid, Tanh, Silu
  • Elementwise FusionFusedElementwise chain with FusedElementwise::builder() fluent API
  • Graph OptimizerFusionOptimizer with FusionConfig::all_enabled() / FusionConfig::conservative(); OptimizationStats with fusions_applied, ops_eliminated, estimated_speedup
  • Convenience Constructorsfuse_matmul_bias_relu, fuse_matmul_bias_gelu, fuse_matmul_bias, fused_add_relu, fused_mul_add, fused_scale_bias_relu, fuse_elementwise, optimize_graph, estimate_speedup
  • Parallel Execution — Rayon for tensor operations; matrixmultiply for the inner GEMM

Modules

Module Description
patterns FusionPattern enum, OpType enum, detect_patterns
elementwise ElementwiseOp, FusedElementwise, FusedElementwiseBuilder, convenience constructors
linear Activation enum, FusedLinear (MatMul + Bias + Activation), fuse_matmul_bias* constructors
optimizer FusionConfig, FusionOptimizer, OptimizationStats, optimize_graph, estimate_speedup
error FusionError / FusionResult

Usage

Add this to your Cargo.toml:

[dependencies]
axonml-fusion = "0.6.1"

Fused Linear Operations

use axonml_fusion::{fuse_matmul_bias_relu, FusedLinear, Activation};
use axonml_tensor::Tensor;

// Create weight and bias tensors
let weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])?;
let bias   = Tensor::from_vec(vec![0.5, 0.5], &[2])?;

// Create fused MatMul + Bias + ReLU operation
let fused = fuse_matmul_bias_relu(&weight, &bias)?;

// Execute fused operation
let input  = Tensor::from_vec(vec![1.0, 1.0], &[1, 2])?;
let output = fused.forward(&input)?;

// Or construct directly
let fl = FusedLinear::new(weight, Some(bias), Activation::Gelu)?;

FusedLinear::forward accepts 1D or 2D inputs (batch × in_features).

Fused Elementwise Operations

use axonml_fusion::{FusedElementwise, ElementwiseOp};

// Build a fused elementwise chain using the builder
let fused = FusedElementwise::builder()
    .mul(2.0)      // Scale by 2
    .add(1.0)      // Add bias
    .relu()        // Apply ReLU
    .build();

let output = fused.forward(&input)?;

// Or construct from an explicit op list
let f = FusedElementwise::new(vec![
    ElementwiseOp::MulConst(2.0),
    ElementwiseOp::AddConst(1.0),
    ElementwiseOp::Relu,
]);

Graph Optimization

use axonml_fusion::{optimize_graph, patterns::OpType};

// Define operation sequence
let ops = vec![
    OpType::MatMul,
    OpType::Add,
    OpType::Relu,
    OpType::Add,
    OpType::Mul,
];

// Optimize with default configuration (None uses all-enabled defaults)
let (patterns, stats) = optimize_graph(&ops, None)?;

println!("Fusions applied:      {}", stats.fusions_applied);
println!("Operations eliminated: {}", stats.ops_eliminated);
println!("Estimated speedup:    {:.2}x", stats.estimated_speedup);

Custom Fusion Configuration

use axonml_fusion::{FusionOptimizer, FusionConfig};

// Preset configurations
let config = FusionConfig::conservative();
let config = FusionConfig::all_enabled();

// Or customize
let config = FusionConfig {
    fuse_elementwise: true,
    fuse_linear: true,
    fuse_conv: false,
    min_elementwise_chain: 3,
    aggressive: false,
};

let optimizer = FusionOptimizer::with_config(config);
let patterns  = optimizer.analyze(&ops);

Supported Fusion Patterns

All variants of FusionPattern with their num_ops() and estimated_speedup():

Pattern Operations num_ops Estimated Speedup
MatMulBias MatMul, Add 2 1.2x
MatMulBiasRelu MatMul, Add, ReLU 3 1.3x
MatMulBiasGelu MatMul, Add, GELU 3 1.3x
ConvBatchNorm Conv, BatchNorm 3 1.3x
ConvBatchNormRelu Conv, BatchNorm, ReLU 4 1.4x
ElementwiseChain 2+ elementwise ops variable 2.0x
Softmax 3 ops 3 1.2x
LayerNorm 4 ops 4 1.2x
GeluApprox 5 ops 5 1.2x
AddRelu Add, ReLU 2 1.8x
MulAdd Mul, Add (FMA) 2 1.5x

Elementwise Operations

FusedElementwise::builder() methods:

Method Op
add(c: f32) ElementwiseOp::AddConst(c)
mul(c: f32) ElementwiseOp::MulConst(c)
relu() ElementwiseOp::Relu
leaky_relu(alpha: f32) ElementwiseOp::LeakyRelu(alpha)
sigmoid() ElementwiseOp::Sigmoid
tanh() ElementwiseOp::Tanh
exp() ElementwiseOp::Exp
log() ElementwiseOp::Log (natural log)
sqrt() ElementwiseOp::Sqrt
square() ElementwiseOp::Square
clamp(min, max) ElementwiseOp::Clamp(min, max)
neg() ElementwiseOp::Neg
abs() ElementwiseOp::Abs

Convenience Constructors

  • fuse_matmul_bias(weight, bias)Activation::None
  • fuse_matmul_bias_relu(weight, bias)Activation::Relu
  • fuse_matmul_bias_gelu(weight, bias)Activation::Gelu
  • fused_add_relu(bias: f32) — add then ReLU
  • fused_mul_add(scale, bias) — FMA-style chain
  • fused_scale_bias_relu(scale, bias) — norm-style chain
  • fuse_elementwise(ops: Vec<ElementwiseOp>) — explicit op chain

Tests

cargo test -p axonml-fusion

License

Licensed under either of:

  • MIT License
  • Apache License, Version 2.0

at your option.

Dependencies

~6MB
~117K SLoC