2 releases
0.1.1 | Apr 1, 2024 |
---|---|
0.1.0 | Mar 2, 2024 |
#481 in Machine learning
21KB
411 lines
ZeNu
ZeNu is a simple and intuitive deep learning library written in Rust. It provides the building blocks for creating and training neural networks, with a focus on ease of use and flexibility.
Features
- High-level API for defining and training neural networks
- Integration with popular datasets like MNIST
- Modular design for easy extensibility
- Efficient computation using the underlying zenu-matrix and zenu-autograd libraries
Getting Started
To use ZeNu in your Rust project, add the following to your Cargo.toml
file:
[dependencies]
zenu = "0.1.0"
Here's a simple example of defining and training a model using ZeNu:
use zenu::{
dataset::{train_val_split, DataLoader, Dataset},
mnist::minist_dataset,
update_parameters, Model,
};
use zenu_autograd::{
creator::from_vec::from_vec,
functions::{activation::sigmoid::sigmoid, loss::cross_entropy::cross_entropy},
Variable,
};
use zenu_layer::{layers::linear::Linear, Layer};
use zenu_matrix::{
matrix::{IndexItem, ToViewMatrix},
operation::max::MaxIdx,
};
use zenu_optimizer::sgd::SGD;
// Define your model
struct SingleLayerModel {
linear: Linear<f32>,
}
impl SingleLayerModel {
fn new() -> Self {
let mut linear = Linear::new(784, 10);
linear.init_parameters(None);
Self { linear }
}
}
impl Model<f32> for SingleLayerModel {
fn predict(&self, inputs: &[Variable<f32>]) -> Variable<f32> {
let x = &inputs[0];
let x = self.linear.call(x.clone());
sigmoid(x)
}
}
// Define your dataset
struct MnistDataset {
data: Vec<(Vec<u8>, u8)>,
}
impl Dataset<f32> for MnistDataset {
type Item = (Vec<u8>, u8);
fn item(&self, item: usize) -> Vec<Variable<f32>> {
// ... Implement your dataset logic
}
fn len(&self) -> usize {
self.data.len()
}
fn all_data(&mut self) -> &mut [Self::Item] {
&mut self.data as &mut [Self::Item]
}
}
fn main() {
// Load and prepare your data
let (train, test) = minist_dataset().unwrap();
let (train, val) = train_val_split(&train, 0.8, true);
let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);
// Create your model and optimizer
let sgd = SGD::new(0.01);
let model = SingleLayerModel::new();
// Train your model
for epoch in 0..10 {
// ... Implement your training loop
}
// Evaluate your model
let mut test_loss = 0.;
let mut num_iter_test = 0;
let mut correct = 0;
let mut total = 0;
for batch in test_dataloader {
// ... Implement your evaluation logic
}
println!("Accuracy: {}", correct as f32 / total as f32);
println!("Test Loss: {}", test_loss / num_iter_test as f32);
}
For more details and examples, please refer to the documentation.
License
ZeNu is licensed under the MIT License.
Dependencies
~6–18MB
~272K SLoC