Skip to main content

libm/math/support/
modular.rs

1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2
3//! This module provides accelerated modular multiplication by large powers
4//! of two, which is needed for computing floating point remainders in `fmod`
5//! and similar functions.
6//!
7//! To keep the equations somewhat concise, the following conventions are used:
8//!  - all integer operations are in the mathematical sense, without overflow
9//!  - concatenation means multiplication: `2xq = 2 * x * q`
10//!  - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`
11
12use crate::support::int_traits::NarrowingDiv;
13use crate::support::{DInt, HInt, Int};
14
15/// Compute the remainder `(x << e) % y` with unbounded integers.
16/// Requires `x < 2y` and `y.leading_zeros() >= 2`
17pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
18where
19    U: HInt + Int<Unsigned = U>,
20    U::D: NarrowingDiv,
21{
22    if !(y <= U::MAX >> 2) {
    ::core::panicking::panic("assertion failed: y <= U::MAX >> 2")
};assert!(y <= U::MAX >> 2);
23    if !(x < (y << 1)) {
    ::core::panicking::panic("assertion failed: x < (y << 1)")
};assert!(x < (y << 1));
24    let _0 = U::ZERO;
25    let _1 = U::ONE;
26
27    // power of two divisors
28    if (y & (y - _1)).is_zero() {
29        if e < U::BITS {
30            // shift and only keep low bits
31            return (x << e) & (y - _1);
32        } else {
33            // would shift out all the bits
34            return _0;
35        }
36    }
37
38    // Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
39    // to shift the divisor so it has exactly two leading zeros to satisfy
40    // the precondition of `Reducer::new`
41    let s = y.leading_zeros() - 2;
42    e += s;
43    y <<= s;
44
45    // `m: Reducer` keeps track of the remainder `x` in a form that makes it
46    //  very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
47    let mut m = Reducer::new(x, y);
48
49    // Use the faster special case with constant `k == U::BITS - 1` while we can
50    while e >= U::BITS - 1 {
51        m.word_reduce();
52        e -= U::BITS - 1;
53    }
54    // Finish with the variable shift operation
55    m.shift_reduce(e);
56
57    // The partial remainder is in `[0, 2y)` ...
58    let r = m.partial_remainder();
59    // ... so check and correct, and compensate for the earlier shift.
60    r.checked_sub(y).unwrap_or(r) >> s
61}
62
63/// Helper type for computing the reductions. The implementation has a number
64/// of seemingly weird choices, but everything is aimed at streamlining
65/// `Reducer::word_reduce` into its current form.
66///
67/// Implicitly contains:
68///  n in (R/8, R/4)
69///  x in [0, 2n)
70/// The value of `n` is fixed for a given `Reducer`,
71/// but the value of `x` is modified by the methods.
72#[derive(#[automatically_derived]
impl<U: ::core::fmt::Debug + HInt> ::core::fmt::Debug for Reducer<U> where
    U::D: ::core::fmt::Debug {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field3_finish(f, "Reducer", "m",
            &self.m, "r", &self.r, "_2xq", &&self._2xq)
    }
}Debug, #[automatically_derived]
impl<U: ::core::clone::Clone + HInt> ::core::clone::Clone for Reducer<U> where
    U::D: ::core::clone::Clone {
    #[inline]
    fn clone(&self) -> Reducer<U> {
        Reducer {
            m: ::core::clone::Clone::clone(&self.m),
            r: ::core::clone::Clone::clone(&self.r),
            _2xq: ::core::clone::Clone::clone(&self._2xq),
        }
    }
}Clone, #[automatically_derived]
impl<U: ::core::cmp::PartialEq + HInt> ::core::cmp::PartialEq for Reducer<U>
    where U::D: ::core::cmp::PartialEq {
    #[inline]
    fn eq(&self, other: &Reducer<U>) -> bool {
        self.m == other.m && self.r == other.r && self._2xq == other._2xq
    }
}PartialEq, #[automatically_derived]
impl<U: ::core::cmp::Eq + HInt> ::core::cmp::Eq for Reducer<U> where
    U::D: ::core::cmp::Eq {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<U>;
        let _: ::core::cmp::AssertParamIsEq<U::D>;
    }
}Eq)]
73struct Reducer<U: HInt> {
74    // m = 2n
75    m: U,
76    // q = (RR/2) / m
77    // r = (RR/2) % m
78    // Then RR/2 = qm + r, where `0 <= r < m`
79    // The value `q` is only needed during construction, so isn't saved.
80    r: U,
81    // The value `x` is implicitly stored as `2 * q * x`:
82    _2xq: U::D,
83}
84
85impl<U> Reducer<U>
86where
87    U: HInt,
88    U: Int<Unsigned = U>,
89{
90    /// Construct a reducer for `(x << _) mod n`.
91    ///
92    /// Requires `R/8 < n < R/4` and `x < 2n`.
93    fn new(x: U, n: U) -> Self
94    where
95        U::D: NarrowingDiv,
96    {
97        let _1 = U::ONE;
98        if !(n > (_1 << (U::BITS - 3))) {
    ::core::panicking::panic("assertion failed: n > (_1 << (U::BITS - 3))")
};assert!(n > (_1 << (U::BITS - 3)));
99        if !(n < (_1 << (U::BITS - 2))) {
    ::core::panicking::panic("assertion failed: n < (_1 << (U::BITS - 2))")
};assert!(n < (_1 << (U::BITS - 2)));
100        let m = n << 1;
101        if !(x < m) { ::core::panicking::panic("assertion failed: x < m") };assert!(x < m);
102
103        // We need to compute the parameters
104        // `q = (RR/2) / m`
105        // `r = (RR/2) % m`
106
107        // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
108        // it would overflow in `U` if computed directly. Instead, we compute
109        // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
110        // from the dividend, which doesn't change the remainder:
111        // `f = R(R/2 - m) / m`
112        // `r = R(R/2 - m) % m`
113        let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();
114        let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();
115
116        // As `x < m`, `xq < qm <= RR/2`
117        // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
118        let _2x = x + x;
119        let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
120        Self { m, r, _2xq }
121    }
122
123    /// Extract the current remainder `x` in the range `[0, 2n)`
124    fn partial_remainder(&self) -> U {
125        // `RR/2 = qm + r`, where `0 <= r < m`
126        // `2xq = uR + v`,  where `0 <= v < R`
127
128        // The goal is to extract the current value of `x` from the value `2xq`
129        // that we actually have. A bit simplified, we could multiply it by `m`
130        // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
131        // We could just round that up to the next multiple of `RR` to get `x`,
132        // but we can avoid having to multiply the full double-wide `2xq` by
133        // making a couple of adjustments:
134
135        // First, let's only use the high half `u` for the product, and
136        // include an additional error term due to the truncation:
137        //  `mu = xR - (2xr + mv)/R`
138
139        // Next, show bounds for the error term
140        //  `0 <= mv < mR` follows from `0 <= v < R`
141        //  `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
142        // Adding those together, we have:
143        //  `0 <= (mv + 2xr)/R < 2m`
144        // Which also implies:
145        //  `0 < 2m - (mv + 2xr)/R <= 2m < R`
146
147        // For that reason, we can use `u + 2` as the factor to obtain
148        //  `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
149        // By the previous inequality, the second term fits neatly in the lower
150        // half, so we get exactly `x` as the high half.
151        let u = self._2xq.hi();
152        let _2 = U::ONE + U::ONE;
153        self.m.widen_mul(u + _2).hi()
154
155        // Additionally, we should ensure that `u + 2` cannot overflow:
156        // Since `x < m` and `2qm <= RR`,
157        //  `2xq <= 2q(m-1) <= RR - 2q`
158        // As we also have `q > R`,
159        //  `2xq < RR - 2R`
160        // which is sufficient.
161    }
162
163    /// Replace the remainder `x` with `(x << k) - un`,
164    /// for a suitable quotient `u`, which is returned.
165    ///
166    /// Requires that `k < U::BITS`.
167    fn shift_reduce(&mut self, k: u32) -> U {
168        if !(k < U::BITS) {
    ::core::panicking::panic("assertion failed: k < U::BITS")
};assert!(k < U::BITS);
169
170        // First, split the shifted value:
171        // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
172        let a = self._2xq.hi() >> (U::BITS - 1 - k);
173        let (low, high) = (self._2xq << k).lo_hi();
174        let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
175
176        // Then, subtract `2anq = aqm`:
177        // ```
178        // (2xq << k) - aqm
179        // = aRR/2 + b - aqm
180        // = a(RR/2 - qm) + b
181        // = ar + b
182        // ```
183        self._2xq = a.widen_mul(self.r) + b;
184        a
185
186        // Since `a` is at most the high half of `2xq`, we have
187        //  `a + 2 < R` (shown above, in `partial_remainder`)
188        // Using that together with `b < RR/2` and `r < m < R/2`,
189        // we get `(a + 2)r + b < RR`, so
190        //  `ar + b < RR - 2r = 2mq`
191        // which shows that the new remainder still satisfies `x < m`.
192    }
193
194    // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
195    // that optimizes especially well. The correspondence is that `a == u` and
196    //  `b == (v >> 1).widen_hi()`
197    //
198    /// Replace the remainder `x` with `x(R/2) - un`,
199    /// for a suitable quotient `u`, which is returned.
200    fn word_reduce(&mut self) -> U {
201        // To do so, we replace `2xq = uR + v` with
202        // ```
203        // 2 * (x(R/2) - un) * q
204        // = xqR - 2unq
205        // = xqR - uqm
206        // = uRR/2 + vR/2 - uRR/2 + ur
207        // = ur + (v/2)R
208        // ```
209        let (v, u) = self._2xq.lo_hi();
210        self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
211        u
212
213        // Additional notes:
214        //  1. As `v` is the low bits of `2xq`, it is even and can be halved.
215        //  2. The new remainder is `(xr + mv/2) / R` (see below)
216        //      and since `v < R`, `r < m`, `x < m < R/2`,
217        //      that is also strictly less than `m`.
218        // ```
219        // (x(R/2) - un)R
220        //      = xRR/2 - (m/2)uR
221        //      = x(qm + r) - (m/2)(2xq - v)
222        //      = xqm + xr - xqm + mv/2
223        //      = xr + mv/2
224        // ```
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use crate::support::linear_mul_reduction;
231    use crate::support::modular::Reducer;
232
233    #[test]
234    fn reducer_ops() {
235        for n in 33..=63_u8 {
236            for x in 0..2 * n {
237                let temp = Reducer::new(x, n);
238                let n = n as u32;
239                let x0 = temp.partial_remainder() as u32;
240                assert_eq!(x as u32, x0);
241                for k in 0..=7 {
242                    let mut red = temp.clone();
243                    let u = red.shift_reduce(k) as u32;
244                    let x1 = red.partial_remainder() as u32;
245                    assert_eq!(x1, (x0 << k) - u * n);
246                    assert!(x1 < 2 * n);
247                    assert!((red._2xq as u32).is_multiple_of(2 * x1));
248
249                    // `word_reduce` is equivalent to
250                    // `shift_reduce(U::BITS - 1)`
251                    if k == 7 {
252                        let mut alt = temp.clone();
253                        let w = alt.word_reduce();
254                        assert_eq!(u, w as u32);
255                        assert_eq!(alt, red);
256                    }
257                }
258            }
259        }
260    }
261    #[test]
262    fn reduction_u8() {
263        for y in 1..64u8 {
264            for x in 0..2 * y {
265                let mut r = x % y;
266                for e in 0..100 {
267                    assert_eq!(r, linear_mul_reduction(x, e, y));
268                    // maintain the correct expected remainder
269                    r <<= 1;
270                    if r >= y {
271                        r -= y;
272                    }
273                }
274            }
275        }
276    }
277    #[test]
278    fn reduction_u128() {
279        assert_eq!(
280            linear_mul_reduction::<u128>(17, 100, 123456789),
281            (17 << 100) % 123456789
282        );
283
284        // power-of-two divisor
285        assert_eq!(
286            linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
287            0xbeef << 100
288        );
289
290        let x = 10_u128.pow(37);
291        let y = 11_u128.pow(36);
292        assert!(x < y);
293        let mut r = x;
294        for e in 0..1000 {
295            assert_eq!(r, linear_mul_reduction(x, e, y));
296            // maintain the correct expected remainder
297            r <<= 1;
298            if r >= y {
299                r -= y;
300            }
301            assert!(r != 0);
302        }
303    }
304}