1 stable release
1.0.0 | May 18, 2021 |
---|
#1832 in Procedural macros
26 downloads per month
25KB
316 lines
standard-dist
An attribute macro for creating a Standard
distribution for rust types
lib.rs
:
standard-dist
is a library for automatically deriving a rand
standard
distribution for your types via a derive macro.
Usage examples
use rand::distributions::Uniform;
use standard_dist::StandardDist;
// Select heads or tails with equal probability
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
enum Coin {
Heads,
Tails,
}
// Flip 3 coins, independently
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
struct Coins {
first: Coin,
second: Coin,
third: Coin,
}
// Use the `#[distribution]` attribute to customize the distribution used on
// a field
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}
// Use the `#[weight]` attribute to customize the relative probabilities of
// enum variants
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
enum D20 {
#[weight(18)]
Normal,
Critical,
CriticalFail,
}
rand
generates typed random values via the Distribution
trait, which
uses a source of randomness to produce values of the given type. Of particular
note is the Standard
distribution, which is the stateless "default" way to
produce random values of a particular type. For instance:
- For ints, this randomly chooses from all possible values for that int type
- For bools, it chooses true or false with 50/50 probability
- For
Option<T>
, it choosesNone
orSome
with 50/50 probability, and usesStandard
to randomly populate the innerSome
value.
Structs
When you derive StandardDist
for one of your own structs, it creates an
impl Distribution<YourStruct> for Standard
implementation, allowing you to
create randomized instances of the struct via Rng::gen
. This implementation
will in turn use the Standard
distribution to populate all the fields of
your type.
use standard_dist::StandardDist;
#[derive(StandardDist)]
struct SimpleStruct {
coin: bool,
percent: f64,
}
let mut heads = 0;
for _ in 0..2000 {
let s: SimpleStruct = rand::random();
assert!(0.0 <= s.percent);
assert!(s.percent < 1.0);
if s.coin {
heads += 1;
}
}
assert!(900 < heads, "heads: {}", heads);
assert!(heads < 1100, "heads: {}", heads);
Custom Distributions
You can customize the distribution used for any field with the #[distribution]
attribute:
use std::collections::HashMap;
use standard_dist::StandardDist;
use rand::distributions::Uniform;
#[derive(StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}
let mut counter: HashMap<u8, u32> = HashMap::new();
for _ in 0..6000 {
let die: Die = rand::random();
*counter.entry(die.value).or_insert(0) += 1;
}
assert_eq!(counter.len(), 6);
for i in 1..=6 {
let count = counter[&i];
assert!(900 < count, "{}: {}", i, count);
assert!(count < 1100, "{}: {}", i, count);
}
Enums
When applied to an enum type, the implementation will randomly select a variant
(where each variant has an equal probability) and then populate all the fields
of that variant in the same manner as with a struct. Enum variant fields may
have custom distributions applied via #[distribution]
, just like struct
fields.
use standard_dist::StandardDist;
#[derive(PartialEq, Eq, StandardDist)]
enum Coin {
Heads,
Tails,
}
let mut heads = 0;
for _ in 0..2000 {
let coin: Coin = rand::random();
if coin == Coin::Heads {
heads += 1;
}
}
assert!(900 < heads, "heads: {}", heads);
assert!(heads < 1100, "heads: {}", heads);
Weights
Enum variants may be weighted with the #[weight]
attribute to make them
relatively more or less likely to be randomly selected. A weight of 0 means
that the variant will never be selected. Any untagged variants will have a
weight of 1.
use standard_dist::StandardDist;
#[derive(StandardDist)]
enum D20 {
#[weight(18)]
Normal,
CriticalHit,
CriticalMiss,
}
let mut crits = 0;
for _ in 0..20000 {
let roll: D20 = rand::random();
if matches!(roll, D20::CriticalHit) {
crits += 1;
}
}
assert!(900 < crits, "crits: {}", crits);
assert!(crits < 1100, "crits: {}", crits);
Advanced custom distributions
Distribution types
You may optionally explicitly specify a type for your distributions; this can sometimes be necessary when using generic types.
use std::collections::HashMap;
use standard_dist::StandardDist;
use rand::distributions::Uniform;
#[derive(StandardDist)]
struct Die {
#[distribution(Uniform<u8> = Uniform::from(1..=6))]
value: u8
}
let mut counter: HashMap<u8, u32> = HashMap::new();
for _ in 0..6000 {
let die: Die = rand::random();
*counter.entry(die.value).or_insert(0) += 1;
}
assert_eq!(counter.len(), 6);
for i in 1..=6 {
let count = counter[&i];
assert!(900 < count, "{}: {}", i, count);
assert!(count < 1100, "{}: {}", i, count);
}
Distribution caching
In some cases, you may wish to cache a Distribution
instance for reuse. Many
distributions perform some initial calculations when constructed, and it can
help performance to reuse existing distributions rather than recreate them
every time a value is generated. standard-dist
provides two ways to cache
distributions: static
and once
. A static
distribution is stored as a
global static variable; this is the preferable option, but it requires the
initializer to be usable in a const
context. A once
distribution is stored
in a once_cell::sync::OnceCell
; it is initialized the first time it's used,
and then reused on subsequent invocations.
In either case, a cache policy is specified by prefixing the type with once
or
static
. The type must be specified in order to use a cache policy.
use std::collections::HashMap;
use std::time::{Instant, Duration};
use standard_dist::StandardDist;
use rand::prelude::*;
use rand::distributions::Uniform;
#[derive(StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}
#[derive(StandardDist)]
struct CachedDie {
#[distribution(once Uniform<u8> = Uniform::from(1..=6))]
value: u8
}
fn timed<T>(task: impl FnOnce() -> T) -> (T, Duration) {
let start = Instant::now();
(task(), start.elapsed())
}
// Count the 6s
let mut rng = StdRng::from_entropy();
let (count, plain_die_duration) = timed(|| (0..600000)
.map(|_| rng.gen())
.filter(|&Die{ value }| value == 6)
.count()
);
assert!(90000 < count);
assert!(count < 110000);
let (count, cache_die_duration) = timed(|| (0..600000)
.map(|_| rng.gen())
.filter(|&CachedDie{ value }| value == 6)
.count()
);
assert!(90000 < count);
assert!(count < 110000);
assert!(
cache_die_duration < plain_die_duration,
"cache: {:?}, plain: {:?}",
cache_die_duration,
plain_die_duration,
);
Note that, unless you're generating a huge quantity of random objects, using
cell
is likely a pessimization because of the upfront cost to initializing
the cell. Make sure to benchmark your specific use case if performance is a
concern.
Dependencies
~2MB
~44K SLoC