1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
34/// Trait for unsigned division of a double-wide integer
5/// when the quotient doesn't overflow.
6///
7/// This is the inverse of widening multiplication:
8/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
9/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
11/// Computes `(self / n, self % n))`
12 ///
13 /// # Safety
14 /// The caller must ensure that `self.hi() < n`, or equivalently,
15 /// that the quotient does not overflow.
16unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
1718/// Returns `Some((self / n, self % n))` when `self.hi() < n`.
19fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
20if self.hi() < n {
21Some(unsafe { self.unchecked_narrowing_div_rem(n) })
22 } else {
23None24 }
25 }
26}
2728// For primitive types we can just use the standard
29// division operators in the double-wide type.
30macro_rules! impl_narrowing_div_primitive {
31 ($D:ident) => {
32impl NarrowingDiv for $D {
33unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
34if self.hi() >= n {
35unsafe { core::hint::unreachable_unchecked() }
36 }
37 ((self / n.widen()).cast(), (self % n.widen()).cast())
38 }
39 }
40 };
41}
4243// Extend division from `u2N / uN` to `u4N / u2N`
44// This is not the most efficient algorithm, but it is
45// relatively simple.
46macro_rules! impl_narrowing_div_recurse {
47 ($D:ident) => {
48impl NarrowingDiv for $D {
49unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
50if self.hi() >= n {
51unsafe { core::hint::unreachable_unchecked() }
52 }
5354// Normalize the divisor by shifting the most significant one
55 // to the leading position. `n != 0` is implied by `self.hi() < n`
56let lz = n.leading_zeros();
57let a = self << lz;
58let b = n << lz;
5960let ah = a.hi();
61let (a0, a1) = a.lo().lo_hi();
62// SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
63 // SAFETY: `ah < b` follows from `self.hi() < n`
64let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
65// SAFETY: `r < b` is given as the postcondition of the previous call
66let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
6768// Undo the earlier normalization for the remainder
69(Self::H::from_lo_hi(q0, q1), r >> lz)
70 }
71 }
72 };
73}
7475impl NarrowingDiv for u16 {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
-> (Self::H, Self::H) {
if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}impl_narrowing_div_primitive!(u16);
76impl NarrowingDiv for u32 {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
-> (Self::H, Self::H) {
if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}impl_narrowing_div_primitive!(u32);
77impl NarrowingDiv for u64 {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
-> (Self::H, Self::H) {
if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}impl_narrowing_div_primitive!(u64);
78impl NarrowingDiv for u128 {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
-> (Self::H, Self::H) {
if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}impl_narrowing_div_primitive!(u128);
79impl NarrowingDiv for u256 {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
-> (Self::H, Self::H) {
if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
let lz = n.leading_zeros();
let a = self << lz;
let b = n << lz;
let ah = a.hi();
let (a0, a1) = a.lo().lo_hi();
let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
(Self::H::from_lo_hi(q0, q1), r >> lz)
}
}impl_narrowing_div_recurse!(u256);
8081/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
82///
83/// Returns the quotient and remainder of `(a * R + a0) / n`,
84/// where `R = (1 << U::BITS)` is the digit size.
85///
86/// # Safety
87/// Requires that `n.leading_zeros() == 0` and `a < n`.
88unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
89where
90U: HInt,
91 U::D: Int + NarrowingDiv,
92{
93if n.leading_zeros() > 0 || a >= n {
94unsafe { core::hint::unreachable_unchecked() }
95 }
9697// n = n1R + n0
98let (n0, n1) = n.lo_hi();
99// a = a2R + a1
100let (a1, a2) = a.lo_hi();
101102let mut q;
103let mut r;
104let mut wrap;
105// `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
106if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
107q = q0;
108// a = qn1 + r1, where 0 <= r1 < n1
109110 // Include the remainder with the low bits:
111 // r = a0 + r1R
112r = U::D::from_lo_hi(a0, r1);
113114// Subtract the contribution of the divisor low bits with the estimated quotient
115let d = q.widen_mul(n0);
116 (r, wrap) = r.overflowing_sub(d);
117118// Since `q` is the quotient of dividing with a slightly smaller divisor,
119 // it may be an overapproximation, but is never too small, and similarly,
120 // `r` is now either the correct remainder ...
121if !wrap {
122return (q, r);
123 }
124// ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
125 // and we have to adjust.
126q -= U::ONE;
127 } else {
128if true {
if !(a2 == n1 && a1 < n0) {
::core::panicking::panic("assertion failed: a2 == n1 && a1 < n0")
};
};debug_assert!(a2 == n1 && a1 < n0);
129// Otherwise, `a2 == n1`, and the estimated quotient would be
130 // `R + (a1 % n1)`, but the correct quotient can't overflow.
131 // We'll start from `q = R = (1 << U::BITS)`,
132 // so `r = aR + a0 - qn = (a - n)R + a0`
133r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
134// Since `a < n`, the first decrement is always needed:
135q = U::MAX; /* R - 1 */
136}
137138 (r, wrap) = r.overflowing_add(n);
139if wrap {
140return (q, r);
141 }
142143// If the remainder still didn't wrap, we need another step.
144q -= U::ONE;
145 (r, wrap) = r.overflowing_add(n);
146// Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
147if true {
if !wrap {
{
::core::panicking::panic_fmt(format_args!("estimated quotient should be off by at most two"));
}
};
};debug_assert!(wrap, "estimated quotient should be off by at most two");
148 (q, r)
149}
150151#[cfg(test)]
152mod test {
153use super::{HInt, NarrowingDiv};
154155#[test]
156fn inverse_mul() {
157for x in 0..=u8::MAX {
158for y in 1..=u8::MAX {
159let xy = x.widen_mul(y);
160assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
161assert_eq!(
162 (xy + (y - 1) as u16).checked_narrowing_div_rem(y),
163Some((x, y - 1))
164 );
165if y > 1 {
166assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
167assert_eq!(
168 (xy + (y - 2) as u16).checked_narrowing_div_rem(y),
169Some((x, y - 2))
170 );
171 }
172 }
173 }
174 }
175}