libm/math/support/modular.rs
1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2
3//! This module provides accelerated modular multiplication by large powers
4//! of two, which is needed for computing floating point remainders in `fmod`
5//! and similar functions.
6//!
7//! To keep the equations somewhat concise, the following conventions are used:
8//! - all integer operations are in the mathematical sense, without overflow
9//! - concatenation means multiplication: `2xq = 2 * x * q`
10//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`
11
12use crate::support::int_traits::NarrowingDiv;
13use crate::support::{DInt, HInt, Int};
14
15/// Compute the remainder `(x << e) % y` with unbounded integers.
16/// Requires `x < 2y` and `y.leading_zeros() >= 2`
17pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
18where
19 U: HInt + Int<Unsigned = U>,
20 U::D: NarrowingDiv,
21{
22 assert!(y <= U::MAX >> 2);
23 assert!(x < (y << 1));
24 let _0 = U::ZERO;
25 let _1 = U::ONE;
26
27 // power of two divisors
28 if (y & (y - _1)).is_zero() {
29 if e < U::BITS {
30 // shift and only keep low bits
31 return (x << e) & (y - _1);
32 } else {
33 // would shift out all the bits
34 return _0;
35 }
36 }
37
38 // Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
39 // to shift the divisor so it has exactly two leading zeros to satisfy
40 // the precondition of `Reducer::new`
41 let s = y.leading_zeros() - 2;
42 e += s;
43 y <<= s;
44
45 // `m: Reducer` keeps track of the remainder `x` in a form that makes it
46 // very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
47 let mut m = Reducer::new(x, y);
48
49 // Use the faster special case with constant `k == U::BITS - 1` while we can
50 while e >= U::BITS - 1 {
51 m.word_reduce();
52 e -= U::BITS - 1;
53 }
54 // Finish with the variable shift operation
55 m.shift_reduce(e);
56
57 // The partial remainder is in `[0, 2y)` ...
58 let r = m.partial_remainder();
59 // ... so check and correct, and compensate for the earlier shift.
60 r.checked_sub(y).unwrap_or(r) >> s
61}
62
63/// Helper type for computing the reductions. The implementation has a number
64/// of seemingly weird choices, but everything is aimed at streamlining
65/// `Reducer::word_reduce` into its current form.
66///
67/// Implicitly contains:
68/// n in (R/8, R/4)
69/// x in [0, 2n)
70/// The value of `n` is fixed for a given `Reducer`,
71/// but the value of `x` is modified by the methods.
72#[derive(Debug, Clone, PartialEq, Eq)]
73struct Reducer<U: HInt> {
74 // m = 2n
75 m: U,
76 // q = (RR/2) / m
77 // r = (RR/2) % m
78 // Then RR/2 = qm + r, where `0 <= r < m`
79 // The value `q` is only needed during construction, so isn't saved.
80 r: U,
81 // The value `x` is implicitly stored as `2 * q * x`:
82 _2xq: U::D,
83}
84
85impl<U> Reducer<U>
86where
87 U: HInt,
88 U: Int<Unsigned = U>,
89{
90 /// Construct a reducer for `(x << _) mod n`.
91 ///
92 /// Requires `R/8 < n < R/4` and `x < 2n`.
93 fn new(x: U, n: U) -> Self
94 where
95 U::D: NarrowingDiv,
96 {
97 let _1 = U::ONE;
98 assert!(n > (_1 << (U::BITS - 3)));
99 assert!(n < (_1 << (U::BITS - 2)));
100 let m = n << 1;
101 assert!(x < m);
102
103 // We need to compute the parameters
104 // `q = (RR/2) / m`
105 // `r = (RR/2) % m`
106
107 // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
108 // it would overflow in `U` if computed directly. Instead, we compute
109 // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
110 // from the dividend, which doesn't change the remainder:
111 // `f = R(R/2 - m) / m`
112 // `r = R(R/2 - m) % m`
113 let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();
114 let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();
115
116 // As `x < m`, `xq < qm <= RR/2`
117 // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
118 let _2x = x + x;
119 let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
120 Self { m, r, _2xq }
121 }
122
123 /// Extract the current remainder `x` in the range `[0, 2n)`
124 fn partial_remainder(&self) -> U {
125 // `RR/2 = qm + r`, where `0 <= r < m`
126 // `2xq = uR + v`, where `0 <= v < R`
127
128 // The goal is to extract the current value of `x` from the value `2xq`
129 // that we actually have. A bit simplified, we could multiply it by `m`
130 // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
131 // We could just round that up to the next multiple of `RR` to get `x`,
132 // but we can avoid having to multiply the full double-wide `2xq` by
133 // making a couple of adjustments:
134
135 // First, let's only use the high half `u` for the product, and
136 // include an additional error term due to the truncation:
137 // `mu = xR - (2xr + mv)/R`
138
139 // Next, show bounds for the error term
140 // `0 <= mv < mR` follows from `0 <= v < R`
141 // `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
142 // Adding those together, we have:
143 // `0 <= (mv + 2xr)/R < 2m`
144 // Which also implies:
145 // `0 < 2m - (mv + 2xr)/R <= 2m < R`
146
147 // For that reason, we can use `u + 2` as the factor to obtain
148 // `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
149 // By the previous inequality, the second term fits neatly in the lower
150 // half, so we get exactly `x` as the high half.
151 let u = self._2xq.hi();
152 let _2 = U::ONE + U::ONE;
153 self.m.widen_mul(u + _2).hi()
154
155 // Additionally, we should ensure that `u + 2` cannot overflow:
156 // Since `x < m` and `2qm <= RR`,
157 // `2xq <= 2q(m-1) <= RR - 2q`
158 // As we also have `q > R`,
159 // `2xq < RR - 2R`
160 // which is sufficient.
161 }
162
163 /// Replace the remainder `x` with `(x << k) - un`,
164 /// for a suitable quotient `u`, which is returned.
165 ///
166 /// Requires that `k < U::BITS`.
167 fn shift_reduce(&mut self, k: u32) -> U {
168 assert!(k < U::BITS);
169
170 // First, split the shifted value:
171 // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
172 let a = self._2xq.hi() >> (U::BITS - 1 - k);
173 let (low, high) = (self._2xq << k).lo_hi();
174 let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
175
176 // Then, subtract `2anq = aqm`:
177 // ```
178 // (2xq << k) - aqm
179 // = aRR/2 + b - aqm
180 // = a(RR/2 - qm) + b
181 // = ar + b
182 // ```
183 self._2xq = a.widen_mul(self.r) + b;
184 a
185
186 // Since `a` is at most the high half of `2xq`, we have
187 // `a + 2 < R` (shown above, in `partial_remainder`)
188 // Using that together with `b < RR/2` and `r < m < R/2`,
189 // we get `(a + 2)r + b < RR`, so
190 // `ar + b < RR - 2r = 2mq`
191 // which shows that the new remainder still satisfies `x < m`.
192 }
193
194 // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
195 // that optimizes especially well. The correspondence is that `a == u` and
196 // `b == (v >> 1).widen_hi()`
197 //
198 /// Replace the remainder `x` with `x(R/2) - un`,
199 /// for a suitable quotient `u`, which is returned.
200 fn word_reduce(&mut self) -> U {
201 // To do so, we replace `2xq = uR + v` with
202 // ```
203 // 2 * (x(R/2) - un) * q
204 // = xqR - 2unq
205 // = xqR - uqm
206 // = uRR/2 + vR/2 - uRR/2 + ur
207 // = ur + (v/2)R
208 // ```
209 let (v, u) = self._2xq.lo_hi();
210 self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
211 u
212
213 // Additional notes:
214 // 1. As `v` is the low bits of `2xq`, it is even and can be halved.
215 // 2. The new remainder is `(xr + mv/2) / R` (see below)
216 // and since `v < R`, `r < m`, `x < m < R/2`,
217 // that is also strictly less than `m`.
218 // ```
219 // (x(R/2) - un)R
220 // = xRR/2 - (m/2)uR
221 // = x(qm + r) - (m/2)(2xq - v)
222 // = xqm + xr - xqm + mv/2
223 // = xr + mv/2
224 // ```
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use crate::support::linear_mul_reduction;
231 use crate::support::modular::Reducer;
232
233 #[test]
234 fn reducer_ops() {
235 for n in 33..=63_u8 {
236 for x in 0..2 * n {
237 let temp = Reducer::new(x, n);
238 let n = n as u32;
239 let x0 = temp.partial_remainder() as u32;
240 assert_eq!(x as u32, x0);
241 for k in 0..=7 {
242 let mut red = temp.clone();
243 let u = red.shift_reduce(k) as u32;
244 let x1 = red.partial_remainder() as u32;
245 assert_eq!(x1, (x0 << k) - u * n);
246 assert!(x1 < 2 * n);
247 assert!((red._2xq as u32).is_multiple_of(2 * x1));
248
249 // `word_reduce` is equivalent to
250 // `shift_reduce(U::BITS - 1)`
251 if k == 7 {
252 let mut alt = temp.clone();
253 let w = alt.word_reduce();
254 assert_eq!(u, w as u32);
255 assert_eq!(alt, red);
256 }
257 }
258 }
259 }
260 }
261 #[test]
262 fn reduction_u8() {
263 for y in 1..64u8 {
264 for x in 0..2 * y {
265 let mut r = x % y;
266 for e in 0..100 {
267 assert_eq!(r, linear_mul_reduction(x, e, y));
268 // maintain the correct expected remainder
269 r <<= 1;
270 if r >= y {
271 r -= y;
272 }
273 }
274 }
275 }
276 }
277 #[test]
278 fn reduction_u128() {
279 assert_eq!(
280 linear_mul_reduction::<u128>(17, 100, 123456789),
281 (17 << 100) % 123456789
282 );
283
284 // power-of-two divisor
285 assert_eq!(
286 linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
287 0xbeef << 100
288 );
289
290 let x = 10_u128.pow(37);
291 let y = 11_u128.pow(36);
292 assert!(x < y);
293 let mut r = x;
294 for e in 0..1000 {
295 assert_eq!(r, linear_mul_reduction(x, e, y));
296 // maintain the correct expected remainder
297 r <<= 1;
298 if r >= y {
299 r -= y;
300 }
301 assert!(r != 0);
302 }
303 }
304}