1 stable release
22.5.23 | May 23, 2022 |
---|---|
22.5.17 |
|
#303 in Machine learning
34KB
517 lines
scratch_genetic
Description
A from-scratch genetic-algorithm library used in my march-madness-predictor project
API Reference
Contents
genetic
module
The genetic module is the only public module under the scratch_genetic library.
It contains a set of functions implementing the "genetic algorithm" which mimics the concept of natural selection in order to create a model that can be used to predict the results of something.
The way it works is you convert your data into a stream of input and output bits, create a set of randomized networks based on the size of the data coming in and with various settings, train by running the test and reproduction functions, and then export the final, best network at the end.
Then you can
gen_pop
function
pub async fn gen_pop(
pop_size: usize,
layer_sizes: Vec<usize>, num_inputs: usize, num_outputs: usize,
activation_thresh: f64, trait_swap_chance: f64,
weight_mutate_chance: f64, weight_mutate_amount: f64,
offset_mutate_chance: f64, offset_mutate_amount: f64) -> Vec<Network> {
This function generates a random vector of an underlying private struct Network. It's private because you won't need to manually mess with it, you'll just need to pass it between functions.
You can see it takes quite a few parameters. These are all the manual settings to give to your network to adjust how it trains.
Parameters:
pop_size
- number of networks to train on. Bigger is better, but bigger is also slowerlayer_sizes
- a vector containing sizes for each layer of the neural networknum_inputs
- the number of bits that your input data produces after being converted (must be divisible by 8)num_outputs
- the expected number of bits generated by the output (must be divisible by 8)activation_thresh
- how hard it is for a neuron to turn ontrait_swap_chance
- controls the variablity of a child sharing different traits from each parent when reproducingweight_mutate_chance
- the chance that a weight on the connections between neurons changesweight_mutate_amount
- how strong the change above isoffset_mutate_chance
andoffset_mutate_amount
- same as the above two, but with the base value of the connection
test_and_sort
function
pub async fn test_and_sort(pop: &mut Vec<Network>, data_set: &Vec<(Vec<u8>, Vec<u8>)>) {
This takes the "population" (vector of Networks created by gen_pop
) and your test data, sees how close each nework gets to reproducing each test data's output, and then sorts the networks based on that performance.
reproduce
function
pub async fn reproduce(pop: &mut Vec<Network>) {
After sorting
, you'll want to reproduce. This takes your set of networks, keeps the upper half of them, and uses those to replace the bottom half with children sharing mixed genes and mutations based on the parameters you provided in gen_pop
.
load_and_predict
function
pub async fn load_and_predict(file_name: &'static str, input_bits: &Vec<u8>) -> Vec<u8> {
Load in a model that's been exported and, provided the input bits you pass in, generate output bits
export_model
function
pub async fn export_model(file_name: &'static str, pop: &Network) {
Take a network and export it to a file.
Examples
The following examples use these constants:
// Neuron connection settings
pub const NEURON_ACTIVATION_THRESH: f64 = 0.60;
pub const TRAIT_SWAP_CHANCE: f64 = 0.80;
pub const WEIGHT_MUTATE_CHANCE: f64 = 0.65;
pub const WEIGHT_MUTATE_AMOUNT: f64 = 0.5;
pub const OFFSET_MUTATE_CHANCE: f64 = 0.25;
pub const OFFSET_MUTATE_AMOUNT: f64 = 0.05;
// Neural network settings
pub const LAYER_SIZES: [usize; 4] = [ 8, 32, 32, 16 ];
// Algortithm settings
const POP_SIZE: usize = 2000;
const DATA_FILE_NAME: &'static str = "NCAA Mens March Madness Historical Results.csv";
const MODEL_FILE_NAME: &'static str = "model.mmp";
const NUM_GENS: usize = 1000;
Training
println!("Training new March Madness Predictor Model");
// Custom class that structures CSV data and allows turning into bits.
println!("Loading training data from {}", DATA_FILE_NAME);
let games = GameInfo::collection_from_file(DATA_FILE_NAME);
let games: Vec<(Vec<u8>, Vec<u8>)> = games.iter().map(|game| { // Redefines games
(game.to_input_bits().to_vec(), game.to_output_bits().to_vec())
}).collect();
println!("Generating randomized population");
let now = Instant::now();
let mut pop = gen_pop(
POP_SIZE,
LAYER_SIZES.to_vec(), NUM_INPUTS, NUM_OUTPUTS,
NEURON_ACTIVATION_THRESH, TRAIT_SWAP_CHANCE,
WEIGHT_MUTATE_CHANCE, WEIGHT_MUTATE_AMOUNT,
OFFSET_MUTATE_CHANCE, OFFSET_MUTATE_AMOUNT
).await;
let elapsed = now.elapsed();
println!("Generation took {}s", elapsed.as_secs_f64());
println!("Starting training");
for i in 0..NUM_GENS {
println!("Generation {} / {}", i, NUM_GENS);
test_and_sort(&mut pop, &games).await;
reproduce(&mut pop).await;
}
// Save algorithm
println!("Saving model to {}", MODEL_FILE_NAME);
export_model(MODEL_FILE_NAME, &pop[0]).await;
Predicting
pub async fn predict(team_names: &str) {
let table_data = team_names.split(",");
let mut indexable_table_data = Vec::new();
for item in table_data {
indexable_table_data.push(item);
}
// A team, A seed, B team, B seed, date, round, region
if indexable_table_data.len() != 7 {
println!("Invalid input string!");
return;
}
// Like the other example, this stuff is converting CSV data into a useable form
println!("Converting input into data...");
let entry = TableEntry {
winner: String::from(indexable_table_data[0]),
win_seed: String::from(indexable_table_data[1]),
loser: String::from(indexable_table_data[2]),
lose_seed: String::from(indexable_table_data[3]),
date: String::from(indexable_table_data[4]),
round: String::from(indexable_table_data[5]),
region: String::from(indexable_table_data[6]),
win_score: String::from("0"),
lose_score: String::from("0"),
overtime: String::from("")
};
let game = GameInfo::from_table_entry(&entry);
// Here's where the code is used
println!("Predicting!");
let result = load_and_predict(MODEL_FILE_NAME, &game.to_input_bits().to_vec()).await;
println!("Predicted score for {}: {}", indexable_table_data[0], result[0]);
println!("Predicted score for {}: {}", indexable_table_data[2], result[1]);
println!("Expected overtimes: {}", result[2]);
}
Dependencies
~3–9.5MB
~77K SLoC