17 releases (5 breaking)
Uses new Rust 2024
| new 0.6.0 | Apr 9, 2026 |
|---|---|
| 0.5.0 | Mar 31, 2026 |
| 0.4.3 | Mar 25, 2026 |
| 0.3.2 | Feb 27, 2026 |
| 0.1.0 | Jan 19, 2026 |
#196 in Machine learning
58 downloads per month
Used in 12 crates
1.5MB
39K
SLoC
axonml-nn
Overview
axonml-nn provides neural network building blocks for the AxonML framework. It includes layers, activation functions, loss functions, and utilities for constructing and training deep learning models with a PyTorch-like API.
Features
-
Module Trait - Core interface for all neural network components with parameter management and train/eval modes.
-
Comprehensive Layers - Linear, Conv1d/Conv2d, RNN/LSTM/GRU, Embedding, BatchNorm, LayerNorm, Dropout, and MultiHeadAttention. RNN/LSTM/GRU layers use batched matmul to process all timesteps efficiently, avoiding per-step allocation overhead.
-
Activation Functions - ReLU, Sigmoid, Tanh, GELU, SiLU, ELU, LeakyReLU, Softmax, and LogSoftmax.
-
Loss Functions - MSELoss, CrossEntropyLoss, BCELoss, BCEWithLogitsLoss, NLLLoss, L1Loss, and SmoothL1Loss.
-
Weight Initialization - Xavier/Glorot, Kaiming/He, orthogonal, sparse, and custom initialization schemes.
-
Sequential Container - Easy model composition by chaining layers together.
Modules
| Module | Description |
|---|---|
module |
Core Module trait and ModuleList container for neural network components |
parameter |
Parameter wrapper for learnable weights with gradient tracking |
sequential |
Sequential container for chaining modules in order |
layers |
Neural network layers (Linear, Conv, RNN, Attention, Norm, Pooling, Embedding, Dropout) |
activation |
Activation function modules (ReLU, Sigmoid, Tanh, GELU, etc.) |
loss |
Loss function modules (MSE, CrossEntropy, BCE, etc.) |
init |
Weight initialization functions (Xavier, Kaiming, orthogonal, etc.) |
functional |
Stateless functional versions of operations |
Usage
Add this to your Cargo.toml:
[dependencies]
axonml-nn = "0.1.0"
Building a Simple MLP
use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
// Build model using Sequential
let model = Sequential::new()
.add(Linear::new(784, 256))
.add(ReLU)
.add(Linear::new(256, 128))
.add(ReLU)
.add(Linear::new(128, 10));
// Create input
let input = Variable::new(
Tensor::from_vec(vec![0.5; 784], &[1, 784]).unwrap(),
false
);
// Forward pass
let output = model.forward(&input);
assert_eq!(output.shape(), vec![1, 10]);
// Get all parameters
let params = model.parameters();
println!("Total parameters: {}", model.num_parameters());
Convolutional Neural Network
use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
let model = Sequential::new()
.add(Conv2d::new(1, 32, 3)) // [B, 1, 28, 28] -> [B, 32, 26, 26]
.add(ReLU)
.add(MaxPool2d::new(2)) // -> [B, 32, 13, 13]
.add(Conv2d::new(32, 64, 3)) // -> [B, 64, 11, 11]
.add(ReLU)
.add(MaxPool2d::new(2)); // -> [B, 64, 5, 5]
let input = Variable::new(
Tensor::from_vec(vec![0.5; 784], &[1, 1, 28, 28]).unwrap(),
false
);
let features = model.forward(&input);
Recurrent Neural Network
use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
// LSTM for sequence modeling
let lstm = LSTM::new(
64, // input_size
128, // hidden_size
2 // num_layers
);
// Input: [batch, seq_len, input_size]
let input = Variable::new(
Tensor::from_vec(vec![0.5; 640], &[2, 5, 64]).unwrap(),
false
);
let output = lstm.forward(&input); // [2, 5, 128]
Transformer Attention
use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
let attention = MultiHeadAttention::new(
512, // embed_dim
8 // num_heads
);
// Input: [batch, seq_len, embed_dim]
let input = Variable::new(
Tensor::from_vec(vec![0.5; 5120], &[2, 5, 512]).unwrap(),
false
);
let output = attention.forward(&input); // [2, 5, 512]
Loss Functions
use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
// Cross Entropy Loss for classification
let logits = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
true
);
let targets = Variable::new(
Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(),
false
);
let loss_fn = CrossEntropyLoss::new();
let loss = loss_fn.compute(&logits, &targets);
loss.backward();
// MSE Loss for regression
let mse = MSELoss::new();
let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.5, 2.5], &[2]).unwrap(), false);
let loss = mse.compute(&pred, &target);
Weight Initialization
use axonml_nn::init::*;
// Xavier/Glorot initialization
let weights = xavier_uniform(256, 128);
let weights = xavier_normal(256, 128);
// Kaiming/He initialization (for ReLU networks)
let weights = kaiming_uniform(256, 128);
let weights = kaiming_normal(256, 128);
// Other initializations
let zeros_tensor = zeros(&[3, 3]);
let ones_tensor = ones(&[3, 3]);
let eye_tensor = eye(4);
let ortho_tensor = orthogonal(64, 64);
Training/Evaluation Mode
use axonml_nn::prelude::*;
let mut model = Sequential::new()
.add(Linear::new(10, 5))
.add(Dropout::new(0.5))
.add(Linear::new(5, 2));
// Training mode (dropout active)
model.train();
assert!(model.is_training());
// Evaluation mode (dropout disabled)
model.eval();
assert!(!model.is_training());
// Zero gradients before backward pass
model.zero_grad();
Differentiable Structured Sparsity (novel)
Learn which weights to prune end-to-end — the pruning mask is differentiable.
use axonml_nn::layers::sparse::{SparseLinear, GroupSparsity, LotteryTicket};
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
// SparseLinear: soft-thresholded magnitude pruning
let sparse = SparseLinear::new(256, 128, 0.5, 10.0); // threshold=0.5, temperature=10.0
let input = Variable::new(Tensor::randn(&[4, 256]), true);
let output = sparse.forward(&input); // [4, 128]
// Check actual sparsity
let sparsity = sparse.sparsity(); // fraction of weights effectively pruned
println!("Sparsity: {:.1}%", sparsity * 100.0);
// Group sparsity regularization
let group_reg = GroupSparsity::new(0.01, "row"); // L1 penalty on row norms
let reg_loss = group_reg.compute(&sparse);
// Lottery Ticket Hypothesis
let mut ticket = LotteryTicket::new(&sparse);
ticket.snapshot(); // save initial weights
// ... train for a while ...
ticket.prune(0.2); // prune bottom 20% by magnitude
ticket.rewind(&mut sparse); // rewind to initial weights with discovered mask
Tests
Run the test suite (171 tests):
cargo test -p axonml-nn
License
Licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Dependencies
~8–10MB
~204K SLoC