6 releases (3 breaking)
0.6.0 | Aug 4, 2024 |
---|---|
0.5.0 | May 4, 2024 |
0.4.0 | Feb 28, 2024 |
0.3.2 | Jan 7, 2024 |
0.3.1 | Dec 20, 2023 |
#126 in Machine learning
162 downloads per month
Used in 3 crates
(via border-candle-agent)
165KB
3K
SLoC
Candle Optimisers
A crate for optimisers for use with candle, the minimalist ML framework
Optimisers implemented are:
-
SGD (including momentum and weight decay)
-
RMSprop
Adaptive methods:
-
AdaDelta
-
AdaGrad
-
AdaMax
-
Adam
-
AdamW (included with Adam as
decoupled_weight_decay
) -
NAdam
-
RAdam
These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).
Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in Decoupled Weight Decay Regularization, in addition to the standard weight decay as implemented in pytorch.
Pseudosecond order methods:
- LBFGS
This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function
Examples
There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely:
cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
For even faster training try:
cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
to use the cuda backend.
Usage
cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers
To do
Currently unimplemented from pytorch:
-
SparseAdam (unsure how to treat sparse tensors in candle)
-
ASGD (no pseudocode)
-
Rprop (need to reformulate in terms of tensors)
Notes
For development, to track state of pytorch methods, use:
print(optimiser.state)
Dependencies
~10–19MB
~312K SLoC