2 releases

new 0.0.2 May 9, 2025
0.0.1 May 8, 2025

#4 in #stochastic

40 downloads per month

MIT/Apache

48KB
805 lines

Latest version Documentation License msrv

Overview

stochy is a collection of stochastic approximation algorithms:

  • RSPSA (Resilient Simultaneous Perturbation Stochastic Approximation)
  • SPSA (Simultaneous Perturbation Stochastic Approximation)

You can use stochy to:

  • Minimize functions with multiple parameters, without needing a gradient function
  • Optimize parameters in game-playing programs using relative difference functions

stochy is compatible with both the stepwise solver API and the argmin solver API (enable via the argmin feature flag).

Usage

Example Cargo.toml:

[dependencies]
stochy = "0.0.2" 

# if using argmin, replace above with...
# stochy = { version = "0.0.2", features = ["argmin"] } 

Example

use stepwise::{Driver as _, fixed_iters, assert_approx_eq};
use stochy::{SpsaAlgo, SpsaParams};

let f = |x: &[f64]| (x[0] - 1.5).powi(2) + x[1] * x[1];

let hyperparams = SpsaParams::default();
let spsa = SpsaAlgo::from_fn(hyperparams, vec![1.0, 1.0], f).expect("bad hyperparams!");

let (solved, final_step) = fixed_iters(spsa, 20_000)
    .on_step(|algo, step| println!("{:>4} {:.8?}", step.iteration(), algo.x()))
    .solve()
    .expect("solving failed!");

assert_approx_eq!(solved.x(), &[1.5, 0.0]);
println!("solved in {} iters", final_step.iteration());

Example (argmin)

use stepwise::assert_approx_eq;
struct MySimpleCost;

# #[cfg(feature = "argmin")]
impl argmin::core::CostFunction for MySimpleCost {
    type Param = Vec<f64>;
    type Output = f64;

    fn cost(&self, x: &Self::Param) -> Result<Self::Output, argmin::core::Error> {
        Ok((x[0] - 1.5).powi(2) + x[1] * x[1])
    }
}

let hyperparams = stochy::SpsaParams::default();
let algo = stochy::SpsaSolverArgmin::new(hyperparams);

let exec = argmin::core::Executor::new(MySimpleCost, algo);

let initial_param = vec![1.0, 1.0];
let result = exec
    .configure(|step| step.param(initial_param).max_iters(20_000))
    .run()
    .unwrap();

let best_param = result.state.best_param.unwrap();
assert_approx_eq!(best_param.as_slice(), &[1.5, 0.0]);

Comparison

Gradient Descent (for comparison) RSPSA SPSA
Requires gradient function No gradient function required No gradient function required
Requires gradient function Accepts relative difference function Accepts relative difference function
One gradient eval per iteration Two function evals per iteration Two function evals per iteration
Single learning-rate hyperparameter Very sensitive to hyperparameters Less sensitive to hyperparameters than SPSA
Continuous convergence progression Convergence saturation Continuous convergence progression

Dependencies

~0.5–1MB
~18K SLoC