bigdecimal/
impl_cmp.rs

1//! Implementation of comparison operations
2//!
3//! Comparisons between decimals and decimal refs
4//! are not directly supported as we lose some type
5//! inference features at the savings of a single
6//! '&' character.
7//!
8//! &BigDecimal and BigDecimalRef are comparable.
9//!
10
11use crate::*;
12
13use stdlib::cmp::Ordering;
14use stdlib::iter;
15
16impl PartialEq for BigDecimal
17{
18    fn eq(&self, rhs: &BigDecimal) -> bool {
19        self.to_ref() == rhs.to_ref()
20    }
21}
22
23impl<'rhs, T> PartialEq<T> for BigDecimalRef<'_>
24where
25    T: Into<BigDecimalRef<'rhs>> + Copy
26{
27    fn eq(&self, rhs: &T) -> bool {
28        let rhs: BigDecimalRef<'rhs> = (*rhs).into();
29        check_equality_bigdecimal_ref(*self, rhs)
30    }
31}
32
33fn check_equality_bigdecimal_ref(lhs: BigDecimalRef, rhs: BigDecimalRef) -> bool {
34    match (lhs.sign(), rhs.sign()) {
35        // both zero
36        (Sign::NoSign, Sign::NoSign) => return true,
37        // signs are different
38        (a, b) if a != b => return false,
39        // signs are same, do nothing
40        _ => {}
41    }
42
43    let unscaled_int;
44    let scaled_int;
45    let trailing_zero_count;
46    match arithmetic::checked_diff(lhs.scale, rhs.scale) {
47        (Ordering::Equal, _) => {
48            return lhs.digits == rhs.digits;
49        }
50        (Ordering::Greater, Some(scale_diff)) => {
51            unscaled_int = lhs.digits;
52            scaled_int = rhs.digits;
53            trailing_zero_count = scale_diff;
54        }
55        (Ordering::Less, Some(scale_diff)) => {
56            unscaled_int = rhs.digits;
57            scaled_int = lhs.digits;
58            trailing_zero_count = scale_diff;
59        }
60        _ => {
61            // all other cases imply overflow in difference of scale,
62            // numbers must not be equal
63            return false;
64        }
65    }
66
67    debug_assert_ne!(trailing_zero_count, 0);
68
69    // test if unscaled_int is guaranteed to be less than
70    // scaled_int*10^trailing_zero_count based on highest bit
71    if highest_bit_lessthan_scaled(unscaled_int, scaled_int, trailing_zero_count) {
72        return false;
73    }
74
75    // try compare without allocating
76    if trailing_zero_count < 20 {
77        let pow = ten_to_the_u64(trailing_zero_count as u8);
78
79        let mut a_digits = unscaled_int.iter_u32_digits();
80        let mut b_digits = scaled_int.iter_u32_digits();
81
82        let mut carry = 0;
83        loop {
84            match (a_digits.next(), b_digits.next()) {
85                (Some(next_a), Some(next_b)) => {
86                    let wide_b = match (next_b as u64).checked_mul(pow) {
87                        Some(tmp) => tmp + carry,
88                        None => break,
89                    };
90
91                    let true_b = wide_b as u32;
92
93                    if next_a != true_b {
94                        return false;
95                    }
96
97                    carry = wide_b >> 32;
98                }
99                (None, Some(_)) => {
100                    return false;
101                }
102                (Some(a_digit), None) => {
103                    if a_digit != (carry as u32) {
104                        return false
105                    }
106                    carry = 0;
107                }
108                (None, None) => {
109                    return carry == 0;
110                }
111            }
112        }
113
114        // we broke out of loop due to overflow - compare via allocation
115        let scaled_int = scaled_int * pow;
116        return &scaled_int == unscaled_int;
117    }
118
119    let trailing_zero_count = trailing_zero_count.to_usize().unwrap();
120    let unscaled_digits = unscaled_int.to_radix_le(10);
121
122    if trailing_zero_count > unscaled_digits.len() {
123        return false;
124    }
125
126    // split into digits below the other value, and digits overlapping
127    let (low_digits, overlap_digits) = unscaled_digits.split_at(trailing_zero_count);
128
129    // if any of the low digits are zero, they are not equal
130    if low_digits.iter().any(|&d| d != 0) {
131        return false;
132    }
133
134    let scaled_digits = scaled_int.to_radix_le(10);
135
136    // different lengths with trailing zeros
137    if overlap_digits.len() != scaled_digits.len() {
138        return false;
139    }
140
141    // return true if all digits are the same
142    overlap_digits.iter().zip(scaled_digits.iter()).all(|(digit_a, digit_b)| digit_a == digit_b)
143}
144
145
146impl PartialOrd for BigDecimal {
147    #[inline]
148    fn partial_cmp(&self, other: &BigDecimal) -> Option<Ordering> {
149        Some(self.cmp(other))
150    }
151}
152
153impl PartialOrd for BigDecimalRef<'_> {
154    fn partial_cmp(&self, other: &BigDecimalRef<'_>) -> Option<Ordering> {
155        Some(self.cmp(other))
156    }
157}
158
159
160impl Ord for BigDecimal {
161    #[inline]
162    fn cmp(&self, other: &BigDecimal) -> Ordering {
163        self.to_ref().cmp(&other.to_ref())
164    }
165}
166
167impl Ord for BigDecimalRef<'_> {
168    /// Complete ordering implementation for BigDecimal
169    ///
170    /// # Example
171    ///
172    /// ```
173    /// use std::str::FromStr;
174    ///
175    /// let a = bigdecimal::BigDecimal::from_str("-1").unwrap();
176    /// let b = bigdecimal::BigDecimal::from_str("1").unwrap();
177    /// assert!(a < b);
178    /// assert!(b > a);
179    /// let c = bigdecimal::BigDecimal::from_str("1").unwrap();
180    /// assert!(b >= c);
181    /// assert!(c >= b);
182    /// let d = bigdecimal::BigDecimal::from_str("10.0").unwrap();
183    /// assert!(d > c);
184    /// let e = bigdecimal::BigDecimal::from_str(".5").unwrap();
185    /// assert!(e < c);
186    /// ```
187    #[inline]
188    fn cmp(&self, other: &BigDecimalRef) -> Ordering {
189        use Ordering::*;
190
191        let scmp = self.sign().cmp(&other.sign());
192        if scmp != Ordering::Equal {
193            return scmp;
194        }
195
196        if self.sign() == Sign::NoSign {
197            return Ordering::Equal;
198        }
199
200        let result = match arithmetic::checked_diff(self.scale, other.scale) {
201            (Greater, Some(scale_diff)) | (Equal, Some(scale_diff)) => {
202                compare_scaled_biguints(self.digits, other.digits, scale_diff)
203            }
204            (Less, Some(scale_diff)) => {
205                compare_scaled_biguints(other.digits, self.digits, scale_diff).reverse()
206            }
207            (res, None) => {
208                // The difference in scale does not fit in a u64,
209                // we can safely assume the value of digits do not matter
210                // (unless we have a 2^64 (i.e. ~16 exabyte) long number
211
212                // larger scale means smaller number, reverse this ordering
213                res.reverse()
214            }
215        };
216
217        if other.sign == Sign::Minus {
218            result.reverse()
219        } else {
220            result
221        }
222    }
223}
224
225
226/// compare scaled uints: a <=> b * 10^{scale_diff}
227///
228fn compare_scaled_biguints(a: &BigUint, b: &BigUint, scale_diff: u64) -> Ordering {
229    use Ordering::*;
230
231    if scale_diff == 0 {
232        return a.cmp(b);
233    }
234
235    // check if highest bit of a is less than b * 10^scale_diff
236    if highest_bit_lessthan_scaled(a, b, scale_diff) {
237        return Ordering::Less;
238    }
239
240    // if biguints fit it u64 or u128, compare using those (avoiding allocations)
241    if let Some(result) = compare_scalar_biguints(a, b, scale_diff) {
242        return result;
243    }
244
245    let a_digit_count = count_decimal_digits_uint(a);
246    let b_digit_count = count_decimal_digits_uint(b);
247
248    let digit_count_cmp = a_digit_count.cmp(&(b_digit_count + scale_diff));
249    if digit_count_cmp != Equal {
250        return digit_count_cmp;
251    }
252
253    let a_digits = a.to_radix_le(10);
254    let b_digits = b.to_radix_le(10);
255
256    debug_assert_eq!(a_digits.len(), a_digit_count as usize);
257    debug_assert_eq!(b_digits.len(), b_digit_count as usize);
258
259    let mut a_it = a_digits.iter().rev();
260    let mut b_it = b_digits.iter().rev();
261
262    loop {
263        match (a_it.next(), b_it.next()) {
264            (Some(ai), Some(bi)) => {
265                match ai.cmp(bi) {
266                    Equal => continue,
267                    result => return result,
268                }
269            }
270            (Some(&ai), None) => {
271                if ai == 0 && a_it.all(Zero::is_zero) {
272                    return Equal;
273                } else {
274                    return Greater;
275                }
276            }
277            (None, Some(&bi)) => {
278                if bi == 0 && b_it.all(Zero::is_zero) {
279                    return Equal;
280                } else {
281                    return Less;
282                }
283            }
284            (None, None) => {
285                return Equal;
286            }
287        }
288    }
289}
290
291/// Try fitting biguints into primitive integers, using those for ordering if possible
292fn compare_scalar_biguints(a: &BigUint, b: &BigUint, scale_diff: u64) -> Option<Ordering> {
293    let scale_diff = scale_diff.to_usize()?;
294
295    // try u64, then u128
296    compare_scaled_uints::<u64>(a, b, scale_diff)
297    .or_else(|| compare_scaled_uints::<u128>(a, b, scale_diff))
298}
299
300/// Implementation comparing biguints cast to generic type
301fn compare_scaled_uints<'a, T>(a: &'a BigUint, b: &'a BigUint, scale_diff: usize) -> Option<Ordering>
302where
303    T: num_traits::PrimInt + TryFrom<&'a BigUint>
304{
305    let ten = T::from(10).unwrap();
306
307    let a = T::try_from(a).ok();
308    let b = T::try_from(b).ok().and_then(
309                |b| num_traits::checked_pow(ten, scale_diff).and_then(
310                    |p| b.checked_mul(&p)));
311
312    match (a, b) {
313        (Some(a), Some(scaled_b)) => Some(a.cmp(&scaled_b)),
314        // if scaled_b doesn't fit in size T, while 'a' does, then a is certainly less
315        (Some(_), None) => Some(Ordering::Less),
316        // if a doesn't fit in size T, while 'scaled_b' does, then a is certainly greater
317        (None, Some(_)) => Some(Ordering::Greater),
318        // neither fits, cannot determine relative size
319        (None, None) => None,
320    }
321}
322
323/// Return highest_bit(a) < highest_bit(b * 10^{scale})
324///
325/// Used for optimization when comparing scaled integers
326///
327/// ```math
328/// a < b * 10^{scale}
329/// log(a) < log(b) + scale * log(10)
330/// ```
331///
332fn highest_bit_lessthan_scaled(a: &BigUint, b: &BigUint, scale: u64) -> bool {
333    let a_bits = a.bits();
334    let b_bits = b.bits();
335    if a_bits < b_bits {
336        return true;
337    }
338    let log_scale = LOG2_10 * scale as f64;
339    match b_bits.checked_add(log_scale as u64) {
340        Some(scaled_b_bit) => a_bits < scaled_b_bit,
341        None => true, // overflowing u64 means we are definitely bigger
342    }
343}
344
345#[cfg(test)]
346mod test {
347    use super::*;
348
349    mod compare_scaled_biguints {
350        use super::*;
351
352        macro_rules! impl_test {
353            ($name:ident: $a:literal > $b:literal e $e:literal) => {
354                impl_test!($name: $a Greater $b e $e);
355            };
356            ($name:ident: $a:literal < $b:literal e $e:literal) => {
357                impl_test!($name: $a Less $b e $e);
358            };
359            ($name:ident: $a:literal = $b:literal e $e:literal) => {
360                impl_test!($name: $a Equal $b e $e);
361            };
362            ($name:ident: $a:literal $op:ident $b:literal e $e:literal) => {
363                #[test]
364                fn $name() {
365                    let a: BigUint = $a.parse().unwrap();
366                    let b: BigUint = $b.parse().unwrap();
367
368                    let result = compare_scaled_biguints(&a, &b, $e);
369                    assert_eq!(result, Ordering::$op);
370                }
371            };
372        }
373
374        impl_test!(case_500_51e1: "500" < "51" e 1);
375        impl_test!(case_500_44e1: "500" > "44" e 1);
376        impl_test!(case_5000_50e2: "5000" = "50" e 2);
377        impl_test!(case_1234e9_12345e9: "1234000000000" < "12345" e 9);
378        impl_test!(case_1116xx459_759xx717e2: "1116386634271380982470843247639640260491505327092723527088459" < "759522625769651746138617259189939751893902453291243506584717" e 2);
379    }
380
381    /// Test that large-magnitidue exponentials will not crash
382    #[test]
383    fn test_cmp_on_exp_boundaries() {
384        let a = BigDecimal::new(1.into(), i64::MAX);
385        let z = BigDecimal::new(1.into(), i64::MIN);
386        assert_ne!(a, z);
387        assert_ne!(z, a);
388
389        assert!(a < z);
390
391        assert_eq!(a, a);
392        assert_eq!(z, z);
393    }
394
395    mod ord {
396        use super::*;
397
398        macro_rules! impl_test {
399            ($name:ident: $a:literal < $b:literal) => {
400                #[test]
401                fn $name() {
402                    let a: BigDecimal = $a.parse().unwrap();
403                    let b: BigDecimal = $b.parse().unwrap();
404
405                    assert!(&a < &b);
406                    assert!(&b > &a);
407                    assert_ne!(a, b);
408                }
409            };
410        }
411
412        impl_test!(case_diff_signs: "-1" < "1");
413        impl_test!(case_n1_0: "-1" < "0");
414        impl_test!(case_0_1: "0" < "1");
415        impl_test!(case_1d2345_1d2346: "1.2345" < "1.2346");
416        impl_test!(case_compare_extreme: "1e-9223372036854775807" < "1");
417        impl_test!(case_compare_extremes: "1e-9223372036854775807" < "1e9223372036854775807");
418        impl_test!(case_small_difference: "472697816888807260.1604" < "472697816888807260.16040000000000000000001");
419        impl_test!(case_very_small_diff: "-1.0000000000000000000000000000000000000000000000000001" < "-1");
420
421        impl_test!(case_1_2p128: "1" < "340282366920938463463374607431768211455");
422        impl_test!(case_1_1e39: "1000000000000000000000000000000000000000" < "1e41");
423
424        impl_test!(case_1d414xxx573: "1.414213562373095048801688724209698078569671875376948073176679730000000000000000000000000000000000000" < "1.41421356237309504880168872420969807856967187537694807317667974000000000");
425        impl_test!(case_11d414xxx573: "1.414213562373095048801688724209698078569671875376948073176679730000000000000000000000000000000000000" < "11.41421356237309504880168872420969807856967187537694807317667974000000000");
426    }
427
428    mod eq {
429        use super::*;
430
431        macro_rules! impl_test {
432            ($name:ident: $a:literal = $b:literal) => {
433                #[test]
434                fn $name() {
435                    let a: BigDecimal = $a.parse().unwrap();
436                    let b: BigDecimal = $b.parse().unwrap();
437
438                    assert_eq!(&a, &b);
439                    assert_eq!(a, b);
440                }
441            };
442        }
443
444        impl_test!(case_zero: "0" = "0.00");
445        impl_test!(case_1_1d00: "1" = "1.00");
446        impl_test!(case_n1_n1000en3: "-1" = "-1000e-3");
447        impl_test!(case_0d000034500_345en7: "0.000034500" = "345e-7");
448    }
449
450    #[test]
451    fn test_borrow_neg_cmp() {
452        let x: BigDecimal = "1514932018891593.916341142773".parse().unwrap();
453        let y: BigDecimal = "1514932018891593916341142773e-12".parse().unwrap();
454
455        assert_eq!(x, y);
456
457        let x_ref = x.to_ref();
458        assert_eq!(x_ref, &y);
459        assert_ne!(x_ref.neg(), x_ref);
460        assert_eq!(x_ref.neg().neg(), x_ref);
461    }
462
463    #[cfg(property_tests)]
464    mod prop {
465        use super::*;
466        use proptest::prelude::*;
467
468        proptest! {
469            #![proptest_config(ProptestConfig { cases: 5000, ..Default::default() })]
470
471            #[test]
472            fn cmp_matches_f64(
473                f in proptest::num::f64::NORMAL | proptest::num::f64::SUBNORMAL | proptest::num::f64::ZERO,
474                g in proptest::num::f64::NORMAL | proptest::num::f64::SUBNORMAL | proptest::num::f64::ZERO
475            ) {
476                let a: BigDecimal = BigDecimal::from_f64(f).unwrap();
477                let b: BigDecimal = BigDecimal::from_f64(g).unwrap();
478
479                let expected = PartialOrd::partial_cmp(&f, &g).unwrap();
480                let value = a.cmp(&b);
481
482                prop_assert_eq!(expected, value)
483            }
484        }
485    }
486}