1use crate::support::{
45    CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status, cold_path,
46};
47
48#[inline]
49pub fn sqrt<F>(x: F) -> F
50where
51    F: Float + SqrtHelper,
52    F::Int: HInt,
53    F::Int: From<u8>,
54    F::Int: From<F::ISet2>,
55    F::Int: CastInto<F::ISet1>,
56    F::Int: CastInto<F::ISet2>,
57    u32: CastInto<F::Int>,
58{
59    sqrt_round(x, Round::Nearest).val
60}
61
62#[inline]
63pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
64where
65    F: Float + SqrtHelper,
66    F::Int: HInt,
67    F::Int: From<u8>,
68    F::Int: From<F::ISet2>,
69    F::Int: CastInto<F::ISet1>,
70    F::Int: CastInto<F::ISet2>,
71    u32: CastInto<F::Int>,
72{
73    let zero = IntTy::<F>::ZERO;
74    let one = IntTy::<F>::ONE;
75
76    let mut ix = x.to_bits();
77
78    let noshift = F::BITS <= u32::BITS;
81    let (mut top, special_case) = if noshift {
82        let exp_lsb = one << F::SIG_BITS;
83        let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
84        (Exp::NoShift(()), special_case)
85    } else {
86        let top = u32::cast_from(ix >> F::SIG_BITS);
87        let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
88        (Exp::Shifted(top), special_case)
89    };
90
91    if special_case {
93        cold_path();
94
95        if ix << 1 == zero {
97            return FpResult::ok(x);
98        }
99
100        if ix == F::EXP_MASK {
102            return FpResult::ok(x);
103        }
104
105        if ix > F::EXP_MASK {
107            return FpResult::new(F::NAN, Status::INVALID);
108        }
109
110        let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
112        ix = scaled.to_bits();
113        match top {
114            Exp::Shifted(ref mut v) => {
115                *v = scaled.ex();
116                *v = (*v).wrapping_sub(F::SIG_BITS);
117            }
118            Exp::NoShift(()) => {
119                ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
120            }
121        }
122    }
123
124    let (m_u2, exp) = match top {
129        Exp::Shifted(top) => {
130            let mut e = top;
132            let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
134            let even = (e & 1) != 0;
135            if even {
136                m_u2 >>= 1;
137            }
138            e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
139            (m_u2, Exp::Shifted(e))
140        }
141        Exp::NoShift(()) => {
142            let even = ix & (one << F::SIG_BITS) != zero;
143
144            let mut e_noshift = ix >> 1;
146            e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
148            e_noshift &= F::EXP_MASK;
149
150            let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
151            let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
152            let m_u2 = if even { m0 } else { m1 };
153
154            (m_u2, Exp::NoShift(e_noshift))
155        }
156    };
157
158    let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
160
161    let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
164    let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
165
166    let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
168
169    let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
171    let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
172    let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
173
174    let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
176    let s_u2: F::Int = m_u2;
177    let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
178
179    let mut m = s_u2 >> (F::EXP_BITS - 2);
181
182    let shift = 2 * F::SIG_BITS - (F::BITS - 2);
202
203    let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
205    let d1 = m.wrapping_sub(d0);
207    m += d1 >> (F::BITS - 1);
208    m &= F::SIG_MASK;
209
210    match exp {
211        Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
212        Exp::NoShift(e) => m |= e,
213    };
214
215    let mut y = F::from_bits(m);
216
217    if F::BITS > 16 {
219        let d2 = d1.wrapping_add(m).wrapping_add(one);
222        let mut tiny = if d2 == zero {
223            cold_path();
224            zero
225        } else {
226            F::IMPLICIT_BIT
227        };
228
229        tiny |= (d1 ^ d2) & F::SIGN_MASK;
230        let t = F::from_bits(tiny);
231        y = y + t;
232    }
233
234    FpResult::ok(y)
235}
236
237fn wmulh<I: HInt>(a: I, b: I) -> I {
239    a.widen_mul(b).hi()
240}
241
242#[inline]
253fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
254where
255    F: SqrtHelper,
256    I: HInt + From<u8>,
257{
258    let three_u2 = I::from(0b11u8) << (I::BITS - 2);
259    let mut u_u0 = r_u0;
260
261    for i in 0..count {
262        s_u2 = wmulh(s_u2, u_u0);
265
266        if i > 0 && (!final_set || i + 1 < count) {
275            s_u2 <<= 1;
276        }
277
278        let d_u2 = wmulh(s_u2, r_u0);
280        u_u0 = three_u2.wrapping_sub(d_u2);
281
282        r_u0 = wmulh(r_u0, u_u0) << 1;
284    }
285
286    (r_u0, s_u2)
287}
288
289enum Exp<T> {
292    Shifted(u32),
294    NoShift(T),
296}
297
298pub trait SqrtHelper: Float {
300    type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
302    type ISet2: HInt + From<Self::ISet1> + From<u8>;
304
305    const SET1_ROUNDS: u32 = 0;
307    const SET2_ROUNDS: u32 = 0;
309    const FINAL_ROUNDS: u32;
311}
312
313#[cfg(f16_enabled)]
314impl SqrtHelper for f16 {
315    type ISet1 = u16; type ISet2 = u16; const FINAL_ROUNDS: u32 = 2;
319}
320
321impl SqrtHelper for f32 {
322    type ISet1 = u32; type ISet2 = u32; const FINAL_ROUNDS: u32 = 3;
326}
327
328impl SqrtHelper for f64 {
329    type ISet1 = u32; type ISet2 = u32;
331
332    const SET2_ROUNDS: u32 = 2;
333    const FINAL_ROUNDS: u32 = 2;
334}
335
336#[cfg(f128_enabled)]
337impl SqrtHelper for f128 {
338    type ISet1 = u32;
339    type ISet2 = u64;
340
341    const SET1_ROUNDS: u32 = 1;
342    const SET2_ROUNDS: u32 = 2;
343    const FINAL_ROUNDS: u32 = 2;
344}
345
346#[rustfmt::skip]
350static RSQRT_TAB: [u16; 128] = [
351    0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
352    0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
353    0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
354    0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
355    0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
356    0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
357    0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
358    0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
359    0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
360    0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
361    0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
362    0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
363    0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
364    0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
365    0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
366    0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
367];
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    fn spec_test<F>()
375    where
376        F: Float + SqrtHelper,
377        F::Int: HInt,
378        F::Int: From<u8>,
379        F::Int: From<F::ISet2>,
380        F::Int: CastInto<F::ISet1>,
381        F::Int: CastInto<F::ISet2>,
382        u32: CastInto<F::Int>,
383    {
384        let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
386
387        let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
389
390        for x in nan {
391            let FpResult { val, status } = sqrt_round(x, Round::Nearest);
392            assert!(val.is_nan());
393            assert!(status == Status::INVALID);
394        }
395
396        for x in roundtrip {
397            let FpResult { val, status } = sqrt_round(x, Round::Nearest);
398            assert_biteq!(val, x);
399            assert!(status == Status::OK);
400        }
401    }
402
403    #[test]
404    #[cfg(f16_enabled)]
405    fn sanity_check_f16() {
406        assert_biteq!(sqrt(100.0f16), 10.0);
407        assert_biteq!(sqrt(4.0f16), 2.0);
408    }
409
410    #[test]
411    #[cfg(f16_enabled)]
412    fn spec_tests_f16() {
413        spec_test::<f16>();
414    }
415
416    #[test]
417    #[cfg(f16_enabled)]
418    #[allow(clippy::approx_constant)]
419    fn conformance_tests_f16() {
420        let cases = [
421            (f16::PI, 0x3f17_u16),
422            (f16::from_bits(0x70e2), 0x5640_u16),
425            (f16::from_bits(0x0000000f), 0x13bf_u16),
426            (f16::INFINITY, f16::INFINITY.to_bits()),
427        ];
428
429        for (input, output) in cases {
430            assert_biteq!(
431                sqrt(input),
432                f16::from_bits(output),
433                "input: {input:?} ({:#018x})",
434                input.to_bits()
435            );
436        }
437    }
438
439    #[test]
440    fn sanity_check_f32() {
441        assert_biteq!(sqrt(100.0f32), 10.0);
442        assert_biteq!(sqrt(4.0f32), 2.0);
443    }
444
445    #[test]
446    fn spec_tests_f32() {
447        spec_test::<f32>();
448    }
449
450    #[test]
451    #[allow(clippy::approx_constant)]
452    fn conformance_tests_f32() {
453        let cases = [
454            (f32::PI, 0x3fe2dfc5_u32),
455            (10000.0f32, 0x42c80000_u32),
456            (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
457            (f32::INFINITY, f32::INFINITY.to_bits()),
458        ];
459
460        for (input, output) in cases {
461            assert_biteq!(
462                sqrt(input),
463                f32::from_bits(output),
464                "input: {input:?} ({:#018x})",
465                input.to_bits()
466            );
467        }
468    }
469
470    #[test]
471    fn sanity_check_f64() {
472        assert_biteq!(sqrt(100.0f64), 10.0);
473        assert_biteq!(sqrt(4.0f64), 2.0);
474    }
475
476    #[test]
477    fn spec_tests_f64() {
478        spec_test::<f64>();
479    }
480
481    #[test]
482    #[allow(clippy::approx_constant)]
483    fn conformance_tests_f64() {
484        let cases = [
485            (f64::PI, 0x3ffc5bf891b4ef6a_u64),
486            (10000.0, 0x4059000000000000_u64),
487            (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
488            (f64::INFINITY, f64::INFINITY.to_bits()),
489        ];
490
491        for (input, output) in cases {
492            assert_biteq!(
493                sqrt(input),
494                f64::from_bits(output),
495                "input: {input:?} ({:#018x})",
496                input.to_bits()
497            );
498        }
499    }
500
501    #[test]
502    #[cfg(f128_enabled)]
503    fn sanity_check_f128() {
504        assert_biteq!(sqrt(100.0f128), 10.0);
505        assert_biteq!(sqrt(4.0f128), 2.0);
506    }
507
508    #[test]
509    #[cfg(f128_enabled)]
510    fn spec_tests_f128() {
511        spec_test::<f128>();
512    }
513
514    #[test]
515    #[cfg(f128_enabled)]
516    #[allow(clippy::approx_constant)]
517    fn conformance_tests_f128() {
518        let cases = [
519            (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
520            (
522                f128::from_bits(0x400c3880000000000000000000000000),
523                0x40059000000000000000000000000000_u128,
524            ),
525            (
526                f128::from_bits(0x0000000f),
527                0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
528            ),
529            (f128::INFINITY, f128::INFINITY.to_bits()),
530        ];
531
532        for (input, output) in cases {
533            assert_biteq!(
534                sqrt(input),
535                f128::from_bits(output),
536                "input: {input:?} ({:#018x})",
537                input.to_bits()
538            );
539        }
540    }
541}