libm/math/sqrt.rs
1/* origin: FreeBSD /usr/src/lib/msun/src/e_sqrt.c */
2/*
3 * ====================================================
4 * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
5 *
6 * Developed at SunSoft, a Sun Microsystems, Inc. business.
7 * Permission to use, copy, modify, and distribute this
8 * software is freely granted, provided that this notice
9 * is preserved.
10 * ====================================================
11 */
12/* sqrt(x)
13 * Return correctly rounded sqrt.
14 * ------------------------------------------
15 * | Use the hardware sqrt if you have one |
16 * ------------------------------------------
17 * Method:
18 * Bit by bit method using integer arithmetic. (Slow, but portable)
19 * 1. Normalization
20 * Scale x to y in [1,4) with even powers of 2:
21 * find an integer k such that 1 <= (y=x*2^(2k)) < 4, then
22 * sqrt(x) = 2^k * sqrt(y)
23 * 2. Bit by bit computation
24 * Let q = sqrt(y) truncated to i bit after binary point (q = 1),
25 * i 0
26 * i+1 2
27 * s = 2*q , and y = 2 * ( y - q ). (1)
28 * i i i i
29 *
30 * To compute q from q , one checks whether
31 * i+1 i
32 *
33 * -(i+1) 2
34 * (q + 2 ) <= y. (2)
35 * i
36 * -(i+1)
37 * If (2) is false, then q = q ; otherwise q = q + 2 .
38 * i+1 i i+1 i
39 *
40 * With some algebraic manipulation, it is not difficult to see
41 * that (2) is equivalent to
42 * -(i+1)
43 * s + 2 <= y (3)
44 * i i
45 *
46 * The advantage of (3) is that s and y can be computed by
47 * i i
48 * the following recurrence formula:
49 * if (3) is false
50 *
51 * s = s , y = y ; (4)
52 * i+1 i i+1 i
53 *
54 * otherwise,
55 * -i -(i+1)
56 * s = s + 2 , y = y - s - 2 (5)
57 * i+1 i i+1 i i
58 *
59 * One may easily use induction to prove (4) and (5).
60 * Note. Since the left hand side of (3) contain only i+2 bits,
61 * it does not necessary to do a full (53-bit) comparison
62 * in (3).
63 * 3. Final rounding
64 * After generating the 53 bits result, we compute one more bit.
65 * Together with the remainder, we can decide whether the
66 * result is exact, bigger than 1/2ulp, or less than 1/2ulp
67 * (it will never equal to 1/2ulp).
68 * The rounding mode can be detected by checking whether
69 * huge + tiny is equal to huge, and whether huge - tiny is
70 * equal to huge for some floating point number "huge" and "tiny".
71 *
72 * Special cases:
73 * sqrt(+-0) = +-0 ... exact
74 * sqrt(inf) = inf
75 * sqrt(-ve) = NaN ... with invalid signal
76 * sqrt(NaN) = NaN ... with invalid signal for signaling NaN
77 */
78
79use core::f64;
80
81/// The square root of `x` (f64).
82#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
83pub fn sqrt(x: f64) -> f64 {
84 // On wasm32 we know that LLVM's intrinsic will compile to an optimized
85 // `f64.sqrt` native instruction, so we can leverage this for both code size
86 // and speed.
87 llvm_intrinsically_optimized! {
88 #[cfg(target_arch = "wasm32")] {
89 return if x < 0.0 {
90 f64::NAN
91 } else {
92 unsafe { ::core::intrinsics::sqrtf64(x) }
93 }
94 }
95 }
96 #[cfg(all(target_feature = "sse2", not(feature = "force-soft-floats")))]
97 {
98 // Note: This path is unlikely since LLVM will usually have already
99 // optimized sqrt calls into hardware instructions if sse2 is available,
100 // but if someone does end up here they'll appreciate the speed increase.
101 #[cfg(target_arch = "x86")]
102 use core::arch::x86::*;
103 #[cfg(target_arch = "x86_64")]
104 use core::arch::x86_64::*;
105 unsafe {
106 let m = _mm_set_sd(x);
107 let m_sqrt = _mm_sqrt_pd(m);
108 _mm_cvtsd_f64(m_sqrt)
109 }
110 }
111 #[cfg(any(not(target_feature = "sse2"), feature = "force-soft-floats"))]
112 {
113 use core::num::Wrapping;
114
115 const TINY: f64 = 1.0e-300;
116
117 let mut z: f64;
118 let sign: Wrapping<u32> = Wrapping(0x80000000);
119 let mut ix0: i32;
120 let mut s0: i32;
121 let mut q: i32;
122 let mut m: i32;
123 let mut t: i32;
124 let mut i: i32;
125 let mut r: Wrapping<u32>;
126 let mut t1: Wrapping<u32>;
127 let mut s1: Wrapping<u32>;
128 let mut ix1: Wrapping<u32>;
129 let mut q1: Wrapping<u32>;
130
131 ix0 = (x.to_bits() >> 32) as i32;
132 ix1 = Wrapping(x.to_bits() as u32);
133
134 /* take care of Inf and NaN */
135 if (ix0 & 0x7ff00000) == 0x7ff00000 {
136 return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
137 }
138 /* take care of zero */
139 if ix0 <= 0 {
140 if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
141 return x; /* sqrt(+-0) = +-0 */
142 }
143 if ix0 < 0 {
144 return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
145 }
146 }
147 /* normalize x */
148 m = ix0 >> 20;
149 if m == 0 {
150 /* subnormal x */
151 while ix0 == 0 {
152 m -= 21;
153 ix0 |= (ix1 >> 11).0 as i32;
154 ix1 <<= 21;
155 }
156 i = 0;
157 while (ix0 & 0x00100000) == 0 {
158 i += 1;
159 ix0 <<= 1;
160 }
161 m -= i - 1;
162 ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
163 ix1 = ix1 << i as usize;
164 }
165 m -= 1023; /* unbias exponent */
166 ix0 = (ix0 & 0x000fffff) | 0x00100000;
167 if (m & 1) == 1 {
168 /* odd m, double x to make it even */
169 ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
170 ix1 += ix1;
171 }
172 m >>= 1; /* m = [m/2] */
173
174 /* generate sqrt(x) bit by bit */
175 ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
176 ix1 += ix1;
177 q = 0; /* [q,q1] = sqrt(x) */
178 q1 = Wrapping(0);
179 s0 = 0;
180 s1 = Wrapping(0);
181 r = Wrapping(0x00200000); /* r = moving bit from right to left */
182
183 while r != Wrapping(0) {
184 t = s0 + r.0 as i32;
185 if t <= ix0 {
186 s0 = t + r.0 as i32;
187 ix0 -= t;
188 q += r.0 as i32;
189 }
190 ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
191 ix1 += ix1;
192 r >>= 1;
193 }
194
195 r = sign;
196 while r != Wrapping(0) {
197 t1 = s1 + r;
198 t = s0;
199 if t < ix0 || (t == ix0 && t1 <= ix1) {
200 s1 = t1 + r;
201 if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
202 s0 += 1;
203 }
204 ix0 -= t;
205 if ix1 < t1 {
206 ix0 -= 1;
207 }
208 ix1 -= t1;
209 q1 += r;
210 }
211 ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
212 ix1 += ix1;
213 r >>= 1;
214 }
215
216 /* use floating add to find out rounding direction */
217 if (ix0 as u32 | ix1.0) != 0 {
218 z = 1.0 - TINY; /* raise inexact flag */
219 if z >= 1.0 {
220 z = 1.0 + TINY;
221 if q1.0 == 0xffffffff {
222 q1 = Wrapping(0);
223 q += 1;
224 } else if z > 1.0 {
225 if q1.0 == 0xfffffffe {
226 q += 1;
227 }
228 q1 += Wrapping(2);
229 } else {
230 q1 += q1 & Wrapping(1);
231 }
232 }
233 }
234 ix0 = (q >> 1) + 0x3fe00000;
235 ix1 = q1 >> 1;
236 if (q & 1) == 1 {
237 ix1 |= sign;
238 }
239 ix0 += m << 20;
240 f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use core::f64::*;
247
248 use super::*;
249
250 #[test]
251 fn sanity_check() {
252 assert_eq!(sqrt(100.0), 10.0);
253 assert_eq!(sqrt(4.0), 2.0);
254 }
255
256 /// The spec: https://en.cppreference.com/w/cpp/numeric/math/sqrt
257 #[test]
258 fn spec_tests() {
259 // Not Asserted: FE_INVALID exception is raised if argument is negative.
260 assert!(sqrt(-1.0).is_nan());
261 assert!(sqrt(NAN).is_nan());
262 for f in [0.0, -0.0, INFINITY].iter().copied() {
263 assert_eq!(sqrt(f), f);
264 }
265 }
266
267 #[test]
268 fn conformance_tests() {
269 let values = [3.14159265359, 10000.0, f64::from_bits(0x0000000f), INFINITY];
270 let results = [
271 4610661241675116657u64,
272 4636737291354636288u64,
273 2197470602079456986u64,
274 9218868437227405312u64,
275 ];
276
277 for i in 0..values.len() {
278 let bits = f64::to_bits(sqrt(values[i]));
279 assert_eq!(results[i], bits);
280 }
281 }
282}