5 releases
Uses new Rust 2024
new 0.1.5 | Apr 18, 2025 |
---|---|
0.1.3 | Apr 17, 2025 |
0.1.2 | Apr 16, 2025 |
0.1.1 | Apr 9, 2025 |
0.1.0 | Apr 8, 2025 |
#10 in #objectives
426 downloads per month
3MB
1.5K
SLoC
Overview
Rust-ML is a personal project that aims to combine Rust programming skills with machine learning concepts. The project leverages Rust's performance characteristics, memory safety guarantees, and expressive type system to build reliable machine learning tools while providing an opportunity to deepen understanding of both Rust and machine learning algorithms.
This is a personal project, and I will be working on it when I have time. If you find it useful, feel free to use it, but please keep in mind that it's not a production-ready library. This project is not intended to be a full-fledged machine learning library like TensorFlow or PyTorch, but rather a learning tool for me to explore the intersection of Rust and machine learning.
Objectives
- Implement common machine learning algorithms in pure Rust
- Provide high-performance implementations suitable for production use
- Ensure memory safety and thread safety through Rust's ownership model
- Create an ergonomic API that is easy to use for both ML beginners and experts
- Support integration with existing data processing pipelines
Project Structure
The project is organized into several key modules:
Core (src/core/
)
Contains fundamental data structures and utilities:
types.rs
: Defines common types used throughout the library likeModelParams
param_storage.rs
: Provides storage mechanisms for model parametersparam_manager.rs
: Manages parameter operations and transformationserror.rs
: Defines error types and handling for the core module
Model (src/model/
)
Houses machine learning model implementations:
-
core/
: Contains base traits and interfaces for models:base.rs
: Defines base model traitsoptimizable_model.rs
: Trait for models that can be optimizedregression_model.rs
: Trait for regression modelsclassification_model.rs
: Trait for classification models
-
linear_regression.rs
: Implementation of linear regression model -
Other model implementations
Optimization (src/optimization/
)
Provides optimization algorithms for training models:
-
core/
:optimizer.rs
: Defines the Optimizer trait and related interfaces
-
gradient_descent.rs
: Implementation of gradient descent optimization
Benchmarking (src/bench/
)
Tools for profiling and evaluating model performance:
-
core/
:profiler.rs
: Base trait for profiling model training and evaluationerror.rs
: Error handling for profiling operationstrain_metrics.rs
: Defines metrics for training evaluation
-
regression_profiler.rs
: Profiler for regression models -
regression_metrics.rs
: Metrics specific to regression models
Builders (src/builders/
)
Factory patterns for creating and configuring models:
builder.rs
: Defines builder interfaceslinear_regression.rs
: Builder for linear regression models
Examples (src/bin/
)
Executable examples demonstrating library usage:
linear_regression.rs
: Example usage of linear regressionsl_classifier.rs
: Example of supervised learning classifier
Error Handling
The library employs a robust error handling strategy using Rust's Result
type
and the thiserror
crate:
- Domain-specific error types are defined for each module
- Error conversion is facilitated with the
From
trait - The
?
operator is used for ergonomic error propagation
This approach ensures clear error messages and type-safe error handling throughout the codebase.
Getting Started
Prerequisites
- Rust (stable version 1.86.0 or higher recommended)
- Cargo (comes with Rust installation)
Installation
Add Rust-ML to your project by including it in your Cargo.toml
:
[dependencies]
rust-ml = "0.1.0"
Or clone the repository to contribute or run examples:
git clone https://github.com/username/rust-ml.git
cd rust-ml
cargo build
Running Tests
cargo test
Running Benchmarks
cargo bench
Example Usage
Linear Regression
use rust_ml::model::linear_regression::LinearRegression;
use rust_ml::model::ml_model::RegressionModel;
use rust_ml::optimization::gradient_descent::GradientDescent;
use rust_ml::optimization::optimizer::Optimizer;
use ndarray::{Array1, Array2};
fn main() {
// Create sample data
let x_train = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
let y_train = Array1::from_vec(vec![6.0, 9.0, 12.0]);
// Initialize model and optimizer
let mut model = LinearRegression::new(2); // 2 features
let mut optimizer = GradientDescent::new(0.01, 1000); // learning rate and iterations
// Train model
optimizer.fit(&mut model, &x_train, &y_train).expect("Training failed");
// Make predictions
let x_test = Array2::from_shape_vec((1, 2), vec![4.0, 5.0]).unwrap();
let predictions = model.predict(&x_test);
println!("Prediction: {:?}", predictions);
}
Profiling Models
The library provides profiling capabilities to measure model performance:
use rust_ml::bench::regression_profiler::RegressionProfiler;
use rust_ml::bench::profiler::Profiler;
use rust_ml::model::linear_regression::LinearRegression;
use rust_ml::optimization::gradient_descent::GradientDescent;
fn main() {
// Initialize data, model and optimizer
// [...]
// Create profiler
let profiler = RegressionProfiler::<LinearRegression, GradientDescent, _, _>::new();
// Profile training
let (train_metrics, eval_metrics) = profiler
.profile_training(&mut model, &mut optimizer, &x_train, &y_train)
.expect("Profiling failed");
println!("Training metrics: {:?}", train_metrics);
println!("Evaluation metrics: {:?}", eval_metrics);
}
Dependencies
The project relies on several high-quality Rust crates:
thiserror
: For ergonomic error handlingndarray
: For efficient numerical computationsndarray-rand
: For random matrix generationpolars
: For data manipulationrand
: For random number generationapprox
: For approximate floating-point comparisons
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature
) - Commit your changes (
git commit -m 'Add some amazing feature'
) - Push to the branch (
git push origin feature/amazing-feature
) - Open a Pull Request
Dependencies
~23–34MB
~568K SLoC