23
23
24
24
use super :: Rng ;
25
25
use distributions:: uniform:: { SampleUniform , SampleBorrow } ;
26
+ #[ cfg( feature = "alloc" ) ]
26
27
use distributions:: Distribution ;
27
28
28
29
/// Extension trait on slices, providing random mutation and sampling methods.
@@ -102,9 +103,9 @@ pub trait SliceRandom {
102
103
///
103
104
/// ```
104
105
/// use rand::prelude::*;
105
- /// //use rand::distributions::uniform::SampleBorrow;
106
106
///
107
107
/// let choices = [('a', 2), ('b', 1), ('c', 1)];
108
+ /// // In rustc version XXX and newer, you can use a closure instead
108
109
/// fn mapping_func(item: &(char, usize)) -> usize {
109
110
/// item.1
110
111
/// }
@@ -126,22 +127,10 @@ pub trait SliceRandom {
126
127
/// likelyhood of getting returned. The likelyhood of a given item getting
127
128
/// returned is proportional to the value returned by the mapping function
128
129
/// `func`.
129
- ///
130
- /// # Example
131
- ///
132
- /// ```
133
- /// use rand::prelude::*;
134
- /// //use rand::distributions::uniform::SampleBorrow;
135
130
///
136
- /// let choices = [('a', 2), ('b', 1), ('c', 1)];
137
- /// fn mapping_func(item: &(char, usize)) -> usize {
138
- /// item.1
139
- /// }
140
- /// let mut rng = thread_rng();
141
- /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
142
- /// println!("{:?}", choices.choose_weighted(&mut rng, &mapping_func).unwrap().0);
143
- /// ```
131
+ /// See also [`choose_weighted`].
144
132
/// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut
133
+ /// [`choose_weighted`]: trait.SliceRandom.html#method.choose_weighted
145
134
fn choose_weighted_mut < R , F , B , X > ( & mut self , rng : & mut R , func : F ) -> Option < & mut Self :: Item >
146
135
where R : Rng + ?Sized ,
147
136
F : Fn ( & Self :: Item ) -> B + Clone ,
@@ -296,18 +285,23 @@ pub trait IteratorRandom: Iterator + Sized {
296
285
}
297
286
}
298
287
299
- /// Extension trait on iterators , providing random sampling methods.
288
+ /// Extension trait on IntoIterator , providing random sampling methods.
300
289
pub trait IntoIteratorRandom : IntoIterator + Sized {
301
290
302
291
/// Return a the index of a random element from this iterator where the
303
292
/// chance of a given item being picked is proportional to the value of
304
293
/// the item.
305
294
///
295
+ /// If calling this on a [`Vec`], make sure to use
296
+ /// `(&vec).choose_index_weighted(...)`. Otherwise the [`Vec`] and all of
297
+ /// its contents will be cloned.
298
+ ///
306
299
/// All values returned by the iterator must be `>= 0`.
307
300
///
308
301
/// This function iterates over the iterator twice. Once to get the total
309
- /// weight, and once while choosing the random item. If you plan to call
310
- /// this function multiple times, it would be more performant to use
302
+ /// weight, and once while choosing the random item. This is done by
303
+ /// cloning the iterator. If you plan to call this function multiple times,
304
+ /// it would be more performant to use
311
305
/// [`into_weighted_index_distribution`].
312
306
///
313
307
/// Return `None` if self is empty, or contains only `0`s.
@@ -328,6 +322,7 @@ pub trait IntoIteratorRandom: IntoIterator + Sized {
328
322
/// println!("{}", weights.choose_index_weighted(&mut rng).map(|ix| choices[ix]).unwrap());
329
323
/// ```
330
324
/// [`into_weighted_index_distribution`]: trait.IntoIteratorRandom.html#method.into_weighted_index_distribution
325
+ /// [`Vec`]: https://doc.rust-lang.org/std/vec/struct.Vec.html
331
326
fn choose_index_weighted < R , X > ( self , rng : & mut R ) -> Option < usize >
332
327
where Self :: IntoIter : Clone ,
333
328
R : Rng + ?Sized ,
@@ -361,9 +356,9 @@ pub trait IntoIteratorRandom: IntoIterator + Sized {
361
356
panic ! ( "Weights changed during cloning" ) ;
362
357
}
363
358
364
- /// Returns a [`Distribution`] which when sampled from, returns a random
365
- /// element from this iterator where the chance of a given item being
366
- /// picked is proportional to the value of the item.
359
+ /// Returns a [`Distribution`] which when sampled from, returns the index
360
+ /// of a random element from this iterator where the chance of a given
361
+ /// item being picked is proportional to the value of the item.
367
362
///
368
363
/// # Panics
369
364
///
@@ -454,6 +449,10 @@ impl<T> SliceRandom for [T] {
454
449
for < ' a > :: core:: ops:: AddAssign < & ' a X > +
455
450
:: core:: cmp:: PartialOrd < X > +
456
451
Default {
452
+ // Technically we could move choose_weighted to IntoIteratorRandom.
453
+ // That does carry the small risk that someone will do something silly like:
454
+ // my_vec.clone().choose_weighted(&mut r, |item| ...)
455
+ // which would clone the vector and its contents 3 times.
457
456
self . iter ( ) . map ( func) . choose_index_weighted ( rng) . map ( |ix| & self [ ix] )
458
457
}
459
458
@@ -465,6 +464,10 @@ impl<T> SliceRandom for [T] {
465
464
for < ' a > :: core:: ops:: AddAssign < & ' a X > +
466
465
:: core:: cmp:: PartialOrd < X > +
467
466
Default {
467
+ // Technically we could move choose_weighted_mut to IntoIteratorRandom.
468
+ // That does carry the small risk that someone will do something silly like:
469
+ // my_vec.clone().choose_weighted(&mut r, |item| ...)
470
+ // which would clone the vector and its contents 3 times.
468
471
let index = self . iter ( ) . map ( func) . choose_index_weighted ( rng) ;
469
472
index. map ( move |ix| & mut self [ ix] )
470
473
}
@@ -925,15 +928,15 @@ mod test {
925
928
}
926
929
} ;
927
930
928
- // choose_weighted array
931
+ // choose_index_weighted array
929
932
let mut chosen = [ 0i32 ; 14 ] ;
930
933
for _ in 0 ..1000 {
931
934
let picked = weights. choose_index_weighted ( & mut r) . unwrap ( ) ;
932
935
chosen[ picked] += 1 ;
933
936
}
934
937
verify ( chosen) ;
935
938
936
- // choose_weighted ref iterator
939
+ // choose_index_weighted ref iterator
937
940
chosen = [ 0i32 ; 14 ] ;
938
941
for _ in 0 ..1000 {
939
942
let picked = weights. iter ( )
@@ -942,7 +945,7 @@ mod test {
942
945
}
943
946
verify ( chosen) ;
944
947
945
- // choose_weighted value iterator
948
+ // choose_index_weighted value iterator
946
949
chosen = [ 0i32 ; 14 ] ;
947
950
for _ in 0 ..1000 {
948
951
let picked = weights. iter ( )
@@ -952,33 +955,51 @@ mod test {
952
955
}
953
956
verify ( chosen) ;
954
957
955
- // choose_weighted value iterator
956
- chosen = [ 0i32 ; 14 ] ;
957
- let distr = weights. into_weighted_index_distribution ( ) ;
958
- for _ in 0 ..1000 {
959
- chosen[ distr. sample ( & mut r) ] += 1 ;
958
+ // choose_index_weighted Vec<...>
959
+ #[ cfg( feature = "alloc" ) ]
960
+ {
961
+ chosen = [ 0i32 ; 14 ] ;
962
+ let vec_weights = weights. to_vec ( ) ;
963
+ for _ in 0 ..1000 {
964
+ let picked = ( & vec_weights) . choose_index_weighted ( & mut r) . unwrap ( ) ;
965
+ chosen[ picked] += 1 ;
966
+ }
967
+ verify ( chosen) ;
968
+ }
969
+
970
+ // into_weighted_index_distribution
971
+ #[ cfg( feature = "alloc" ) ]
972
+ {
973
+ chosen = [ 0i32 ; 14 ] ;
974
+ let distr = weights. into_weighted_index_distribution ( ) ;
975
+ for _ in 0 ..1000 {
976
+ chosen[ distr. sample ( & mut r) ] += 1 ;
977
+ }
978
+ verify ( chosen) ;
960
979
}
961
- verify ( chosen) ;
962
980
963
981
// choose_weighted
982
+ fn get_weight < T > ( item : & ( u32 , T ) ) -> u32 {
983
+ item. 0
984
+ }
964
985
chosen = [ 0i32 ; 14 ] ;
965
- let mut items = [ ( 0u32 , 0usize ) ; 14 ] ;
986
+ let mut items = [ ( 0u32 , 0usize ) ; 14 ] ; // (weight, index)
966
987
for ( i, item) in items. iter_mut ( ) . enumerate ( ) {
967
988
* item = ( weights[ i] , i) ;
968
989
}
969
990
for _ in 0 ..1000 {
970
- let item = items. choose_weighted ( & mut r, |item| item . 0 ) . unwrap ( ) ;
991
+ let item = items. choose_weighted ( & mut r, get_weight ) . unwrap ( ) ;
971
992
chosen[ item. 1 ] += 1 ;
972
993
}
973
994
verify ( chosen) ;
974
995
975
996
// choose_weighted_mut
976
- let mut items = [ ( 0u32 , 0i32 ) ; 14 ] ;
997
+ let mut items = [ ( 0u32 , 0i32 ) ; 14 ] ; // (weight, count)
977
998
for ( i, item) in items. iter_mut ( ) . enumerate ( ) {
978
999
* item = ( weights[ i] , 0 ) ;
979
1000
}
980
1001
for _ in 0 ..1000 {
981
- items. choose_weighted_mut ( & mut r, |item| item . 0 ) . unwrap ( ) . 1 += 1 ;
1002
+ items. choose_weighted_mut ( & mut r, get_weight ) . unwrap ( ) . 1 += 1 ;
982
1003
}
983
1004
for ( ch, item) in chosen. iter_mut ( ) . zip ( items. iter ( ) ) {
984
1005
* ch = item. 1 ;
0 commit comments