use rustc_middle::mir;
use rustc_middle::ty::{self, Ty};
use std::borrow::Borrow;
use std::collections::hash_map::Entry;
use std::hash::Hash;
use rustc_data_structures::fx::FxHashMap;
use std::fmt;
use rustc_ast::Mutability;
use rustc_hir::def_id::DefId;
use rustc_middle::mir::AssertMessage;
use rustc_session::Limit;
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{Align, Size};
use rustc_target::spec::abi::Abi;
use crate::interpret::{
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, InterpCx, InterpResult, Memory,
OpTy, PlaceTy, Pointer, Scalar,
};
use super::error::*;
impl<'mir, 'tcx> InterpCx<'mir, 'tcx, CompileTimeInterpreter<'mir, 'tcx>> {
fn hook_panic_fn(
&mut self,
instance: ty::Instance<'tcx>,
args: &[OpTy<'tcx>],
) -> InterpResult<'tcx> {
let def_id = instance.def_id();
if Some(def_id) == self.tcx.lang_items().panic_fn()
|| Some(def_id) == self.tcx.lang_items().panic_str()
|| Some(def_id) == self.tcx.lang_items().begin_panic_fn()
{
assert!(args.len() == 1);
let msg_place = self.deref_operand(args[0])?;
let msg = Symbol::intern(self.read_str(msg_place)?);
let span = self.find_closest_untracked_caller_location();
let (file, line, col) = self.location_triple_for_span(span);
Err(ConstEvalErrKind::Panic { msg, file, line, col }.into())
} else {
Ok(())
}
}
}
pub struct CompileTimeInterpreter<'mir, 'tcx> {
pub steps_remaining: usize,
pub(crate) stack: Vec<Frame<'mir, 'tcx, (), ()>>,
}
#[derive(Copy, Clone, Debug)]
pub struct MemoryExtra {
pub(super) can_access_statics: bool,
}
impl<'mir, 'tcx> CompileTimeInterpreter<'mir, 'tcx> {
pub(super) fn new(const_eval_limit: Limit) -> Self {
CompileTimeInterpreter { steps_remaining: const_eval_limit.0, stack: Vec::new() }
}
}
impl<K: Hash + Eq, V> interpret::AllocMap<K, V> for FxHashMap<K, V> {
#[inline(always)]
fn contains_key<Q: ?Sized + Hash + Eq>(&mut self, k: &Q) -> bool
where
K: Borrow<Q>,
{
FxHashMap::contains_key(self, k)
}
#[inline(always)]
fn insert(&mut self, k: K, v: V) -> Option<V> {
FxHashMap::insert(self, k, v)
}
#[inline(always)]
fn remove<Q: ?Sized + Hash + Eq>(&mut self, k: &Q) -> Option<V>
where
K: Borrow<Q>,
{
FxHashMap::remove(self, k)
}
#[inline(always)]
fn filter_map_collect<T>(&self, mut f: impl FnMut(&K, &V) -> Option<T>) -> Vec<T> {
self.iter().filter_map(move |(k, v)| f(k, &*v)).collect()
}
#[inline(always)]
fn get_or<E>(&self, k: K, vacant: impl FnOnce() -> Result<V, E>) -> Result<&V, E> {
match self.get(&k) {
Some(v) => Ok(v),
None => {
vacant()?;
bug!("The CTFE machine shouldn't ever need to extend the alloc_map when reading")
}
}
}
#[inline(always)]
fn get_mut_or<E>(&mut self, k: K, vacant: impl FnOnce() -> Result<V, E>) -> Result<&mut V, E> {
match self.entry(k) {
Entry::Occupied(e) => Ok(e.into_mut()),
Entry::Vacant(e) => {
let v = vacant()?;
Ok(e.insert(v))
}
}
}
}
crate type CompileTimeEvalContext<'mir, 'tcx> =
InterpCx<'mir, 'tcx, CompileTimeInterpreter<'mir, 'tcx>>;
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum MemoryKind {
Heap,
}
impl fmt::Display for MemoryKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MemoryKind::Heap => write!(f, "heap allocation"),
}
}
}
impl interpret::MayLeak for MemoryKind {
#[inline(always)]
fn may_leak(self) -> bool {
match self {
MemoryKind::Heap => false,
}
}
}
impl interpret::MayLeak for ! {
#[inline(always)]
fn may_leak(self) -> bool {
self
}
}
impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
fn guaranteed_eq(&mut self, a: Scalar, b: Scalar) -> bool {
match (a, b) {
(Scalar::Int { .. }, Scalar::Int { .. }) => a == b,
(Scalar::Int { .. }, Scalar::Ptr(_)) | (Scalar::Ptr(_), Scalar::Int { .. }) => false,
(Scalar::Ptr(_), Scalar::Ptr(_)) => false,
}
}
fn guaranteed_ne(&mut self, a: Scalar, b: Scalar) -> bool {
match (a, b) {
(Scalar::Int(_), Scalar::Int(_)) => a != b,
(Scalar::Int(int), Scalar::Ptr(ptr)) | (Scalar::Ptr(ptr), Scalar::Int(int)) => {
int.is_null() && !self.memory.ptr_may_be_null(ptr)
}
(Scalar::Ptr(_), Scalar::Ptr(_)) => false,
}
}
}
impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir, 'tcx> {
compile_time_machine!(<'mir, 'tcx>);
type MemoryKind = MemoryKind;
type MemoryExtra = MemoryExtra;
fn load_mir(
ecx: &InterpCx<'mir, 'tcx, Self>,
instance: ty::InstanceDef<'tcx>,
) -> InterpResult<'tcx, &'tcx mir::Body<'tcx>> {
match instance {
ty::InstanceDef::Item(def) => {
if ecx.tcx.is_ctfe_mir_available(def.did) {
Ok(ecx.tcx.mir_for_ctfe_opt_const_arg(def))
} else {
throw_unsup!(NoMirFor(def.did))
}
}
_ => Ok(ecx.tcx.instance_mir(instance)),
}
}
fn find_mir_or_eval_fn(
ecx: &mut InterpCx<'mir, 'tcx, Self>,
instance: ty::Instance<'tcx>,
_abi: Abi,
args: &[OpTy<'tcx>],
_ret: Option<(PlaceTy<'tcx>, mir::BasicBlock)>,
_unwind: Option<mir::BasicBlock>,
) -> InterpResult<'tcx, Option<&'mir mir::Body<'tcx>>> {
debug!("find_mir_or_eval_fn: {:?}", instance);
if let ty::InstanceDef::Item(def) = instance.def {
if !ecx.tcx.is_const_fn_raw(def.did) {
ecx.hook_panic_fn(instance, args)?;
throw_unsup_format!("calling non-const function `{}`", instance)
}
}
Ok(Some(match ecx.load_mir(instance.def, None) {
Ok(body) => body,
Err(err) => {
if let err_unsup!(NoMirFor(did)) = err.kind {
let path = ecx.tcx.def_path_str(did);
return Err(ConstEvalErrKind::NeedsRfc(format!(
"calling extern function `{}`",
path
))
.into());
}
return Err(err);
}
}))
}
fn call_intrinsic(
ecx: &mut InterpCx<'mir, 'tcx, Self>,
instance: ty::Instance<'tcx>,
args: &[OpTy<'tcx>],
ret: Option<(PlaceTy<'tcx>, mir::BasicBlock)>,
_unwind: Option<mir::BasicBlock>,
) -> InterpResult<'tcx> {
if ecx.emulate_intrinsic(instance, args, ret)? {
return Ok(());
}
let intrinsic_name = ecx.tcx.item_name(instance.def_id());
let (dest, ret) = match ret {
None => {
return Err(ConstEvalErrKind::NeedsRfc(format!(
"calling intrinsic `{}`",
intrinsic_name
))
.into());
}
Some(p) => p,
};
match intrinsic_name {
sym::ptr_guaranteed_eq | sym::ptr_guaranteed_ne => {
let a = ecx.read_immediate(args[0])?.to_scalar()?;
let b = ecx.read_immediate(args[1])?.to_scalar()?;
let cmp = if intrinsic_name == sym::ptr_guaranteed_eq {
ecx.guaranteed_eq(a, b)
} else {
ecx.guaranteed_ne(a, b)
};
ecx.write_scalar(Scalar::from_bool(cmp), dest)?;
}
sym::const_allocate => {
let size = ecx.read_scalar(args[0])?.to_machine_usize(ecx)?;
let align = ecx.read_scalar(args[1])?.to_machine_usize(ecx)?;
let align = match Align::from_bytes(align) {
Ok(a) => a,
Err(err) => throw_ub_format!("align has to be a power of 2, {}", err),
};
let ptr = ecx.memory.allocate(
Size::from_bytes(size as u64),
align,
interpret::MemoryKind::Machine(MemoryKind::Heap),
);
ecx.write_scalar(Scalar::Ptr(ptr), dest)?;
}
_ => {
return Err(ConstEvalErrKind::NeedsRfc(format!(
"calling intrinsic `{}`",
intrinsic_name
))
.into());
}
}
ecx.go_to_block(ret);
Ok(())
}
fn assert_panic(
ecx: &mut InterpCx<'mir, 'tcx, Self>,
msg: &AssertMessage<'tcx>,
_unwind: Option<mir::BasicBlock>,
) -> InterpResult<'tcx> {
use rustc_middle::mir::AssertKind::*;
let eval_to_int =
|op| ecx.read_immediate(ecx.eval_operand(op, None)?).map(|x| x.to_const_int());
let err = match msg {
BoundsCheck { ref len, ref index } => {
let len = eval_to_int(len)?;
let index = eval_to_int(index)?;
BoundsCheck { len, index }
}
Overflow(op, l, r) => Overflow(*op, eval_to_int(l)?, eval_to_int(r)?),
OverflowNeg(op) => OverflowNeg(eval_to_int(op)?),
DivisionByZero(op) => DivisionByZero(eval_to_int(op)?),
RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
ResumedAfterReturn(generator_kind) => ResumedAfterReturn(*generator_kind),
ResumedAfterPanic(generator_kind) => ResumedAfterPanic(*generator_kind),
};
Err(ConstEvalErrKind::AssertFailure(err).into())
}
fn abort(_ecx: &mut InterpCx<'mir, 'tcx, Self>, msg: String) -> InterpResult<'tcx, !> {
Err(ConstEvalErrKind::Abort(msg).into())
}
fn ptr_to_int(_mem: &Memory<'mir, 'tcx, Self>, _ptr: Pointer) -> InterpResult<'tcx, u64> {
Err(ConstEvalErrKind::NeedsRfc("pointer-to-integer cast".to_string()).into())
}
fn binary_ptr_op(
_ecx: &InterpCx<'mir, 'tcx, Self>,
_bin_op: mir::BinOp,
_left: ImmTy<'tcx>,
_right: ImmTy<'tcx>,
) -> InterpResult<'tcx, (Scalar, bool, Ty<'tcx>)> {
Err(ConstEvalErrKind::NeedsRfc("pointer arithmetic or comparison".to_string()).into())
}
fn box_alloc(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_dest: PlaceTy<'tcx>,
) -> InterpResult<'tcx> {
Err(ConstEvalErrKind::NeedsRfc("heap allocations via `box` keyword".to_string()).into())
}
fn before_terminator(ecx: &mut InterpCx<'mir, 'tcx, Self>) -> InterpResult<'tcx> {
if ecx.machine.steps_remaining == 0 {
return Ok(());
}
ecx.machine.steps_remaining -= 1;
if ecx.machine.steps_remaining == 0 {
throw_exhaust!(StepLimitReached)
}
Ok(())
}
#[inline(always)]
fn init_frame_extra(
ecx: &mut InterpCx<'mir, 'tcx, Self>,
frame: Frame<'mir, 'tcx>,
) -> InterpResult<'tcx, Frame<'mir, 'tcx>> {
if !ecx.tcx.sess.recursion_limit().value_within_limit(ecx.stack().len() + 1) {
throw_exhaust!(StackFrameLimitReached)
} else {
Ok(frame)
}
}
#[inline(always)]
fn stack(
ecx: &'a InterpCx<'mir, 'tcx, Self>,
) -> &'a [Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>] {
&ecx.machine.stack
}
#[inline(always)]
fn stack_mut(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
) -> &'a mut Vec<Frame<'mir, 'tcx, Self::PointerTag, Self::FrameExtra>> {
&mut ecx.machine.stack
}
fn before_access_global(
memory_extra: &MemoryExtra,
alloc_id: AllocId,
allocation: &Allocation,
static_def_id: Option<DefId>,
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
if allocation.mutability == Mutability::Not {
Err(err_ub!(WriteToReadOnly(alloc_id)).into())
} else {
Err(ConstEvalErrKind::ModifiedGlobal.into())
}
} else {
if memory_extra.can_access_statics {
Ok(())
} else if static_def_id.is_some() {
Err(ConstEvalErrKind::ConstAccessesStatic.into())
} else {
assert_eq!(allocation.mutability, Mutability::Not);
Ok(())
}
}
}
}