miri/math.rs
1use rand::Rng as _;
2use rustc_apfloat::Float as _;
3use rustc_apfloat::ieee::IeeeFloat;
4
5/// Disturbes a floating-point result by a relative error in the range (-2^scale, 2^scale).
6///
7/// For a 2^N ULP error, you can use an `err_scale` of `-(F::PRECISION - 1 - N)`.
8/// In other words, a 1 ULP (absolute) error is the same as a `2^-(F::PRECISION-1)` relative error.
9/// (Subtracting 1 compensates for the integer bit.)
10pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
11 ecx: &mut crate::MiriInterpCx<'_>,
12 val: F,
13 err_scale: i32,
14) -> F {
15 let rng = ecx.machine.rng.get_mut();
16 // Generate a random integer in the range [0, 2^PREC).
17 // (When read as binary, the position of the first `1` determines the exponent,
18 // and the remaining bits fill the mantissa. `PREC` is one plus the size of the mantissa,
19 // so this all works out.)
20 let r = F::from_u128(rng.random_range(0..(1 << F::PRECISION))).value;
21 // Multiply this with 2^(scale - PREC). The result is between 0 and
22 // 2^PREC * 2^(scale - PREC) = 2^scale.
23 let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
24 // give it a random sign
25 let err = if rng.random() { -err } else { err };
26 // multiple the value with (1+err)
27 (val * (F::from_u128(1).value + err).value).value
28}
29
30/// [`apply_random_float_error`] gives instructions to apply a 2^N ULP error.
31/// This function implements these instructions such that applying a 2^N ULP error is less error prone.
32/// So for a 2^N ULP error, you would pass N as the `ulp_exponent` argument.
33pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
34 ecx: &mut crate::MiriInterpCx<'_>,
35 val: F,
36 ulp_exponent: u32,
37) -> F {
38 let n = i32::try_from(ulp_exponent)
39 .expect("`err_scale_for_ulp`: exponent is too large to create an error scale");
40 // we know this fits
41 let prec = i32::try_from(F::PRECISION).unwrap();
42 let err_scale = -(prec - n - 1);
43 apply_random_float_error(ecx, val, err_scale)
44}
45
46pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
47 match x.category() {
48 // preserve zero sign
49 rustc_apfloat::Category::Zero => x,
50 // propagate NaN
51 rustc_apfloat::Category::NaN => x,
52 // sqrt of negative number is NaN
53 _ if x.is_negative() => IeeeFloat::NAN,
54 // sqrt(∞) = ∞
55 rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
56 rustc_apfloat::Category::Normal => {
57 // Floating point precision, excluding the integer bit
58 let prec = i32::try_from(S::PRECISION).unwrap() - 1;
59
60 // x = 2^(exp - prec) * mant
61 // where mant is an integer with prec+1 bits
62 // mant is a u128, which should be large enough for the largest prec (112 for f128)
63 let mut exp = x.ilogb();
64 let mut mant = x.scalbn(prec - exp).to_u128(128).value;
65
66 if exp % 2 != 0 {
67 // Make exponent even, so it can be divided by 2
68 exp -= 1;
69 mant <<= 1;
70 }
71
72 // Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
73 // mant is treated here as a fixed point number with prec fractional bits.
74 // mant will be shifted left by one bit to have an extra fractional bit, which
75 // will be used to determine the rounding direction.
76
77 // res is the truncated sqrt of mant, where one bit is added at each iteration.
78 let mut res = 0u128;
79 // rem is the remainder with the current res
80 // rem_i = 2^i * ((mant<<1) - res_i^2)
81 // starting with res = 0, rem = mant<<1
82 let mut rem = mant << 1;
83 // s_i = 2*res_i
84 let mut s = 0u128;
85 // d is used to iterate over bits, from high to low (d_i = 2^(-i))
86 let mut d = 1u128 << (prec + 1);
87
88 // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
89 // (res_i + b_j * 2^(-j))^2 <= mant<<1
90 // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
91 // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
92 // And rearranging the terms:
93 // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
94 // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
95
96 while d != 0 {
97 // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
98 // t = 2*res_i + 2^(-j)
99 let t = s + d;
100 if rem >= t {
101 // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
102 res += d;
103 s += d + d;
104 rem -= t;
105 }
106 // Adjust rem for next iteration
107 rem <<= 1;
108 // Shift iterator
109 d >>= 1;
110 }
111
112 // Remove extra fractional bit from result, rounding to nearest.
113 // If the last bit is 0, then the nearest neighbor is definitely the lower one.
114 // If the last bit is 1, it sounds like this may either be a tie (if there's
115 // infinitely many 0s after this 1), or the nearest neighbor is the upper one.
116 // However, since square roots are either exact or irrational, and an exact root
117 // would lead to the last "extra" bit being 0, we can exclude a tie in this case.
118 // We therefore always round up if the last bit is 1. When the last bit is 0,
119 // adding 1 will not do anything since the shift will discard it.
120 res = (res + 1) >> 1;
121
122 // Build resulting value with res as mantissa and exp/2 as exponent
123 IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
131
132 use super::sqrt;
133
134 #[test]
135 fn test_sqrt() {
136 #[track_caller]
137 fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
138 let x: IeeeFloat<S> = x.parse().unwrap();
139 let expected: IeeeFloat<S> = expected.parse().unwrap();
140 let result = sqrt(x);
141 assert_eq!(result, expected);
142 }
143
144 fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
145 test::<S>("0", "0");
146 test::<S>("1", "1");
147 test::<S>("1.5625", "1.25");
148 test::<S>("2.25", "1.5");
149 test::<S>("4", "2");
150 test::<S>("5.0625", "2.25");
151 test::<S>("9", "3");
152 test::<S>("16", "4");
153 test::<S>("25", "5");
154 test::<S>("36", "6");
155 test::<S>("49", "7");
156 test::<S>("64", "8");
157 test::<S>("81", "9");
158 test::<S>("100", "10");
159
160 test::<S>("0.5625", "0.75");
161 test::<S>("0.25", "0.5");
162 test::<S>("0.0625", "0.25");
163 test::<S>("0.00390625", "0.0625");
164 }
165
166 exact_tests::<HalfS>();
167 exact_tests::<SingleS>();
168 exact_tests::<DoubleS>();
169 exact_tests::<QuadS>();
170
171 test::<SingleS>("2", "1.4142135");
172 test::<DoubleS>("2", "1.4142135623730951");
173
174 test::<SingleS>("1.1", "1.0488088");
175 test::<DoubleS>("1.1", "1.0488088481701516");
176
177 test::<SingleS>("2.2", "1.4832398");
178 test::<DoubleS>("2.2", "1.4832396974191326");
179
180 test::<SingleS>("1.22101e-40", "1.10499205e-20");
181 test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
182
183 test::<SingleS>("3.4028235e38", "1.8446743e19");
184 test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
185 }
186}