Skip to main content

rand/seq/
iterator.rs

1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11#[allow(unused)]
12use super::IndexedRandom;
13use super::coin_flipper::CoinFlipper;
14use crate::{Rng, RngExt};
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35    /// Uniformly sample one element
36    ///
37    /// Assuming that the [`Iterator::size_hint`] is correct, this method
38    /// returns one uniformly-sampled random element of the slice, or `None`
39    /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40    /// cause this method to incorrectly return `None` if fewer elements than
41    /// the advertised `lower` bound are present and may prevent sampling of
42    /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43    /// is memory-safe, but may result in unexpected `None` result and
44    /// non-uniform distribution).
45    ///
46    /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47    /// a constant-time operation, this method can offer `O(1)` performance.
48    /// Where no size hint is
49    /// available, complexity is `O(n)` where `n` is the iterator length.
50    /// Partial hints (where `lower > 0`) also improve performance.
51    ///
52    /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53    /// samples used as well as the result (while remaining uniform sampling).
54    /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55    /// [`Iterator`] combinators which only change size hints from affecting the
56    /// results.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rand::seq::IteratorRandom;
62    ///
63    /// let words = "Mary had a little lamb".split(' ');
64    /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65    /// ```
66    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74        // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75        // seq_iter_choose_from_1000.
76        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        // Continue until the iterator is exhausted
88        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    /// Uniformly sample one element (stable)
122    ///
123    /// This method is very similar to [`choose`] except that the result
124    /// only depends on the length of the iterator and the values produced by
125    /// `rng`. Notably for any iterator of a given length this will make the
126    /// same requests to `rng` and if the same sequence of values are produced
127    /// the same index will be selected from `self`. This may be useful if you
128    /// need consistent results no matter what type of iterator you are working
129    /// with. If you do not need this stability prefer [`choose`].
130    ///
131    /// Note that this method still uses [`Iterator::size_hint`] to skip
132    /// constructing elements where possible, however the selection and `rng`
133    /// calls are the same in the face of this optimization. If you want to
134    /// force every element to be created regardless call `.inspect(|e| ())`.
135    ///
136    /// [`choose`]: IteratorRandom::choose
137    //
138    // Clippy is wrong here: we need to iterate over all entries with the RNG to
139    // ensure that choosing is *stable*.
140    // "allow(unknown_lints)" can be removed when switching to at least
141    // rust-version 1.86.0, see:
142    // https://rust-lang.github.io/rust-clippy/master/index.html#double_ended_iterator_last
143    #[allow(unknown_lints)]
144    #[allow(clippy::double_ended_iterator_last)]
145    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146    where
147        R: Rng + ?Sized,
148    {
149        let mut consumed = 0;
150        let mut result = None;
151        let mut coin_flipper = CoinFlipper::new(rng);
152
153        loop {
154            // Currently the only way to skip elements is `nth()`. So we need to
155            // store what index to access next here.
156            // This should be replaced by `advance_by()` once it is stable:
157            // https://github.com/rust-lang/rust/issues/77404
158            let mut next = 0;
159
160            let (lower, _) = self.size_hint();
161            if lower >= 2 {
162                let highest_selected = (0..lower)
163                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164                    .last();
165
166                consumed += lower;
167                next = lower;
168
169                if let Some(ix) = highest_selected {
170                    result = self.nth(ix);
171                    next -= ix + 1;
172                    if true {
    if !result.is_some() {
        {
            ::core::panicking::panic_fmt(format_args!("iterator shorter than size_hint().0"));
        }
    };
};debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173                }
174            }
175
176            let elem = self.nth(next);
177            if elem.is_none() {
178                return result;
179            }
180
181            if coin_flipper.random_ratio_one_over(consumed + 1) {
182                result = elem;
183            }
184            consumed += 1;
185        }
186    }
187
188    /// Uniformly sample `amount` distinct elements into a buffer
189    ///
190    /// Collects values at random from the iterator into a supplied buffer
191    /// until that buffer is filled.
192    ///
193    /// Although the elements are selected randomly, the order of elements in
194    /// the buffer is neither stable nor fully random. If random ordering is
195    /// desired, shuffle the result.
196    ///
197    /// Returns the number of elements added to the buffer. This equals the length
198    /// of the buffer unless the iterator contains insufficient elements, in which
199    /// case this equals the number of elements available.
200    ///
201    /// Complexity is `O(n)` where `n` is the length of the iterator.
202    /// For slices, prefer [`IndexedRandom::sample`].
203    fn sample_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204    where
205        R: Rng + ?Sized,
206    {
207        let amount = buf.len();
208        let mut len = 0;
209        while len < amount {
210            if let Some(elem) = self.next() {
211                buf[len] = elem;
212                len += 1;
213            } else {
214                // Iterator exhausted; stop early
215                return len;
216            }
217        }
218
219        // Continue, since the iterator was not exhausted
220        for (i, elem) in self.enumerate() {
221            let k = rng.random_range(..i + 1 + amount);
222            if let Some(slot) = buf.get_mut(k) {
223                *slot = elem;
224            }
225        }
226        len
227    }
228
229    /// Uniformly sample `amount` distinct elements into a [`Vec`]
230    ///
231    /// This is equivalent to `sample_fill` except for the result type.
232    ///
233    /// Although the elements are selected randomly, the order of elements in
234    /// the buffer is neither stable nor fully random. If random ordering is
235    /// desired, shuffle the result.
236    ///
237    /// The length of the returned vector equals `amount` unless the iterator
238    /// contains insufficient elements, in which case it equals the number of
239    /// elements available.
240    ///
241    /// Complexity is `O(n)` where `n` is the length of the iterator.
242    /// For slices, prefer [`IndexedRandom::sample`].
243    #[cfg(feature = "alloc")]
244    fn sample<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245    where
246        R: Rng + ?Sized,
247    {
248        let mut reservoir = Vec::from_iter(self.by_ref().take(amount));
249
250        // Continue unless the iterator was exhausted
251        //
252        // note: this prevents iterators that "restart" from causing problems.
253        // If the iterator stops once, then so do we.
254        if reservoir.len() == amount {
255            for (i, elem) in self.enumerate() {
256                let k = rng.random_range(..i + 1 + amount);
257                if let Some(slot) = reservoir.get_mut(k) {
258                    *slot = elem;
259                }
260            }
261        }
262        reservoir
263    }
264
265    /// Deprecated: use [`Self::sample_fill`] instead
266    #[deprecated(since = "0.10.0", note = "Renamed to `sample_fill`")]
267    fn choose_multiple_fill<R>(self, rng: &mut R, buf: &mut [Self::Item]) -> usize
268    where
269        R: Rng + ?Sized,
270    {
271        self.sample_fill(rng, buf)
272    }
273
274    /// Deprecated: use [`Self::sample`] instead
275    #[cfg(feature = "alloc")]
276    #[deprecated(since = "0.10.0", note = "Renamed to `sample`")]
277    fn choose_multiple<R>(self, rng: &mut R, amount: usize) -> Vec<Self::Item>
278    where
279        R: Rng + ?Sized,
280    {
281        self.sample(rng, amount)
282    }
283}
284
285impl<I> IteratorRandom for I where I: Iterator + Sized {}
286
287#[cfg(test)]
288mod test {
289    use super::*;
290    #[cfg(all(feature = "alloc", not(feature = "std")))]
291    use alloc::vec::Vec;
292
293    #[derive(Clone)]
294    struct UnhintedIterator<I: Iterator + Clone> {
295        iter: I,
296    }
297    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
298        type Item = I::Item;
299
300        fn next(&mut self) -> Option<Self::Item> {
301            self.iter.next()
302        }
303    }
304
305    #[derive(Clone)]
306    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
307        iter: I,
308        chunk_remaining: usize,
309        chunk_size: usize,
310        hint_total_size: bool,
311    }
312    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
313        type Item = I::Item;
314
315        fn next(&mut self) -> Option<Self::Item> {
316            if self.chunk_remaining == 0 {
317                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
318            }
319            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
320
321            self.iter.next()
322        }
323
324        fn size_hint(&self) -> (usize, Option<usize>) {
325            (
326                self.chunk_remaining,
327                if self.hint_total_size {
328                    Some(self.iter.len())
329                } else {
330                    None
331                },
332            )
333        }
334    }
335
336    #[derive(Clone)]
337    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
338        iter: I,
339        window_size: usize,
340        hint_total_size: bool,
341    }
342    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
343        type Item = I::Item;
344
345        fn next(&mut self) -> Option<Self::Item> {
346            self.iter.next()
347        }
348
349        fn size_hint(&self) -> (usize, Option<usize>) {
350            (
351                core::cmp::min(self.iter.len(), self.window_size),
352                if self.hint_total_size {
353                    Some(self.iter.len())
354                } else {
355                    None
356                },
357            )
358        }
359    }
360
361    #[test]
362    #[cfg_attr(miri, ignore)] // Miri is too slow
363    fn test_iterator_choose() {
364        let r = &mut crate::test::rng(109);
365        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
366            let mut chosen = [0i32; 9];
367            for _ in 0..1000 {
368                let picked = iter.clone().choose(r).unwrap();
369                chosen[picked] += 1;
370            }
371            for count in chosen.iter() {
372                // Samples should follow Binomial(1000, 1/9)
373                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
374                // Note: have seen 153, which is unlikely but not impossible.
375                assert!(
376                    72 < *count && *count < 154,
377                    "count not close to 1000/9: {}",
378                    count
379                );
380            }
381        }
382
383        test_iter(r, 0..9);
384        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
385        #[cfg(feature = "alloc")]
386        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
387        test_iter(r, UnhintedIterator { iter: 0..9 });
388        test_iter(
389            r,
390            ChunkHintedIterator {
391                iter: 0..9,
392                chunk_size: 4,
393                chunk_remaining: 4,
394                hint_total_size: false,
395            },
396        );
397        test_iter(
398            r,
399            ChunkHintedIterator {
400                iter: 0..9,
401                chunk_size: 4,
402                chunk_remaining: 4,
403                hint_total_size: true,
404            },
405        );
406        test_iter(
407            r,
408            WindowHintedIterator {
409                iter: 0..9,
410                window_size: 2,
411                hint_total_size: false,
412            },
413        );
414        test_iter(
415            r,
416            WindowHintedIterator {
417                iter: 0..9,
418                window_size: 2,
419                hint_total_size: true,
420            },
421        );
422
423        assert_eq!((0..0).choose(r), None);
424        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
425    }
426
427    #[test]
428    #[cfg_attr(miri, ignore)] // Miri is too slow
429    fn test_iterator_choose_stable() {
430        let r = &mut crate::test::rng(109);
431        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
432            let mut chosen = [0i32; 9];
433            for _ in 0..1000 {
434                let picked = iter.clone().choose_stable(r).unwrap();
435                chosen[picked] += 1;
436            }
437            for count in chosen.iter() {
438                // Samples should follow Binomial(1000, 1/9)
439                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
440                // Note: have seen 153, which is unlikely but not impossible.
441                assert!(
442                    72 < *count && *count < 154,
443                    "count not close to 1000/9: {}",
444                    count
445                );
446            }
447        }
448
449        test_iter(r, 0..9);
450        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
451        #[cfg(feature = "alloc")]
452        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
453        test_iter(r, UnhintedIterator { iter: 0..9 });
454        test_iter(
455            r,
456            ChunkHintedIterator {
457                iter: 0..9,
458                chunk_size: 4,
459                chunk_remaining: 4,
460                hint_total_size: false,
461            },
462        );
463        test_iter(
464            r,
465            ChunkHintedIterator {
466                iter: 0..9,
467                chunk_size: 4,
468                chunk_remaining: 4,
469                hint_total_size: true,
470            },
471        );
472        test_iter(
473            r,
474            WindowHintedIterator {
475                iter: 0..9,
476                window_size: 2,
477                hint_total_size: false,
478            },
479        );
480        test_iter(
481            r,
482            WindowHintedIterator {
483                iter: 0..9,
484                window_size: 2,
485                hint_total_size: true,
486            },
487        );
488
489        assert_eq!((0..0).choose(r), None);
490        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
491    }
492
493    #[test]
494    #[cfg_attr(miri, ignore)] // Miri is too slow
495    fn test_iterator_choose_stable_stability() {
496        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
497            let r = &mut crate::test::rng(109);
498            let mut chosen = [0i32; 9];
499            for _ in 0..1000 {
500                let picked = iter.clone().choose_stable(r).unwrap();
501                chosen[picked] += 1;
502            }
503            chosen
504        }
505
506        let reference = test_iter(0..9);
507        assert_eq!(
508            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
509            reference
510        );
511
512        #[cfg(feature = "alloc")]
513        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
514        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
515        assert_eq!(
516            test_iter(ChunkHintedIterator {
517                iter: 0..9,
518                chunk_size: 4,
519                chunk_remaining: 4,
520                hint_total_size: false,
521            }),
522            reference
523        );
524        assert_eq!(
525            test_iter(ChunkHintedIterator {
526                iter: 0..9,
527                chunk_size: 4,
528                chunk_remaining: 4,
529                hint_total_size: true,
530            }),
531            reference
532        );
533        assert_eq!(
534            test_iter(WindowHintedIterator {
535                iter: 0..9,
536                window_size: 2,
537                hint_total_size: false,
538            }),
539            reference
540        );
541        assert_eq!(
542            test_iter(WindowHintedIterator {
543                iter: 0..9,
544                window_size: 2,
545                hint_total_size: true,
546            }),
547            reference
548        );
549    }
550
551    #[test]
552    #[cfg(feature = "alloc")]
553    fn test_sample_iter() {
554        let min_val = 1;
555        let max_val = 100;
556
557        let mut r = crate::test::rng(401);
558        let vals = (min_val..max_val).collect::<Vec<i32>>();
559        let small_sample = vals.iter().sample(&mut r, 5);
560        let large_sample = vals.iter().sample(&mut r, vals.len() + 5);
561
562        assert_eq!(small_sample.len(), 5);
563        assert_eq!(large_sample.len(), vals.len());
564        // no randomization happens when amount >= len
565        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
566
567        assert!(
568            small_sample
569                .iter()
570                .all(|e| { **e >= min_val && **e <= max_val })
571        );
572    }
573
574    #[test]
575    fn value_stability_choose() {
576        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
577            let mut rng = crate::test::rng(411);
578            iter.choose(&mut rng)
579        }
580
581        assert_eq!(choose([].iter().cloned()), None);
582        assert_eq!(choose(0..100), Some(33));
583        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
584        assert_eq!(
585            choose(ChunkHintedIterator {
586                iter: 0..100,
587                chunk_size: 32,
588                chunk_remaining: 32,
589                hint_total_size: false,
590            }),
591            Some(91)
592        );
593        assert_eq!(
594            choose(ChunkHintedIterator {
595                iter: 0..100,
596                chunk_size: 32,
597                chunk_remaining: 32,
598                hint_total_size: true,
599            }),
600            Some(91)
601        );
602        assert_eq!(
603            choose(WindowHintedIterator {
604                iter: 0..100,
605                window_size: 32,
606                hint_total_size: false,
607            }),
608            Some(34)
609        );
610        assert_eq!(
611            choose(WindowHintedIterator {
612                iter: 0..100,
613                window_size: 32,
614                hint_total_size: true,
615            }),
616            Some(34)
617        );
618    }
619
620    #[test]
621    fn value_stability_choose_stable() {
622        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
623            let mut rng = crate::test::rng(411);
624            iter.choose_stable(&mut rng)
625        }
626
627        assert_eq!(choose([].iter().cloned()), None);
628        assert_eq!(choose(0..100), Some(27));
629        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
630        assert_eq!(
631            choose(ChunkHintedIterator {
632                iter: 0..100,
633                chunk_size: 32,
634                chunk_remaining: 32,
635                hint_total_size: false,
636            }),
637            Some(27)
638        );
639        assert_eq!(
640            choose(ChunkHintedIterator {
641                iter: 0..100,
642                chunk_size: 32,
643                chunk_remaining: 32,
644                hint_total_size: true,
645            }),
646            Some(27)
647        );
648        assert_eq!(
649            choose(WindowHintedIterator {
650                iter: 0..100,
651                window_size: 32,
652                hint_total_size: false,
653            }),
654            Some(27)
655        );
656        assert_eq!(
657            choose(WindowHintedIterator {
658                iter: 0..100,
659                window_size: 32,
660                hint_total_size: true,
661            }),
662            Some(27)
663        );
664    }
665
666    #[test]
667    fn value_stability_sample() {
668        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
669            let mut rng = crate::test::rng(412);
670            let mut buf = [0u32; 8];
671            assert_eq!(iter.clone().sample_fill(&mut rng, &mut buf), v.len());
672            assert_eq!(&buf[0..v.len()], v);
673
674            #[cfg(feature = "alloc")]
675            {
676                let mut rng = crate::test::rng(412);
677                assert_eq!(iter.sample(&mut rng, v.len()), v);
678            }
679        }
680
681        do_test(0..4, &[0, 1, 2, 3]);
682        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
683        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
684    }
685}