1use std::ops::Neg;
2
3use rand::Rng as _;
4use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, Semantics, SingleS};
5use rustc_apfloat::{Float, FloatConvert};
6use rustc_middle::ty::{self, FloatTy, ScalarInt};
7
8use crate::*;
9
10pub trait ToHost {
12 type HostFloat;
13 fn to_host(self) -> Self::HostFloat;
14}
15
16pub trait ToSoft {
18 type SoftFloat;
19 fn to_soft(self) -> Self::SoftFloat;
20}
21
22impl ToHost for rustc_apfloat::ieee::Double {
23 type HostFloat = f64;
24
25 fn to_host(self) -> Self::HostFloat {
26 f64::from_bits(self.to_bits().try_into().unwrap())
27 }
28}
29
30impl ToSoft for f64 {
31 type SoftFloat = rustc_apfloat::ieee::Double;
32
33 fn to_soft(self) -> Self::SoftFloat {
34 Float::from_bits(self.to_bits().into())
35 }
36}
37
38impl ToHost for rustc_apfloat::ieee::Single {
39 type HostFloat = f32;
40
41 fn to_host(self) -> Self::HostFloat {
42 f32::from_bits(self.to_bits().try_into().unwrap())
43 }
44}
45
46impl ToSoft for f32 {
47 type SoftFloat = rustc_apfloat::ieee::Single;
48
49 fn to_soft(self) -> Self::SoftFloat {
50 Float::from_bits(self.to_bits().into())
51 }
52}
53
54impl ToHost for rustc_apfloat::ieee::Half {
55 type HostFloat = f16;
56
57 fn to_host(self) -> Self::HostFloat {
58 f16::from_bits(self.to_bits().try_into().unwrap())
59 }
60}
61
62impl ToSoft for f16 {
63 type SoftFloat = rustc_apfloat::ieee::Half;
64
65 fn to_soft(self) -> Self::SoftFloat {
66 Float::from_bits(self.to_bits().into())
67 }
68}
69
70pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
72 ecx: &mut crate::MiriInterpCx<'_>,
73 val: F,
74 err_scale: i32,
75) -> F {
76 if !ecx.machine.float_nondet
77 || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
78 || val.is_zero()
80 || !val.is_finite()
82 {
83 return val;
84 }
85 let rng = ecx.machine.rng.get_mut();
86
87 let r = F::from_u128(match ecx.machine.float_rounding_error {
92 FloatRoundingErrorMode::Random => rng.random_range(0..(1 << F::PRECISION)),
93 FloatRoundingErrorMode::Max => (1 << F::PRECISION) - 1, FloatRoundingErrorMode::None => unreachable!(),
95 })
96 .value;
97 let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
100 let err = if rng.random() { -err } else { err };
102 (val + (val * err).value).value
105}
106
107pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
109 ecx: &mut crate::MiriInterpCx<'_>,
110 val: F,
111 max_error: u32,
112) -> F {
113 if !ecx.machine.float_nondet
117 || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
118 || val.is_zero()
121 || !val.is_finite()
123 {
124 return val;
125 }
126 let rng = ecx.machine.rng.get_mut();
127
128 let max_error = i64::from(max_error);
129 let error = match ecx.machine.float_rounding_error {
130 FloatRoundingErrorMode::Random => rng.random_range(-max_error..=max_error),
131 FloatRoundingErrorMode::Max =>
132 if rng.random() {
133 max_error
134 } else {
135 -max_error
136 },
137 FloatRoundingErrorMode::None => unreachable!(),
138 };
139 let ulp = (((val.next_up().value - val).value + (val - val.next_down().value).value).value
141 / F::from_u128(2).value)
142 .value;
143 (val + (ulp * F::from_i128(error.into()).value).value).value
145}
146
147pub(crate) fn apply_random_float_error_to_imm<'tcx>(
150 ecx: &mut MiriInterpCx<'tcx>,
151 val: ImmTy<'tcx>,
152 max_error: u32,
153) -> InterpResult<'tcx, ImmTy<'tcx>> {
154 let scalar = val.to_scalar_int()?;
155 let res: ScalarInt = match val.layout.ty.kind() {
156 ty::Float(FloatTy::F16) =>
157 apply_random_float_error_ulp(ecx, scalar.to_f16(), max_error).into(),
158 ty::Float(FloatTy::F32) =>
159 apply_random_float_error_ulp(ecx, scalar.to_f32(), max_error).into(),
160 ty::Float(FloatTy::F64) =>
161 apply_random_float_error_ulp(ecx, scalar.to_f64(), max_error).into(),
162 ty::Float(FloatTy::F128) =>
163 apply_random_float_error_ulp(ecx, scalar.to_f128(), max_error).into(),
164 _ => bug!("intrinsic called with non-float input type"),
165 };
166
167 interp_ok(ImmTy::from_scalar_int(res, val.layout))
168}
169
170fn strip_float_suffix(intrinsic_name: &str) -> &str {
175 let name = intrinsic_name
176 .strip_suffix("f16")
177 .or_else(|| intrinsic_name.strip_suffix("f32"))
178 .or_else(|| intrinsic_name.strip_suffix("f64"))
179 .or_else(|| intrinsic_name.strip_suffix("f128"))
180 .unwrap_or(intrinsic_name);
181
182 if name == "erf" { name } else { name.strip_suffix("f").unwrap_or(name) }
183}
184
185pub(crate) fn clamp_float_value<S: Semantics>(
188 intrinsic_name: &str,
189 val: IeeeFloat<S>,
190) -> IeeeFloat<S>
191where
192 IeeeFloat<S>: IeeeExt,
193{
194 let zero = IeeeFloat::<S>::ZERO;
195 let one = IeeeFloat::<S>::one();
196 let two = IeeeFloat::<S>::two();
197 let pi = IeeeFloat::<S>::pi();
198 let pi_over_2 = (pi / two).value;
199
200 match strip_float_suffix(intrinsic_name) {
201 "sin" | "cos" | "tanh" => val.clamp(one.neg(), one),
203
204 "exp" | "exp2" => val.maximum(zero),
206
207 "cosh" => val.maximum(one),
209
210 "acos" => val.clamp(zero, pi),
212
213 "asin" => val.clamp(pi.neg(), pi),
215
216 "atan" => val.clamp(pi_over_2.neg(), pi_over_2),
218
219 "erf" => val.clamp(one.neg(), one),
221
222 "erfc" => val.clamp(zero, two),
224
225 "atan2" => val.clamp(pi.neg(), pi),
227
228 _ => val,
229 }
230}
231
232pub(crate) fn fixed_float_value<S: Semantics>(
265 ecx: &mut MiriInterpCx<'_>,
266 intrinsic_name: &str,
267 args: &[IeeeFloat<S>],
268) -> Option<IeeeFloat<S>>
269where
270 IeeeFloat<S>: IeeeExt,
271{
272 let this = ecx.eval_context_mut();
273 let one = IeeeFloat::<S>::one();
274 let two = IeeeFloat::<S>::two();
275 let three = IeeeFloat::<S>::three();
276 let pi = IeeeFloat::<S>::pi();
277 let pi_over_2 = (pi / two).value;
278 let pi_over_4 = (pi_over_2 / two).value;
279
280 Some(match (strip_float_suffix(intrinsic_name), args) {
281 ("cos" | "cosh", [input]) if input.is_zero() => one,
283
284 ("exp" | "exp2", [input]) if input.is_zero() => one,
286
287 ("tanh", [input]) if input.is_infinite() => one.copy_sign(*input),
289
290 ("atan", [input]) if input.is_infinite() => pi_over_2.copy_sign(*input),
292
293 ("erf", [input]) if input.is_infinite() => one.copy_sign(*input),
295
296 ("erfc", [input]) if input.is_neg_infinity() => (one + one).value,
298
299 ("_hypot" | "hypot", [x, y]) if !x.is_nan() && y.is_zero() => x.abs(),
302
303 ("atan2", [x, y]) if (x.is_zero() && (y.is_negative() && !y.is_nan())) => pi.copy_sign(*x),
307
308 ("atan2", [x, y]) if (!x.is_zero() && !x.is_infinite()) && y.is_neg_infinity() =>
310 pi.copy_sign(*x),
311
312 ("atan2", [x, y]) if !x.is_zero() && y.is_zero() => pi_over_2.copy_sign(*x),
315
316 ("atan2", [x, y]) if x.is_infinite() && y.is_neg_infinity() =>
318 (pi_over_4 * three).value.copy_sign(*x),
319
320 ("atan2", [x, y]) if x.is_infinite() && y.is_pos_infinity() => pi_over_4.copy_sign(*x),
322
323 ("atan2", [x, y]) if x.is_infinite() && (!y.is_infinite() && !y.is_nan()) =>
325 pi_over_2.copy_sign(*x),
326
327 ("pow", [base, exp]) if *base == -one && exp.is_infinite() => one,
329
330 ("pow", [base, exp]) if *base == one => {
332 let rng = this.machine.rng.get_mut();
333 let return_nan = exp.is_signaling() && this.machine.float_nondet && rng.random();
337 if return_nan { this.generate_nan(args) } else { one }
338 }
339
340 ("pow", [base, exp]) if exp.is_zero() => {
342 let rng = this.machine.rng.get_mut();
343 let return_nan = base.is_signaling() && this.machine.float_nondet && rng.random();
347 if return_nan { this.generate_nan(args) } else { one }
348 }
349
350 _ => return None,
353 })
354}
355
356pub(crate) fn fixed_powi_value<S: Semantics>(
359 ecx: &mut MiriInterpCx<'_>,
360 base: IeeeFloat<S>,
361 exp: i32,
362) -> Option<IeeeFloat<S>>
363where
364 IeeeFloat<S>: IeeeExt,
365{
366 match exp {
367 0 => {
368 let one = IeeeFloat::<S>::one();
369 let rng = ecx.machine.rng.get_mut();
370 let return_nan = base.is_signaling() && ecx.machine.float_nondet && rng.random();
373 Some(if return_nan { ecx.generate_nan(&[base]) } else { one })
374 }
375
376 _ => None,
377 }
378}
379
380pub(crate) fn sqrt<F: Float>(x: F) -> F {
381 match x.category() {
382 rustc_apfloat::Category::Zero => x,
384 rustc_apfloat::Category::NaN => x,
386 _ if x.is_negative() => F::NAN,
388 rustc_apfloat::Category::Infinity => F::INFINITY,
390 rustc_apfloat::Category::Normal => {
391 let prec = i32::try_from(F::PRECISION).unwrap() - 1;
393
394 let mut exp = x.ilogb();
398 let mut mant = x.scalbn(prec - exp).to_u128(128).value;
399
400 if exp % 2 != 0 {
401 exp -= 1;
403 mant <<= 1;
404 }
405
406 let mut res = 0u128;
413 let mut rem = mant << 1;
417 let mut s = 0u128;
419 let mut d = 1u128 << (prec + 1);
421
422 while d != 0 {
431 let t = s + d;
434 if rem >= t {
435 res += d;
437 s += d + d;
438 rem -= t;
439 }
440 rem <<= 1;
442 d >>= 1;
444 }
445
446 res = (res + 1) >> 1;
455
456 F::from_u128(res).value.scalbn(exp / 2 - prec)
458 }
459 }
460}
461
462pub fn sqrt_op<'tcx, F: Float + FloatConvert<F> + Into<Scalar>>(
463 this: &mut MiriInterpCx<'tcx>,
464 f: &OpTy<'tcx>,
465 dest: &MPlaceTy<'tcx>,
466) -> InterpResult<'tcx> {
467 let f: F = this.read_scalar(f)?.to_float()?;
468 let res = math::sqrt(f);
470 let res = this.adjust_nan(res, &[f]);
471 this.write_scalar(res, dest)
472}
473
474pub trait HostFloatOperation {
475 fn host_sin(self) -> Self;
476 fn host_cos(self) -> Self;
477 fn host_exp(self) -> Self;
478 fn host_exp2(self) -> Self;
479 fn host_log(self) -> Self;
480 fn host_log10(self) -> Self;
481 fn host_log2(self) -> Self;
482 fn host_powf(self, y: Self) -> Self;
483 fn host_powi(self, y: i32) -> Self;
484}
485
486macro_rules! impl_float_host_operations {
487 ($ty:ty) => {
488 impl HostFloatOperation for $ty {
489 fn host_sin(self) -> Self {
490 self.to_host().sin().to_soft()
491 }
492 fn host_cos(self) -> Self {
493 self.to_host().cos().to_soft()
494 }
495 fn host_exp(self) -> Self {
496 self.to_host().exp().to_soft()
497 }
498 fn host_exp2(self) -> Self {
499 self.to_host().exp2().to_soft()
500 }
501 fn host_log(self) -> Self {
502 self.to_host().ln().to_soft()
503 }
504 fn host_log10(self) -> Self {
505 self.to_host().log10().to_soft()
506 }
507 fn host_log2(self) -> Self {
508 self.to_host().log2().to_soft()
509 }
510 fn host_powf(self, y: Self) -> Self {
511 self.to_host().powf(y.to_host()).to_soft()
512 }
513 fn host_powi(self, y: i32) -> Self {
514 self.to_host().powi(y).to_soft()
515 }
516 }
517 };
518}
519
520impl_float_host_operations!(IeeeFloat<HalfS>);
521impl_float_host_operations!(IeeeFloat<SingleS>);
522impl_float_host_operations!(IeeeFloat<DoubleS>);
523
524#[derive(Debug, Clone, Copy)]
525pub enum HostUnaryFloatOp {
526 Sin,
527 Cos,
528 Exp,
529 Exp2,
530 Log,
531 Log10,
532 Log2,
533}
534
535pub fn host_unary_float_op<'tcx, S: Semantics>(
536 this: &mut MiriInterpCx<'tcx>,
537 f: &OpTy<'tcx>,
538 op: HostUnaryFloatOp,
539 dest: &MPlaceTy<'tcx>,
540) -> InterpResult<'tcx>
541where
542 IeeeFloat<S>: HostFloatOperation + IeeeExt + Float + Into<Scalar>,
543{
544 use HostFloatOperation;
545
546 let f: IeeeFloat<S> = this.read_scalar(f)?.to_float()?;
547
548 let name = match op {
550 HostUnaryFloatOp::Sin => "sin",
551 HostUnaryFloatOp::Cos => "cos",
552 HostUnaryFloatOp::Exp => "exp",
553 HostUnaryFloatOp::Exp2 => "exp2",
554 HostUnaryFloatOp::Log => "log",
555 HostUnaryFloatOp::Log10 => "log10",
556 HostUnaryFloatOp::Log2 => "log2",
557 };
558 let res = math::fixed_float_value(this, name, &[f]).unwrap_or_else(|| {
559 let res = match op {
562 HostUnaryFloatOp::Sin => f.host_sin(),
563 HostUnaryFloatOp::Cos => f.host_cos(),
564 HostUnaryFloatOp::Exp => f.host_exp(),
565 HostUnaryFloatOp::Exp2 => f.host_exp2(),
566 HostUnaryFloatOp::Log => f.host_log(),
567 HostUnaryFloatOp::Log10 => f.host_log10(),
568 HostUnaryFloatOp::Log2 => f.host_log2(),
569 };
570
571 let res = math::apply_random_float_error_ulp(this, res, 4);
574
575 math::clamp_float_value(name, res)
578 });
579
580 let res = this.adjust_nan(res, &[f]);
581 this.write_scalar(res, dest)?;
582 interp_ok(())
583}
584
585pub trait IeeeExt: rustc_apfloat::Float {
587 #[inline]
590 fn one() -> Self {
591 Self::from_u128(1).value
592 }
593
594 #[inline]
595 fn two() -> Self {
596 Self::from_u128(2).value
597 }
598
599 #[inline]
600 fn three() -> Self {
601 Self::from_u128(3).value
602 }
603
604 fn pi() -> Self;
605
606 #[inline]
607 fn clamp(self, min: Self, max: Self) -> Self {
608 self.maximum(min).minimum(max)
609 }
610}
611
612macro_rules! impl_ieee_pi {
613 ($float_ty:ident, $semantic:ty) => {
614 impl IeeeExt for IeeeFloat<$semantic> {
615 #[inline]
616 fn pi() -> Self {
617 Self::from_bits(core::$float_ty::consts::PI.to_bits().into())
619 }
620 }
621 };
622}
623
624impl_ieee_pi!(f16, HalfS);
625impl_ieee_pi!(f32, SingleS);
626impl_ieee_pi!(f64, DoubleS);
627
628#[cfg(test)]
629mod tests {
630 use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
631
632 use super::sqrt;
633
634 #[test]
635 fn test_sqrt() {
636 #[track_caller]
637 fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
638 let x: IeeeFloat<S> = x.parse().unwrap();
639 let expected: IeeeFloat<S> = expected.parse().unwrap();
640 let result = sqrt(x);
641 assert_eq!(result, expected);
642 }
643
644 fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
645 test::<S>("0", "0");
646 test::<S>("1", "1");
647 test::<S>("1.5625", "1.25");
648 test::<S>("2.25", "1.5");
649 test::<S>("4", "2");
650 test::<S>("5.0625", "2.25");
651 test::<S>("9", "3");
652 test::<S>("16", "4");
653 test::<S>("25", "5");
654 test::<S>("36", "6");
655 test::<S>("49", "7");
656 test::<S>("64", "8");
657 test::<S>("81", "9");
658 test::<S>("100", "10");
659
660 test::<S>("0.5625", "0.75");
661 test::<S>("0.25", "0.5");
662 test::<S>("0.0625", "0.25");
663 test::<S>("0.00390625", "0.0625");
664 }
665
666 exact_tests::<HalfS>();
667 exact_tests::<SingleS>();
668 exact_tests::<DoubleS>();
669 exact_tests::<QuadS>();
670
671 test::<SingleS>("2", "1.4142135");
672 test::<DoubleS>("2", "1.4142135623730951");
673
674 test::<SingleS>("1.1", "1.0488088");
675 test::<DoubleS>("1.1", "1.0488088481701516");
676
677 test::<SingleS>("2.2", "1.4832398");
678 test::<DoubleS>("2.2", "1.4832396974191326");
679
680 test::<SingleS>("1.22101e-40", "1.10499205e-20");
681 test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
682
683 test::<SingleS>("3.4028235e38", "1.8446743e19");
684 test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
685 }
686}