1#![warn(missing_docs)]
7
8use crate::builder::{FinalStage, NumThreadsStage};
9use parking_lot::{Condvar, Mutex};
10use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
11use std::collections::BinaryHeap;
12use std::panic::{self, AssertUnwindSafe};
13use std::sync::atomic::{self, AtomicBool};
14use std::sync::Arc;
15use std::thread;
16use std::time::{Duration, Instant};
17
18pub mod builder;
19
20#[derive(Debug)]
22pub struct JobHandle(Arc<AtomicBool>);
23
24impl JobHandle {
25 pub fn cancel(&self) {
27 self.0.store(true, atomic::Ordering::SeqCst);
28 }
29}
30
31enum JobType {
32 Once(Box<dyn FnOnce() + Send + 'static>),
33 FixedRate {
34 f: Box<dyn FnMut() + Send + 'static>,
35 rate: Duration,
36 },
37 DynamicRate(Box<dyn FnMut() -> Option<Duration> + Send + 'static>),
38 FixedDelay {
39 f: Box<dyn FnMut() + Send + 'static>,
40 delay: Duration,
41 },
42 DynamicDelay(Box<dyn FnMut() -> Option<Duration> + Send + 'static>),
43}
44
45struct Job {
46 type_: JobType,
47 time: Instant,
48 canceled: Arc<AtomicBool>,
49}
50
51impl PartialOrd for Job {
52 fn partial_cmp(&self, other: &Job) -> Option<Ordering> {
53 Some(self.cmp(other))
54 }
55}
56
57impl Ord for Job {
58 fn cmp(&self, other: &Job) -> Ordering {
59 self.time.cmp(&other.time).reverse()
61 }
62}
63
64impl PartialEq for Job {
65 fn eq(&self, other: &Job) -> bool {
66 self.time == other.time
67 }
68}
69
70impl Eq for Job {}
71
72struct InnerPool {
73 queue: BinaryHeap<Job>,
74 shutdown: bool,
75 on_drop_behavior: OnPoolDropBehavior,
76}
77
78struct SharedPool {
79 inner: Mutex<InnerPool>,
80 cvar: Condvar,
81}
82
83impl SharedPool {
84 fn run(&self, job: Job) {
85 let mut inner = self.inner.lock();
86
87 if inner.shutdown {
89 return;
90 }
91
92 match inner.queue.peek() {
93 None => self.cvar.notify_all(),
94 Some(e) if e.time > job.time => self.cvar.notify_all(),
95 _ => 0,
96 };
97 inner.queue.push(job);
98 }
99}
100
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
104pub enum OnPoolDropBehavior {
105 CompletePendingScheduled,
110
111 DiscardPendingScheduled,
113}
114
115pub struct ScheduledThreadPool {
123 shared: Arc<SharedPool>,
124}
125
126impl Drop for ScheduledThreadPool {
127 fn drop(&mut self) {
128 self.shared.inner.lock().shutdown = true;
129 self.shared.cvar.notify_all();
130 }
131}
132
133impl ScheduledThreadPool {
134 pub fn new(num_threads: usize) -> ScheduledThreadPool {
140 Self::builder().num_threads(num_threads).build()
141 }
142
143 pub fn builder() -> builder::NumThreadsStage {
145 NumThreadsStage(())
146 }
147
148 #[deprecated(note = "use ScheduledThreadPool::builder", since = "0.2.7")]
158 pub fn with_name(thread_name: &str, num_threads: usize) -> ScheduledThreadPool {
159 Self::builder()
160 .num_threads(num_threads)
161 .thread_name_pattern(thread_name)
162 .build()
163 }
164
165 fn new_inner(builder: FinalStage) -> ScheduledThreadPool {
166 let inner = InnerPool {
167 queue: BinaryHeap::new(),
168 shutdown: false,
169 on_drop_behavior: builder.on_drop_behavior,
170 };
171
172 let shared = SharedPool {
173 inner: Mutex::new(inner),
174 cvar: Condvar::new(),
175 };
176
177 let pool = ScheduledThreadPool {
178 shared: Arc::new(shared),
179 };
180
181 for i in 0..builder.num_threads {
182 Worker::start(
183 builder
184 .thread_name_pattern
185 .map(|n| n.replace("{}", &i.to_string())),
186 pool.shared.clone(),
187 );
188 }
189
190 pool
191 }
192
193 pub fn execute<F>(&self, job: F) -> JobHandle
195 where
196 F: FnOnce() + Send + 'static,
197 {
198 self.execute_after(Duration::from_secs(0), job)
199 }
200
201 pub fn execute_after<F>(&self, delay: Duration, job: F) -> JobHandle
203 where
204 F: FnOnce() + Send + 'static,
205 {
206 self.execute_after_inner(delay, Box::new(job))
207 }
208
209 fn execute_after_inner(
210 &self,
211 delay: Duration,
212 job: Box<dyn FnOnce() + Send + 'static>,
213 ) -> JobHandle {
214 let canceled = Arc::new(AtomicBool::new(false));
215 let job = Job {
216 type_: JobType::Once(job),
217 time: Instant::now() + delay,
218 canceled: canceled.clone(),
219 };
220 self.shared.run(job);
221 JobHandle(canceled)
222 }
223
224 pub fn execute_at_fixed_rate<F>(
234 &self,
235 initial_delay: Duration,
236 rate: Duration,
237 f: F,
238 ) -> JobHandle
239 where
240 F: FnMut() + Send + 'static,
241 {
242 self.execute_at_fixed_rate_inner(initial_delay, rate, Box::new(f))
243 }
244
245 fn execute_at_fixed_rate_inner(
246 &self,
247 initial_delay: Duration,
248 rate: Duration,
249 f: Box<dyn FnMut() + Send + 'static>,
250 ) -> JobHandle {
251 let canceled = Arc::new(AtomicBool::new(false));
252 let job = Job {
253 type_: JobType::FixedRate { f, rate },
254 time: Instant::now() + initial_delay,
255 canceled: canceled.clone(),
256 };
257 self.shared.run(job);
258 JobHandle(canceled)
259 }
260
261 pub fn execute_at_dynamic_rate<F>(&self, initial_delay: Duration, f: F) -> JobHandle
271 where
272 F: FnMut() -> Option<Duration> + Send + 'static,
273 {
274 self.execute_at_dynamic_rate_inner(initial_delay, Box::new(f))
275 }
276
277 fn execute_at_dynamic_rate_inner(
278 &self,
279 initial_delay: Duration,
280 f: Box<dyn FnMut() -> Option<Duration> + Send + 'static>,
281 ) -> JobHandle {
282 let canceled = Arc::new(AtomicBool::new(false));
283 let job = Job {
284 type_: JobType::DynamicRate(f),
285 time: Instant::now() + initial_delay,
286 canceled: canceled.clone(),
287 };
288 self.shared.run(job);
289 JobHandle(canceled)
290 }
291
292 pub fn execute_with_fixed_delay<F>(
303 &self,
304 initial_delay: Duration,
305 delay: Duration,
306 f: F,
307 ) -> JobHandle
308 where
309 F: FnMut() + Send + 'static,
310 {
311 self.execute_with_fixed_delay_inner(initial_delay, delay, Box::new(f))
312 }
313
314 fn execute_with_fixed_delay_inner(
315 &self,
316 initial_delay: Duration,
317 delay: Duration,
318 f: Box<dyn FnMut() + Send + 'static>,
319 ) -> JobHandle {
320 let canceled = Arc::new(AtomicBool::new(false));
321 let job = Job {
322 type_: JobType::FixedDelay { f, delay },
323 time: Instant::now() + initial_delay,
324 canceled: canceled.clone(),
325 };
326 self.shared.run(job);
327 JobHandle(canceled)
328 }
329
330 pub fn execute_with_dynamic_delay<F>(&self, initial_delay: Duration, f: F) -> JobHandle
341 where
342 F: FnMut() -> Option<Duration> + Send + 'static,
343 {
344 self.execute_with_dynamic_delay_inner(initial_delay, Box::new(f))
345 }
346
347 fn execute_with_dynamic_delay_inner(
348 &self,
349 initial_delay: Duration,
350 f: Box<dyn FnMut() -> Option<Duration> + Send + 'static>,
351 ) -> JobHandle {
352 let canceled = Arc::new(AtomicBool::new(false));
353 let job = Job {
354 type_: JobType::DynamicDelay(f),
355 time: Instant::now() + initial_delay,
356 canceled: canceled.clone(),
357 };
358 self.shared.run(job);
359 JobHandle(canceled)
360 }
361}
362
363struct Worker {
364 shared: Arc<SharedPool>,
365}
366
367impl Worker {
368 fn start(name: Option<String>, shared: Arc<SharedPool>) {
369 let mut worker = Worker { shared };
370
371 let mut thread = thread::Builder::new();
372 if let Some(name) = name {
373 thread = thread.name(name);
374 }
375 thread.spawn(move || worker.run()).unwrap();
376 }
377
378 fn run(&mut self) {
379 while let Some(job) = self.get_job() {
380 let _ = panic::catch_unwind(AssertUnwindSafe(|| self.run_job(job)));
382 }
383 }
384
385 fn get_job(&self) -> Option<Job> {
386 enum Need {
387 Wait,
388 WaitTimeout(Duration),
389 }
390
391 let mut inner = self.shared.inner.lock();
392 loop {
393 let now = Instant::now();
394
395 let need = match inner.queue.peek() {
396 None if inner.shutdown => return None,
397 None => Need::Wait,
398 Some(_)
399 if inner.shutdown
400 && inner.on_drop_behavior
401 == OnPoolDropBehavior::DiscardPendingScheduled =>
402 {
403 return None
404 }
405 Some(e) if e.time <= now => break,
406 Some(e) => Need::WaitTimeout(e.time - now),
407 };
408
409 match need {
410 Need::Wait => self.shared.cvar.wait(&mut inner),
411 Need::WaitTimeout(t) => {
412 self.shared.cvar.wait_until(&mut inner, now + t);
413 }
414 };
415 }
416
417 Some(inner.queue.pop().unwrap())
418 }
419
420 fn run_job(&self, job: Job) {
421 if job.canceled.load(atomic::Ordering::SeqCst) {
422 return;
423 }
424
425 match job.type_ {
426 JobType::Once(f) => f(),
427 JobType::FixedRate { mut f, rate } => {
428 f();
429 let new_job = Job {
430 type_: JobType::FixedRate { f, rate },
431 time: job.time + rate,
432 canceled: job.canceled,
433 };
434 self.shared.run(new_job)
435 }
436 JobType::DynamicRate(mut f) => {
437 if let Some(next_rate) = f() {
438 let new_job = Job {
439 type_: JobType::DynamicRate(f),
440 time: job.time + next_rate,
441 canceled: job.canceled,
442 };
443 self.shared.run(new_job)
444 }
445 }
446 JobType::FixedDelay { mut f, delay } => {
447 f();
448 let new_job = Job {
449 type_: JobType::FixedDelay { f, delay },
450 time: Instant::now() + delay,
451 canceled: job.canceled,
452 };
453 self.shared.run(new_job)
454 }
455 JobType::DynamicDelay(mut f) => {
456 if let Some(next_delay) = f() {
457 let new_job = Job {
458 type_: JobType::DynamicDelay(f),
459 time: Instant::now() + next_delay,
460 canceled: job.canceled,
461 };
462 self.shared.run(new_job)
463 }
464 }
465 }
466 }
467}
468
469#[cfg(test)]
470mod test {
471 use std::sync::mpsc::{channel, Receiver, Sender};
472 use std::sync::{Arc, Barrier};
473 use std::time::Duration;
474
475 use super::{OnPoolDropBehavior, ScheduledThreadPool};
476
477 const TEST_TASKS: usize = 4;
478
479 #[test]
480 fn test_works() {
481 let pool = ScheduledThreadPool::new(TEST_TASKS);
482
483 let (tx, rx) = channel();
484 for _ in 0..TEST_TASKS {
485 let tx = tx.clone();
486 pool.execute(move || {
487 tx.send(1usize).unwrap();
488 });
489 }
490
491 assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
492 }
493
494 #[test]
495 fn test_works_with_builder() {
496 let pool = ScheduledThreadPool::builder()
497 .num_threads(TEST_TASKS)
498 .build();
499
500 let (tx, rx) = channel();
501 for _ in 0..TEST_TASKS {
502 let tx = tx.clone();
503 pool.execute(move || {
504 tx.send(1usize).unwrap();
505 });
506 }
507
508 assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
509 }
510
511 #[test]
512 #[should_panic(expected = "num_threads must be positive")]
513 fn test_zero_tasks_panic() {
514 ScheduledThreadPool::new(0);
515 }
516
517 #[test]
518 #[should_panic(expected = "num_threads must be positive")]
519 fn test_num_threads_zero_panics_with_builder() {
520 ScheduledThreadPool::builder().num_threads(0);
521 }
522
523 #[test]
524 fn test_recovery_from_subtask_panic() {
525 let pool = ScheduledThreadPool::new(TEST_TASKS);
526
527 let waiter = Arc::new(Barrier::new(TEST_TASKS));
529 for _ in 0..TEST_TASKS {
530 let waiter = waiter.clone();
531 pool.execute(move || {
532 waiter.wait();
533 panic!();
534 });
535 }
536
537 let (tx, rx) = channel();
539 let waiter = Arc::new(Barrier::new(TEST_TASKS));
540 for _ in 0..TEST_TASKS {
541 let tx = tx.clone();
542 let waiter = waiter.clone();
543 pool.execute(move || {
544 waiter.wait();
545 tx.send(1usize).unwrap();
546 });
547 }
548
549 assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
550 }
551
552 #[test]
553 fn test_execute_after() {
554 let pool = ScheduledThreadPool::new(TEST_TASKS);
555 let (tx, rx) = channel();
556
557 let tx1 = tx.clone();
558 pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
559 pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
560
561 assert_eq!(2, rx.recv().unwrap());
562 assert_eq!(1, rx.recv().unwrap());
563 }
564
565 #[test]
566 fn test_jobs_complete_after_drop() {
567 let pool = ScheduledThreadPool::new(TEST_TASKS);
568 let (tx, rx) = channel();
569
570 let tx1 = tx.clone();
571 pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
572 pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
573
574 drop(pool);
575
576 assert_eq!(2, rx.recv().unwrap());
577 assert_eq!(1, rx.recv().unwrap());
578 }
579
580 #[test]
581 fn test_jobs_do_not_complete_after_drop_if_behavior_is_discard() {
582 let pool = ScheduledThreadPool::builder()
583 .num_threads(TEST_TASKS)
584 .on_drop_behavior(OnPoolDropBehavior::DiscardPendingScheduled)
585 .build();
586 let (tx, rx) = channel();
587
588 let tx1 = tx.clone();
589 pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
590 pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
591
592 drop(pool);
593
594 assert!(rx.recv().is_err());
595 }
596
597 #[test]
598 fn test_jobs_do_not_complete_after_drop_if_behavior_is_discard_using_builder() {
599 let pool = ScheduledThreadPool::builder()
600 .num_threads(TEST_TASKS)
601 .on_drop_behavior(OnPoolDropBehavior::DiscardPendingScheduled)
602 .build();
603 let (tx, rx) = channel();
604
605 let tx1 = tx.clone();
606 pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
607 pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
608
609 drop(pool);
610
611 assert!(rx.recv().is_err());
612 }
613
614 #[test]
615 fn test_fixed_rate_jobs_stop_after_drop() {
616 test_jobs_stop_after_drop(
617 |pool: &Arc<ScheduledThreadPool>, tx: Sender<i32>, rx2: Receiver<()>| {
618 let mut pool2 = Some(pool.clone());
619 let mut i = 0i32;
620 pool.execute_at_fixed_rate(
621 Duration::from_millis(500),
622 Duration::from_millis(500),
623 move || {
624 i += 1;
625 tx.send(i).unwrap();
626 rx2.recv().unwrap();
627 if i == 2 {
628 drop(pool2.take().unwrap());
629 }
630 },
631 );
632 },
633 );
634 }
635
636 #[test]
637 fn test_dynamic_delay_jobs_stop_after_drop() {
638 test_jobs_stop_after_drop(
639 |pool: &Arc<ScheduledThreadPool>, tx: Sender<i32>, rx2: Receiver<()>| {
640 let mut pool2 = Some(pool.clone());
641 let mut i = 0i32;
642 pool.execute_with_dynamic_delay(Duration::from_millis(500), move || {
643 i += 1;
644 tx.send(i).unwrap();
645 rx2.recv().unwrap();
646 if i == 2 {
647 drop(pool2.take().unwrap());
648 }
649 Some(Duration::from_millis(500))
650 });
651 },
652 );
653 }
654
655 #[test]
656 fn test_dynamic_rate_jobs_stop_after_drop() {
657 test_jobs_stop_after_drop(
658 |pool: &Arc<ScheduledThreadPool>, tx: Sender<i32>, rx2: Receiver<()>| {
659 let mut pool2 = Some(pool.clone());
660 let mut i = 0i32;
661 pool.execute_at_dynamic_rate(Duration::from_millis(500), move || {
662 i += 1;
663 tx.send(i).unwrap();
664 rx2.recv().unwrap();
665 if i == 2 {
666 drop(pool2.take().unwrap());
667 }
668 Some(Duration::from_millis(500))
669 });
670 },
671 );
672 }
673
674 fn test_jobs_stop_after_drop<F>(mut execute_fn: F)
675 where
676 F: FnMut(&Arc<ScheduledThreadPool>, Sender<i32>, Receiver<()>),
677 {
678 use super::OnPoolDropBehavior::*;
679 for drop_behavior in [CompletePendingScheduled, DiscardPendingScheduled] {
680 let pool = Arc::new(
681 ScheduledThreadPool::builder()
682 .num_threads(TEST_TASKS)
683 .on_drop_behavior(drop_behavior)
684 .build(),
685 );
686 let (tx, rx) = channel();
687 let (tx2, rx2) = channel();
688
689 execute_fn(&pool, tx, rx2);
691
692 drop(pool);
695
696 assert_eq!(Ok(1), rx.recv());
697 tx2.send(()).unwrap();
698 assert_eq!(Ok(2), rx.recv());
699 tx2.send(()).unwrap();
700 assert!(rx.recv().is_err());
701 }
702 }
703
704 #[test]
705 fn cancellation() {
706 let pool = ScheduledThreadPool::new(TEST_TASKS);
707 let (tx, rx) = channel();
708
709 let handle = pool.execute_at_fixed_rate(
710 Duration::from_millis(500),
711 Duration::from_millis(500),
712 move || {
713 tx.send(()).unwrap();
714 },
715 );
716
717 rx.recv().unwrap();
718 handle.cancel();
719 assert!(rx.recv().is_err());
720 }
721}