miri/math.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
use rand::Rng as _;
use rand::distributions::Distribution as _;
use rustc_apfloat::Float as _;
use rustc_apfloat::ieee::IeeeFloat;
/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
ecx: &mut crate::MiriInterpCx<'_>,
val: F,
err_scale: i32,
) -> F {
let rng = ecx.machine.rng.get_mut();
// Generate a random integer in the range [0, 2^PREC).
let dist = rand::distributions::Uniform::new(0, 1 << F::PRECISION);
let err = F::from_u128(dist.sample(rng))
.value
.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
// give it a random sign
let err = if rng.gen::<bool>() { -err } else { err };
// multiple the value with (1+err)
(val * (F::from_u128(1).value + err).value).value
}
pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
match x.category() {
// preserve zero sign
rustc_apfloat::Category::Zero => x,
// propagate NaN
rustc_apfloat::Category::NaN => x,
// sqrt of negative number is NaN
_ if x.is_negative() => IeeeFloat::NAN,
// sqrt(∞) = ∞
rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
rustc_apfloat::Category::Normal => {
// Floating point precision, excluding the integer bit
let prec = i32::try_from(S::PRECISION).unwrap() - 1;
// x = 2^(exp - prec) * mant
// where mant is an integer with prec+1 bits
// mant is a u128, which should be large enough for the largest prec (112 for f128)
let mut exp = x.ilogb();
let mut mant = x.scalbn(prec - exp).to_u128(128).value;
if exp % 2 != 0 {
// Make exponent even, so it can be divided by 2
exp -= 1;
mant <<= 1;
}
// Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
// mant is treated here as a fixed point number with prec fractional bits.
// mant will be shifted left by one bit to have an extra fractional bit, which
// will be used to determine the rounding direction.
// res is the truncated sqrt of mant, where one bit is added at each iteration.
let mut res = 0u128;
// rem is the remainder with the current res
// rem_i = 2^i * ((mant<<1) - res_i^2)
// starting with res = 0, rem = mant<<1
let mut rem = mant << 1;
// s_i = 2*res_i
let mut s = 0u128;
// d is used to iterate over bits, from high to low (d_i = 2^(-i))
let mut d = 1u128 << (prec + 1);
// For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
// (res_i + b_j * 2^(-j))^2 <= mant<<1
// Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
// res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
// And rearranging the terms:
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
while d != 0 {
// Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
// t = 2*res_i + 2^(-j)
let t = s + d;
if rem >= t {
// b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
res += d;
s += d + d;
rem -= t;
}
// Adjust rem for next iteration
rem <<= 1;
// Shift iterator
d >>= 1;
}
// Remove extra fractional bit from result, rounding to nearest.
// If the last bit is 0, then the nearest neighbor is definitely the lower one.
// If the last bit is 1, it sounds like this may either be a tie (if there's
// infinitely many 0s after this 1), or the nearest neighbor is the upper one.
// However, since square roots are either exact or irrational, and an exact root
// would lead to the last "extra" bit being 0, we can exclude a tie in this case.
// We therefore always round up if the last bit is 1. When the last bit is 0,
// adding 1 will not do anything since the shift will discard it.
res = (res + 1) >> 1;
// Build resulting value with res as mantissa and exp/2 as exponent
IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
}
}
}
#[cfg(test)]
mod tests {
use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
use super::sqrt;
#[test]
fn test_sqrt() {
#[track_caller]
fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
let x: IeeeFloat<S> = x.parse().unwrap();
let expected: IeeeFloat<S> = expected.parse().unwrap();
let result = sqrt(x);
assert_eq!(result, expected);
}
fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
test::<S>("0", "0");
test::<S>("1", "1");
test::<S>("1.5625", "1.25");
test::<S>("2.25", "1.5");
test::<S>("4", "2");
test::<S>("5.0625", "2.25");
test::<S>("9", "3");
test::<S>("16", "4");
test::<S>("25", "5");
test::<S>("36", "6");
test::<S>("49", "7");
test::<S>("64", "8");
test::<S>("81", "9");
test::<S>("100", "10");
test::<S>("0.5625", "0.75");
test::<S>("0.25", "0.5");
test::<S>("0.0625", "0.25");
test::<S>("0.00390625", "0.0625");
}
exact_tests::<HalfS>();
exact_tests::<SingleS>();
exact_tests::<DoubleS>();
exact_tests::<QuadS>();
test::<SingleS>("2", "1.4142135");
test::<DoubleS>("2", "1.4142135623730951");
test::<SingleS>("1.1", "1.0488088");
test::<DoubleS>("1.1", "1.0488088481701516");
test::<SingleS>("2.2", "1.4832398");
test::<DoubleS>("2.2", "1.4832396974191326");
test::<SingleS>("1.22101e-40", "1.10499205e-20");
test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
test::<SingleS>("3.4028235e38", "1.8446743e19");
test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
}
}