Skip to content

Commit c988835

Browse files
committed
Implement WeightedIndex, SliceRandom::choose_weighted and SliceRandom::choose_weighted_mut
1 parent af1303c commit c988835

File tree

5 files changed

+348
-27
lines changed

5 files changed

+348
-27
lines changed

benches/distributions.rs

+5
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
115115
distr_int!(distr_poisson, u64, Poisson::new(4.0));
116116
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));
117117

118+
// Weighted
119+
distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
120+
distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
121+
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
122+
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
118123

119124
// construct and sample from a range
120125
macro_rules! gen_range_int {

src/distributions/mod.rs

+25-22
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
//! numbers of the `char` type; in contrast [`Standard`] may sample any valid
7474
//! `char`.
7575
//!
76+
//! [`WeightedIndex`] can be used to do weighted sampling from a set of items,
77+
//! such as from an array.
7678
//!
7779
//! # Non-uniform probability distributions
7880
//!
@@ -167,12 +169,15 @@
167169
//! [`Uniform`]: struct.Uniform.html
168170
//! [`Uniform::new`]: struct.Uniform.html#method.new
169171
//! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive
172+
//! [`WeightedIndex`]: struct.WeightedIndex.html
170173
171174
use Rng;
172175

173176
#[doc(inline)] pub use self::other::Alphanumeric;
174177
#[doc(inline)] pub use self::uniform::Uniform;
175178
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
179+
#[cfg(feature="alloc")]
180+
#[doc(inline)] pub use self::weighted::WeightedIndex;
176181
#[cfg(feature="std")]
177182
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
178183
#[cfg(feature="std")]
@@ -192,6 +197,8 @@ use Rng;
192197
#[doc(inline)] pub use self::dirichlet::Dirichlet;
193198

194199
pub mod uniform;
200+
#[cfg(feature="alloc")]
201+
#[doc(hidden)] pub mod weighted;
195202
#[cfg(feature="std")]
196203
#[doc(hidden)] pub mod gamma;
197204
#[cfg(feature="std")]
@@ -372,6 +379,8 @@ pub struct Standard;
372379

373380

374381
/// A value with a particular weight for use with `WeightedChoice`.
382+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
383+
#[allow(deprecated)]
375384
#[derive(Copy, Clone, Debug)]
376385
pub struct Weighted<T> {
377386
/// The numerical weight of this item
@@ -382,34 +391,18 @@ pub struct Weighted<T> {
382391

383392
/// A distribution that selects from a finite collection of weighted items.
384393
///
385-
/// Each item has an associated weight that influences how likely it
386-
/// is to be chosen: higher weight is more likely.
387-
///
388-
/// The `Clone` restriction is a limitation of the `Distribution` trait.
389-
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
390-
/// store references or indices into another vector.
391-
///
392-
/// # Example
393-
///
394-
/// ```
395-
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
396-
///
397-
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
398-
/// Weighted { weight: 4, item: 'b' },
399-
/// Weighted { weight: 1, item: 'c' });
400-
/// let wc = WeightedChoice::new(&mut items);
401-
/// let mut rng = rand::thread_rng();
402-
/// for _ in 0..16 {
403-
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
404-
/// println!("{}", wc.sample(&mut rng));
405-
/// }
406-
/// ```
394+
/// Deprecated: use [`WeightedIndex`] instead.
395+
/// [`WeightedIndex`]: distributions/struct.WeightedIndex.html
396+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
397+
#[allow(deprecated)]
407398
#[derive(Debug)]
408399
pub struct WeightedChoice<'a, T:'a> {
409400
items: &'a mut [Weighted<T>],
410401
weight_range: Uniform<u32>,
411402
}
412403

404+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
405+
#[allow(deprecated)]
413406
impl<'a, T: Clone> WeightedChoice<'a, T> {
414407
/// Create a new `WeightedChoice`.
415408
///
@@ -447,6 +440,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
447440
}
448441
}
449442

443+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
444+
#[allow(deprecated)]
450445
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
451446
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
452447
// we want to find the first element that has cumulative
@@ -556,9 +551,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
556551
#[cfg(test)]
557552
mod tests {
558553
use rngs::mock::StepRng;
554+
#[allow(deprecated)]
559555
use super::{WeightedChoice, Weighted, Distribution};
560556

561557
#[test]
558+
#[allow(deprecated)]
562559
fn test_weighted_choice() {
563560
// this makes assumptions about the internal implementation of
564561
// WeightedChoice. It may fail when the implementation in
@@ -618,6 +615,7 @@ mod tests {
618615
}
619616

620617
#[test]
618+
#[allow(deprecated)]
621619
fn test_weighted_clone_initialization() {
622620
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
623621
let clone = initial.clone();
@@ -626,6 +624,7 @@ mod tests {
626624
}
627625

628626
#[test] #[should_panic]
627+
#[allow(deprecated)]
629628
fn test_weighted_clone_change_weight() {
630629
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
631630
let mut clone = initial.clone();
@@ -634,6 +633,7 @@ mod tests {
634633
}
635634

636635
#[test] #[should_panic]
636+
#[allow(deprecated)]
637637
fn test_weighted_clone_change_item() {
638638
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
639639
let mut clone = initial.clone();
@@ -643,15 +643,18 @@ mod tests {
643643
}
644644

645645
#[test] #[should_panic]
646+
#[allow(deprecated)]
646647
fn test_weighted_choice_no_items() {
647648
WeightedChoice::<isize>::new(&mut []);
648649
}
649650
#[test] #[should_panic]
651+
#[allow(deprecated)]
650652
fn test_weighted_choice_zero_weight() {
651653
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
652654
Weighted { weight: 0, item: 1}]);
653655
}
654656
#[test] #[should_panic]
657+
#[allow(deprecated)]
655658
fn test_weighted_choice_weight_overflows() {
656659
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
657660
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },

src/distributions/weighted.rs

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
use Rng;
12+
use distributions::Distribution;
13+
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
14+
use ::core::cmp::PartialOrd;
15+
use ::{Error, ErrorKind};
16+
17+
// Note that this whole module is only imported if feature="alloc" is enabled.
18+
#[cfg(not(feature="std"))] use alloc::Vec;
19+
20+
/// A distribution using weighted sampling to pick a discretely selected item.
21+
///
22+
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
23+
/// selected element from the iterator used when the `WeightedIndex` was
24+
/// created. The chance of a given element being picked is proportional to the
25+
/// value of the element. The weights can use any type `X` for which an
26+
/// implementation of [`Uniform<X>`] exists.
27+
///
28+
/// # Example
29+
///
30+
/// ```
31+
/// use rand::prelude::*;
32+
/// use rand::distributions::WeightedIndex;
33+
///
34+
/// let choices = ['a', 'b', 'c'];
35+
/// let weights = [2, 1, 1];
36+
/// let dist = WeightedIndex::new(&weights).unwrap();
37+
/// let mut rng = thread_rng();
38+
/// for _ in 0..100 {
39+
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
40+
/// println!("{}", choices[dist.sample(&mut rng)]);
41+
/// }
42+
///
43+
/// let items = [('a', 0), ('b', 3), ('c', 7)];
44+
/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
45+
/// for _ in 0..100 {
46+
/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
47+
/// println!("{}", items[dist2.sample(&mut rng)].0);
48+
/// }
49+
/// ```
50+
#[derive(Debug, Clone)]
51+
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
52+
cumulative_weights: Vec<X>,
53+
weight_distribution: X::Sampler,
54+
}
55+
56+
impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
57+
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
58+
/// in `weights`. The weights can use any type `X` for which an
59+
/// implementation of [`Uniform<X>`] exists.
60+
///
61+
/// Returns an error if the iterator is empty, or its total value is 0.
62+
///
63+
/// # Panics
64+
///
65+
/// If a value in the iterator is `< 0`.
66+
///
67+
/// [`Distribution`]: trait.Distribution.html
68+
/// [`Uniform<X>`]: struct.Uniform.html
69+
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
70+
where I: IntoIterator,
71+
I::Item: SampleBorrow<X>,
72+
X: for<'a> ::core::ops::AddAssign<&'a X> +
73+
Clone +
74+
Default {
75+
let mut iter = weights.into_iter();
76+
let mut total_weight: X = iter.next()
77+
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
78+
.borrow()
79+
.clone();
80+
81+
let zero = <X as Default>::default();
82+
let weights = iter.map(|w| {
83+
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new");
84+
let prev_weight = total_weight.clone();
85+
total_weight += w.borrow();
86+
prev_weight
87+
}).collect::<Vec<X>>();
88+
89+
if total_weight == zero {
90+
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
91+
}
92+
let distr = X::Sampler::new(zero, total_weight);
93+
94+
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
95+
}
96+
}
97+
98+
impl<X> Distribution<usize> for WeightedIndex<X> where
99+
X: SampleUniform + PartialOrd {
100+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
101+
let chosen_weight = self.weight_distribution.sample(rng);
102+
// Invariants: indexes in range [start, end] (inclusive) are candidate indexes
103+
// cumulative_weights[start-1] <= chosen_weight
104+
// chosen_weight < cumulative_weights[end]
105+
// The returned index is the first one whose value is >= chosen_weight
106+
let mut start = 0usize;
107+
let mut end = self.cumulative_weights.len();
108+
while start < end {
109+
let mid = (start + end) / 2;
110+
if chosen_weight >= * unsafe { self.cumulative_weights.get_unchecked(mid) } {
111+
start = mid + 1;
112+
} else {
113+
end = mid;
114+
}
115+
}
116+
debug_assert_eq!(start, end);
117+
start
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod test {
123+
use super::*;
124+
#[cfg(feature="std")]
125+
use core::panic::catch_unwind;
126+
127+
#[test]
128+
fn test_weightedindex() {
129+
let mut r = ::test::rng(700);
130+
const N_REPS: u32 = 5000;
131+
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
132+
let total_weight = weights.iter().sum::<u32>() as f32;
133+
134+
let verify = |result: [i32; 14]| {
135+
for (i, count) in result.iter().enumerate() {
136+
let exp = (weights[i] * N_REPS) as f32 / total_weight;
137+
let mut err = (*count as f32 - exp).abs();
138+
if err != 0.0 {
139+
err /= exp;
140+
}
141+
assert!(err <= 0.25);
142+
}
143+
};
144+
145+
// WeightedIndex from vec
146+
let mut chosen = [0i32; 14];
147+
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
148+
for _ in 0..N_REPS {
149+
chosen[distr.sample(&mut r)] += 1;
150+
}
151+
verify(chosen);
152+
153+
// WeightedIndex from slice
154+
chosen = [0i32; 14];
155+
let distr = WeightedIndex::new(&weights[..]).unwrap();
156+
for _ in 0..N_REPS {
157+
chosen[distr.sample(&mut r)] += 1;
158+
}
159+
verify(chosen);
160+
161+
// WeightedIndex from iterator
162+
chosen = [0i32; 14];
163+
let distr = WeightedIndex::new(weights.iter()).unwrap();
164+
for _ in 0..N_REPS {
165+
chosen[distr.sample(&mut r)] += 1;
166+
}
167+
verify(chosen);
168+
169+
assert!(WeightedIndex::new(&[10][0..0]).is_err());
170+
assert!(WeightedIndex::new(&[0]).is_err());
171+
}
172+
173+
#[test]
174+
#[cfg(all(feature="std",
175+
not(target_arch = "wasm32"),
176+
not(target_arch = "asmjs")))]
177+
fn test_weighted_assertions() {
178+
assert!(catch_unwind(|| WeightedIndex::new(&[1, 2, 3])).is_ok());
179+
assert!(catch_unwind(|| WeightedIndex::new(&[10, -1, 10])).is_err());
180+
assert!(catch_unwind(|| WeightedIndex::new(&[1, -1])).is_err());
181+
}
182+
}

src/lib.rs

-5
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@
134134
//!
135135
//! For more slice/sequence related functionality, look in the [`seq` module].
136136
//!
137-
//! There is also [`distributions::WeightedChoice`], which can be used to pick
138-
//! elements at random with some probability. But it does not work well at the
139-
//! moment and is going through a redesign.
140-
//!
141137
//!
142138
//! # Error handling
143139
//!
@@ -187,7 +183,6 @@
187183
//!
188184
//!
189185
//! [`distributions` module]: distributions/index.html
190-
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
191186
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
192187
//! [`Error`]: struct.Error.html
193188
//! [`gen_range`]: trait.Rng.html#method.gen_range

0 commit comments

Comments
 (0)