8 releases (4 breaking)

0.7.0 Oct 16, 2023
0.6.1 Dec 3, 2022
0.6.0 Jun 15, 2022
0.5.1 Mar 1, 2022
0.3.0 Jan 20, 2021

#715 in Machine learning

Download history 118/week @ 2023-12-13 95/week @ 2023-12-20 154/week @ 2023-12-27 149/week @ 2024-01-03 191/week @ 2024-01-10 132/week @ 2024-01-17 101/week @ 2024-01-24 222/week @ 2024-01-31 273/week @ 2024-02-07 280/week @ 2024-02-14 297/week @ 2024-02-21 338/week @ 2024-02-28 309/week @ 2024-03-06 307/week @ 2024-03-13 368/week @ 2024-03-20 363/week @ 2024-03-27

1,381 downloads per month

MIT/Apache

270KB
5K SLoC

Elastic Net

linfa-elasticnet provides a pure Rust implementations of elastic net linear regression.

The Big Picture

linfa-elasticnet is a crate in the linfa ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's scikit-learn.

Current state

The linfa-elasticnet crate provides linear regression with ridge and LASSO constraints. The solver uses coordinate descent to find an optimal solution.

This library contains an elastic net implementation for linear regression models. It combines l1 and l2 penalties of the LASSO and ridge methods and offers therefore a greater flexibility for feature selection. With increasing penalization certain parameters become zero, their corresponding variables are dropped from the model.

See also:

BLAS/Lapack backend

See this section to enable an external BLAS/LAPACK backend.

Examples

There is an usage example in the examples/ directory. To run, use:

$ cargo run --example elasticnet
Show source code
use linfa::prelude::*;
use linfa_elasticnet::{ElasticNet, Result};

// load Diabetes dataset
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);

// train pure LASSO model with 0.1 penalty
let model = ElasticNet::params()
    .penalty(0.3)
    .l1_ratio(1.0)
    .fit(&train)?;

println!("intercept:  {}", model.intercept());
println!("params: {}", model.hyperplane());

println!("z score: {:?}", model.z_score());

// validate
let y_est = model.predict(&valid);
println!("predicted variance: {}", valid.r2(&y_est)?);
# Result::Ok(())

Dependencies

~3–12MB
~185K SLoC