1#[allow(unused)]
12use super::IndexedRandom;
13use super::coin_flipper::CoinFlipper;
14use crate::{Rng, RngExt};
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 #[allow(unknown_lints)]
144 #[allow(clippy::double_ended_iterator_last)]
145 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146 where
147 R: Rng + ?Sized,
148 {
149 let mut consumed = 0;
150 let mut result = None;
151 let mut coin_flipper = CoinFlipper::new(rng);
152
153 loop {
154 let mut next = 0;
159
160 let (lower, _) = self.size_hint();
161 if lower >= 2 {
162 let highest_selected = (0..lower)
163 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164 .last();
165
166 consumed += lower;
167 next = lower;
168
169 if let Some(ix) = highest_selected {
170 result = self.nth(ix);
171 next -= ix + 1;
172 if true {
if !result.is_some() {
{
::core::panicking::panic_fmt(format_args!("iterator shorter than size_hint().0"));
}
};
};debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173 }
174 }
175
176 let elem = self.nth(next);
177 if elem.is_none() {
178 return result;
179 }
180
181 if coin_flipper.random_ratio_one_over(consumed + 1) {
182 result = elem;
183 }
184 consumed += 1;
185 }
186 }
187
188 fn sample_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204 where
205 R: Rng + ?Sized,
206 {
207 let amount = buf.len();
208 let mut len = 0;
209 while len < amount {
210 if let Some(elem) = self.next() {
211 buf[len] = elem;
212 len += 1;
213 } else {
214 return len;
216 }
217 }
218
219 for (i, elem) in self.enumerate() {
221 let k = rng.random_range(..i + 1 + amount);
222 if let Some(slot) = buf.get_mut(k) {
223 *slot = elem;
224 }
225 }
226 len
227 }
228
229 #[cfg(feature = "alloc")]
244 fn sample<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245 where
246 R: Rng + ?Sized,
247 {
248 let mut reservoir = Vec::from_iter(self.by_ref().take(amount));
249
250 if reservoir.len() == amount {
255 for (i, elem) in self.enumerate() {
256 let k = rng.random_range(..i + 1 + amount);
257 if let Some(slot) = reservoir.get_mut(k) {
258 *slot = elem;
259 }
260 }
261 }
262 reservoir
263 }
264
265 #[deprecated(since = "0.10.0", note = "Renamed to `sample_fill`")]
267 fn choose_multiple_fill<R>(self, rng: &mut R, buf: &mut [Self::Item]) -> usize
268 where
269 R: Rng + ?Sized,
270 {
271 self.sample_fill(rng, buf)
272 }
273
274 #[cfg(feature = "alloc")]
276 #[deprecated(since = "0.10.0", note = "Renamed to `sample`")]
277 fn choose_multiple<R>(self, rng: &mut R, amount: usize) -> Vec<Self::Item>
278 where
279 R: Rng + ?Sized,
280 {
281 self.sample(rng, amount)
282 }
283}
284
285impl<I> IteratorRandom for I where I: Iterator + Sized {}
286
287#[cfg(test)]
288mod test {
289 use super::*;
290 #[cfg(all(feature = "alloc", not(feature = "std")))]
291 use alloc::vec::Vec;
292
293 #[derive(Clone)]
294 struct UnhintedIterator<I: Iterator + Clone> {
295 iter: I,
296 }
297 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
298 type Item = I::Item;
299
300 fn next(&mut self) -> Option<Self::Item> {
301 self.iter.next()
302 }
303 }
304
305 #[derive(Clone)]
306 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
307 iter: I,
308 chunk_remaining: usize,
309 chunk_size: usize,
310 hint_total_size: bool,
311 }
312 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
313 type Item = I::Item;
314
315 fn next(&mut self) -> Option<Self::Item> {
316 if self.chunk_remaining == 0 {
317 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
318 }
319 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
320
321 self.iter.next()
322 }
323
324 fn size_hint(&self) -> (usize, Option<usize>) {
325 (
326 self.chunk_remaining,
327 if self.hint_total_size {
328 Some(self.iter.len())
329 } else {
330 None
331 },
332 )
333 }
334 }
335
336 #[derive(Clone)]
337 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
338 iter: I,
339 window_size: usize,
340 hint_total_size: bool,
341 }
342 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
343 type Item = I::Item;
344
345 fn next(&mut self) -> Option<Self::Item> {
346 self.iter.next()
347 }
348
349 fn size_hint(&self) -> (usize, Option<usize>) {
350 (
351 core::cmp::min(self.iter.len(), self.window_size),
352 if self.hint_total_size {
353 Some(self.iter.len())
354 } else {
355 None
356 },
357 )
358 }
359 }
360
361 #[test]
362 #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
364 let r = &mut crate::test::rng(109);
365 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
366 let mut chosen = [0i32; 9];
367 for _ in 0..1000 {
368 let picked = iter.clone().choose(r).unwrap();
369 chosen[picked] += 1;
370 }
371 for count in chosen.iter() {
372 assert!(
376 72 < *count && *count < 154,
377 "count not close to 1000/9: {}",
378 count
379 );
380 }
381 }
382
383 test_iter(r, 0..9);
384 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
385 #[cfg(feature = "alloc")]
386 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
387 test_iter(r, UnhintedIterator { iter: 0..9 });
388 test_iter(
389 r,
390 ChunkHintedIterator {
391 iter: 0..9,
392 chunk_size: 4,
393 chunk_remaining: 4,
394 hint_total_size: false,
395 },
396 );
397 test_iter(
398 r,
399 ChunkHintedIterator {
400 iter: 0..9,
401 chunk_size: 4,
402 chunk_remaining: 4,
403 hint_total_size: true,
404 },
405 );
406 test_iter(
407 r,
408 WindowHintedIterator {
409 iter: 0..9,
410 window_size: 2,
411 hint_total_size: false,
412 },
413 );
414 test_iter(
415 r,
416 WindowHintedIterator {
417 iter: 0..9,
418 window_size: 2,
419 hint_total_size: true,
420 },
421 );
422
423 assert_eq!((0..0).choose(r), None);
424 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
425 }
426
427 #[test]
428 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
430 let r = &mut crate::test::rng(109);
431 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
432 let mut chosen = [0i32; 9];
433 for _ in 0..1000 {
434 let picked = iter.clone().choose_stable(r).unwrap();
435 chosen[picked] += 1;
436 }
437 for count in chosen.iter() {
438 assert!(
442 72 < *count && *count < 154,
443 "count not close to 1000/9: {}",
444 count
445 );
446 }
447 }
448
449 test_iter(r, 0..9);
450 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
451 #[cfg(feature = "alloc")]
452 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
453 test_iter(r, UnhintedIterator { iter: 0..9 });
454 test_iter(
455 r,
456 ChunkHintedIterator {
457 iter: 0..9,
458 chunk_size: 4,
459 chunk_remaining: 4,
460 hint_total_size: false,
461 },
462 );
463 test_iter(
464 r,
465 ChunkHintedIterator {
466 iter: 0..9,
467 chunk_size: 4,
468 chunk_remaining: 4,
469 hint_total_size: true,
470 },
471 );
472 test_iter(
473 r,
474 WindowHintedIterator {
475 iter: 0..9,
476 window_size: 2,
477 hint_total_size: false,
478 },
479 );
480 test_iter(
481 r,
482 WindowHintedIterator {
483 iter: 0..9,
484 window_size: 2,
485 hint_total_size: true,
486 },
487 );
488
489 assert_eq!((0..0).choose(r), None);
490 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
491 }
492
493 #[test]
494 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
496 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
497 let r = &mut crate::test::rng(109);
498 let mut chosen = [0i32; 9];
499 for _ in 0..1000 {
500 let picked = iter.clone().choose_stable(r).unwrap();
501 chosen[picked] += 1;
502 }
503 chosen
504 }
505
506 let reference = test_iter(0..9);
507 assert_eq!(
508 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
509 reference
510 );
511
512 #[cfg(feature = "alloc")]
513 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
514 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
515 assert_eq!(
516 test_iter(ChunkHintedIterator {
517 iter: 0..9,
518 chunk_size: 4,
519 chunk_remaining: 4,
520 hint_total_size: false,
521 }),
522 reference
523 );
524 assert_eq!(
525 test_iter(ChunkHintedIterator {
526 iter: 0..9,
527 chunk_size: 4,
528 chunk_remaining: 4,
529 hint_total_size: true,
530 }),
531 reference
532 );
533 assert_eq!(
534 test_iter(WindowHintedIterator {
535 iter: 0..9,
536 window_size: 2,
537 hint_total_size: false,
538 }),
539 reference
540 );
541 assert_eq!(
542 test_iter(WindowHintedIterator {
543 iter: 0..9,
544 window_size: 2,
545 hint_total_size: true,
546 }),
547 reference
548 );
549 }
550
551 #[test]
552 #[cfg(feature = "alloc")]
553 fn test_sample_iter() {
554 let min_val = 1;
555 let max_val = 100;
556
557 let mut r = crate::test::rng(401);
558 let vals = (min_val..max_val).collect::<Vec<i32>>();
559 let small_sample = vals.iter().sample(&mut r, 5);
560 let large_sample = vals.iter().sample(&mut r, vals.len() + 5);
561
562 assert_eq!(small_sample.len(), 5);
563 assert_eq!(large_sample.len(), vals.len());
564 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
566
567 assert!(
568 small_sample
569 .iter()
570 .all(|e| { **e >= min_val && **e <= max_val })
571 );
572 }
573
574 #[test]
575 fn value_stability_choose() {
576 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
577 let mut rng = crate::test::rng(411);
578 iter.choose(&mut rng)
579 }
580
581 assert_eq!(choose([].iter().cloned()), None);
582 assert_eq!(choose(0..100), Some(33));
583 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
584 assert_eq!(
585 choose(ChunkHintedIterator {
586 iter: 0..100,
587 chunk_size: 32,
588 chunk_remaining: 32,
589 hint_total_size: false,
590 }),
591 Some(91)
592 );
593 assert_eq!(
594 choose(ChunkHintedIterator {
595 iter: 0..100,
596 chunk_size: 32,
597 chunk_remaining: 32,
598 hint_total_size: true,
599 }),
600 Some(91)
601 );
602 assert_eq!(
603 choose(WindowHintedIterator {
604 iter: 0..100,
605 window_size: 32,
606 hint_total_size: false,
607 }),
608 Some(34)
609 );
610 assert_eq!(
611 choose(WindowHintedIterator {
612 iter: 0..100,
613 window_size: 32,
614 hint_total_size: true,
615 }),
616 Some(34)
617 );
618 }
619
620 #[test]
621 fn value_stability_choose_stable() {
622 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
623 let mut rng = crate::test::rng(411);
624 iter.choose_stable(&mut rng)
625 }
626
627 assert_eq!(choose([].iter().cloned()), None);
628 assert_eq!(choose(0..100), Some(27));
629 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
630 assert_eq!(
631 choose(ChunkHintedIterator {
632 iter: 0..100,
633 chunk_size: 32,
634 chunk_remaining: 32,
635 hint_total_size: false,
636 }),
637 Some(27)
638 );
639 assert_eq!(
640 choose(ChunkHintedIterator {
641 iter: 0..100,
642 chunk_size: 32,
643 chunk_remaining: 32,
644 hint_total_size: true,
645 }),
646 Some(27)
647 );
648 assert_eq!(
649 choose(WindowHintedIterator {
650 iter: 0..100,
651 window_size: 32,
652 hint_total_size: false,
653 }),
654 Some(27)
655 );
656 assert_eq!(
657 choose(WindowHintedIterator {
658 iter: 0..100,
659 window_size: 32,
660 hint_total_size: true,
661 }),
662 Some(27)
663 );
664 }
665
666 #[test]
667 fn value_stability_sample() {
668 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
669 let mut rng = crate::test::rng(412);
670 let mut buf = [0u32; 8];
671 assert_eq!(iter.clone().sample_fill(&mut rng, &mut buf), v.len());
672 assert_eq!(&buf[0..v.len()], v);
673
674 #[cfg(feature = "alloc")]
675 {
676 let mut rng = crate::test::rng(412);
677 assert_eq!(iter.sample(&mut rng, v.len()), v);
678 }
679 }
680
681 do_test(0..4, &[0, 1, 2, 3]);
682 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
683 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
684 }
685}