Skip to main content

libm/math/support/int_traits/
narrowing_div.rs

1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
3
4/// Trait for unsigned division of a double-wide integer
5/// when the quotient doesn't overflow.
6///
7/// This is the inverse of widening multiplication:
8///  - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
9///  - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
11    /// Computes `(self / n, self % n))`
12    ///
13    /// # Safety
14    /// The caller must ensure that `self.hi() < n`, or equivalently,
15    /// that the quotient does not overflow.
16    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
17
18    /// Returns `Some((self / n, self % n))` when `self.hi() < n`.
19    fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
20        if self.hi() < n {
21            Some(unsafe { self.unchecked_narrowing_div_rem(n) })
22        } else {
23            None
24        }
25    }
26}
27
28// For primitive types we can just use the standard
29// division operators in the double-wide type.
30macro_rules! impl_narrowing_div_primitive {
31    ($D:ident) => {
32        impl NarrowingDiv for $D {
33            unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
34                if self.hi() >= n {
35                    unsafe { core::hint::unreachable_unchecked() }
36                }
37                ((self / n.widen()).cast(), (self % n.widen()).cast())
38            }
39        }
40    };
41}
42
43// Extend division from `u2N / uN` to `u4N / u2N`
44// This is not the most efficient algorithm, but it is
45// relatively simple.
46macro_rules! impl_narrowing_div_recurse {
47    ($D:ident) => {
48        impl NarrowingDiv for $D {
49            unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
50                if self.hi() >= n {
51                    unsafe { core::hint::unreachable_unchecked() }
52                }
53
54                // Normalize the divisor by shifting the most significant one
55                // to the leading position. `n != 0` is implied by `self.hi() < n`
56                let lz = n.leading_zeros();
57                let a = self << lz;
58                let b = n << lz;
59
60                let ah = a.hi();
61                let (a0, a1) = a.lo().lo_hi();
62                // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
63                // SAFETY: `ah < b` follows from `self.hi() < n`
64                let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
65                // SAFETY: `r < b` is given as the postcondition of the previous call
66                let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
67
68                // Undo the earlier normalization for the remainder
69                (Self::H::from_lo_hi(q0, q1), r >> lz)
70            }
71        }
72    };
73}
74
75impl NarrowingDiv for u16 {
    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
        -> (Self::H, Self::H) {
        if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
        ((self / n.widen()).cast(), (self % n.widen()).cast())
    }
}impl_narrowing_div_primitive!(u16);
76impl NarrowingDiv for u32 {
    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
        -> (Self::H, Self::H) {
        if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
        ((self / n.widen()).cast(), (self % n.widen()).cast())
    }
}impl_narrowing_div_primitive!(u32);
77impl NarrowingDiv for u64 {
    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
        -> (Self::H, Self::H) {
        if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
        ((self / n.widen()).cast(), (self % n.widen()).cast())
    }
}impl_narrowing_div_primitive!(u64);
78impl NarrowingDiv for u128 {
    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
        -> (Self::H, Self::H) {
        if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
        ((self / n.widen()).cast(), (self % n.widen()).cast())
    }
}impl_narrowing_div_primitive!(u128);
79impl NarrowingDiv for u256 {
    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H)
        -> (Self::H, Self::H) {
        if self.hi() >= n { unsafe { core::hint::unreachable_unchecked() } }
        let lz = n.leading_zeros();
        let a = self << lz;
        let b = n << lz;
        let ah = a.hi();
        let (a0, a1) = a.lo().lo_hi();
        let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
        let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
        (Self::H::from_lo_hi(q0, q1), r >> lz)
    }
}impl_narrowing_div_recurse!(u256);
80
81/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
82///
83/// Returns the quotient and remainder of `(a * R + a0) / n`,
84/// where `R = (1 << U::BITS)` is the digit size.
85///
86/// # Safety
87/// Requires that `n.leading_zeros() == 0` and `a < n`.
88unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
89where
90    U: HInt,
91    U::D: Int + NarrowingDiv,
92{
93    if n.leading_zeros() > 0 || a >= n {
94        unsafe { core::hint::unreachable_unchecked() }
95    }
96
97    // n = n1R + n0
98    let (n0, n1) = n.lo_hi();
99    // a = a2R + a1
100    let (a1, a2) = a.lo_hi();
101
102    let mut q;
103    let mut r;
104    let mut wrap;
105    // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
106    if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
107        q = q0;
108        // a = qn1 + r1, where 0 <= r1 < n1
109
110        // Include the remainder with the low bits:
111        // r = a0 + r1R
112        r = U::D::from_lo_hi(a0, r1);
113
114        // Subtract the contribution of the divisor low bits with the estimated quotient
115        let d = q.widen_mul(n0);
116        (r, wrap) = r.overflowing_sub(d);
117
118        // Since `q` is the quotient of dividing with a slightly smaller divisor,
119        // it may be an overapproximation, but is never too small, and similarly,
120        // `r` is now either the correct remainder ...
121        if !wrap {
122            return (q, r);
123        }
124        // ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
125        // and we have to adjust.
126        q -= U::ONE;
127    } else {
128        if true {
    if !(a2 == n1 && a1 < n0) {
        ::core::panicking::panic("assertion failed: a2 == n1 && a1 < n0")
    };
};debug_assert!(a2 == n1 && a1 < n0);
129        // Otherwise, `a2 == n1`, and the estimated quotient would be
130        // `R + (a1 % n1)`, but the correct quotient can't overflow.
131        // We'll start from `q = R = (1 << U::BITS)`,
132        // so `r = aR + a0 - qn = (a - n)R + a0`
133        r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
134        // Since `a < n`, the first decrement is always needed:
135        q = U::MAX; /* R - 1 */
136    }
137
138    (r, wrap) = r.overflowing_add(n);
139    if wrap {
140        return (q, r);
141    }
142
143    // If the remainder still didn't wrap, we need another step.
144    q -= U::ONE;
145    (r, wrap) = r.overflowing_add(n);
146    // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
147    if true {
    if !wrap {
        {
            ::core::panicking::panic_fmt(format_args!("estimated quotient should be off by at most two"));
        }
    };
};debug_assert!(wrap, "estimated quotient should be off by at most two");
148    (q, r)
149}
150
151#[cfg(test)]
152mod test {
153    use super::{HInt, NarrowingDiv};
154
155    #[test]
156    fn inverse_mul() {
157        for x in 0..=u8::MAX {
158            for y in 1..=u8::MAX {
159                let xy = x.widen_mul(y);
160                assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
161                assert_eq!(
162                    (xy + (y - 1) as u16).checked_narrowing_div_rem(y),
163                    Some((x, y - 1))
164                );
165                if y > 1 {
166                    assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
167                    assert_eq!(
168                        (xy + (y - 2) as u16).checked_narrowing_div_rem(y),
169                        Some((x, y - 2))
170                    );
171                }
172            }
173        }
174    }
175}