Skip to main content

rustc_mir_transform/
instsimplify.rs

1//! Performs various peephole optimizations.
2
3use rustc_abi::{ExternAbi, Integer};
4use rustc_hir::{LangItem, find_attr};
5use rustc_index::IndexVec;
6use rustc_middle::bug;
7use rustc_middle::mir::visit::MutVisitor;
8use rustc_middle::mir::*;
9use rustc_middle::ty::layout::{IntegerExt, ValidityRequirement};
10use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, layout};
11use rustc_span::{Symbol, sym};
12
13use crate::simplify::simplify_duplicate_switch_targets;
14
15pub(super) enum InstSimplify {
16    BeforeInline,
17    AfterSimplifyCfg,
18}
19
20impl<'tcx> crate::MirPass<'tcx> for InstSimplify {
21    fn name(&self) -> &'static str {
22        match self {
23            InstSimplify::BeforeInline => "InstSimplify-before-inline",
24            InstSimplify::AfterSimplifyCfg => "InstSimplify-after-simplifycfg",
25        }
26    }
27
28    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
29        sess.mir_opt_level() > 0
30    }
31
32    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
33        let preserve_ub_checks = find_attr!(tcx.hir_krate_attrs(), RustcPreserveUbChecks);
34        if !preserve_ub_checks {
35            SimplifyUbCheck { tcx }.visit_body(body);
36        }
37        let mut ctx = InstSimplifyContext {
38            tcx,
39            typing_env: body.typing_env(tcx),
40            local_decls: &mut body.local_decls,
41        };
42        for block in body.basic_blocks.as_mut() {
43            for statement in block.statements.iter_mut() {
44                let StatementKind::Assign((.., rvalue)) = &mut statement.kind else {
45                    continue;
46                };
47
48                ctx.simplify_bool_cmp(rvalue);
49                ctx.simplify_ref_deref(rvalue);
50                ctx.simplify_ptr_aggregate(rvalue);
51                ctx.simplify_cast(rvalue);
52                ctx.simplify_repeated_aggregate(rvalue);
53                ctx.simplify_repeat_once(rvalue);
54            }
55
56            let terminator = block.terminator.as_mut().unwrap();
57            ctx.simplify_primitive_clone(terminator, &mut block.statements);
58            ctx.simplify_size_or_align_of_val(terminator, &mut block.statements);
59            ctx.simplify_raw_eq(terminator, &mut block.statements);
60            ctx.simplify_intrinsic_assert(terminator);
61            ctx.simplify_nounwind_call(terminator);
62            simplify_duplicate_switch_targets(terminator);
63        }
64    }
65
66    fn is_required(&self) -> bool {
67        false
68    }
69}
70
71struct InstSimplifyContext<'a, 'tcx> {
72    tcx: TyCtxt<'tcx>,
73    local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
74    typing_env: ty::TypingEnv<'tcx>,
75}
76
77impl<'tcx> InstSimplifyContext<'_, 'tcx> {
78    /// Transform aggregates like [0, 0, 0, 0, 0] into [0; 5].
79    /// GVN can also do this optimization, but GVN is only run at mir-opt-level 2 so having this in
80    /// InstSimplify helps unoptimized builds.
81    fn simplify_repeated_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
82        let Rvalue::Aggregate(AggregateKind::Array(_), fields) = &*rvalue else {
83            return;
84        };
85        if fields.len() < 5 {
86            return;
87        }
88        let (first, rest) = fields[..].split_first().unwrap();
89        let Operand::Constant(first) = first else {
90            return;
91        };
92        let Ok(first_val) = first.const_.eval(self.tcx, self.typing_env, first.span) else {
93            return;
94        };
95        if rest.iter().all(|field| {
96            let Operand::Constant(field) = field else {
97                return false;
98            };
99            let field = field.const_.eval(self.tcx, self.typing_env, field.span);
100            field == Ok(first_val)
101        }) {
102            let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap());
103            *rvalue = Rvalue::Repeat(Operand::Constant(first.clone()), len);
104        }
105    }
106
107    /// Transform boolean comparisons into logical operations.
108    fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) {
109        let Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), (a, b)) = &*rvalue else { return };
110        *rvalue = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
111            // Transform "Eq(a, true)" ==> "a"
112            (BinOp::Eq, _, Some(true)) => Rvalue::Use(a.clone(), WithRetag::Yes),
113
114            // Transform "Ne(a, false)" ==> "a"
115            (BinOp::Ne, _, Some(false)) => Rvalue::Use(a.clone(), WithRetag::Yes),
116
117            // Transform "Eq(true, b)" ==> "b"
118            (BinOp::Eq, Some(true), _) => Rvalue::Use(b.clone(), WithRetag::Yes),
119
120            // Transform "Ne(false, b)" ==> "b"
121            (BinOp::Ne, Some(false), _) => Rvalue::Use(b.clone(), WithRetag::Yes),
122
123            // Transform "Eq(false, b)" ==> "Not(b)"
124            (BinOp::Eq, Some(false), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
125
126            // Transform "Ne(true, b)" ==> "Not(b)"
127            (BinOp::Ne, Some(true), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
128
129            // Transform "Eq(a, false)" ==> "Not(a)"
130            (BinOp::Eq, _, Some(false)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
131
132            // Transform "Ne(a, true)" ==> "Not(a)"
133            (BinOp::Ne, _, Some(true)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
134
135            _ => return,
136        };
137    }
138
139    fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> {
140        let a = a.constant()?;
141        if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None }
142    }
143
144    /// Transform `&(*a)` ==> `a`.
145    fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) {
146        if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue
147            && let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection()
148            && rvalue.ty(self.local_decls, self.tcx) == base.ty(self.local_decls, self.tcx).ty
149        {
150            *rvalue = Rvalue::Use(
151                Operand::Copy(Place {
152                    local: base.local,
153                    projection: self.tcx.mk_place_elems(base.projection),
154                }),
155                // This might have been a two-phase borrow, which we should not upgrade
156                // to a full `&mut` reborrow.
157                // FIXME: Once Stacked Borrows is fully removed, we can use `Yes` here as
158                // Tree Borrows treats two-phase and full borrows the same.
159                if matches!(
160                    rvalue,
161                    Rvalue::Ref(_, BorrowKind::Mut { kind: MutBorrowKind::TwoPhaseBorrow }, _)
162                ) {
163                    WithRetag::No
164                } else {
165                    WithRetag::Yes
166                },
167            );
168        }
169    }
170
171    /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`.
172    fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
173        if let Rvalue::Aggregate(AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue
174            && let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx)
175            && meta_ty.is_unit()
176        {
177            // The mutable borrows we're holding prevent printing `rvalue` here
178            let mut fields = std::mem::take(fields);
179            let _meta = fields.pop().unwrap();
180            let data = fields.pop().unwrap();
181            let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
182            *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
183        }
184    }
185
186    fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) {
187        let Rvalue::Cast(kind, operand, cast_ty) = rvalue else { return };
188
189        let operand_ty = operand.ty(self.local_decls, self.tcx);
190        if operand_ty == *cast_ty {
191            *rvalue = Rvalue::Use(operand.clone(), WithRetag::Yes);
192        } else if *kind == CastKind::Transmute
193            // Transmuting an integer to another integer is just a signedness cast
194            && let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
195                (operand_ty.kind(), cast_ty.kind())
196            && int.bit_width() == uint.bit_width()
197        {
198            // The width check isn't strictly necessary, as different widths
199            // are UB and thus we'd be allowed to turn it into a cast anyway.
200            // But let's keep the UB around for codegen to exploit later.
201            // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
202            // then the width check is necessary for big-endian correctness.)
203            *kind = CastKind::IntToInt;
204        }
205    }
206
207    /// Simplify `[x; 1]` to just `[x]`.
208    fn simplify_repeat_once(&self, rvalue: &mut Rvalue<'tcx>) {
209        if let Rvalue::Repeat(operand, count) = rvalue
210            && let Some(1) = count.try_to_target_usize(self.tcx)
211        {
212            *rvalue = Rvalue::Aggregate(
213                Box::new(AggregateKind::Array(operand.ty(self.local_decls, self.tcx))),
214                [operand.clone()].into(),
215            );
216        }
217    }
218
219    fn simplify_primitive_clone(
220        &self,
221        terminator: &mut Terminator<'tcx>,
222        statements: &mut Vec<Statement<'tcx>>,
223    ) {
224        let TerminatorKind::Call {
225            func, args, destination, target: Some(destination_block), ..
226        } = &terminator.kind
227        else {
228            return;
229        };
230
231        // It's definitely not a clone if there are multiple arguments
232        let [arg] = &args[..] else { return };
233
234        // Only bother looking more if it's easy to know what we're calling
235        let Some((fn_def_id, ..)) = func.const_fn_def() else { return };
236
237        // These types are easily available from locals, so check that before
238        // doing DefId lookups to figure out what we're actually calling.
239        let arg_ty = arg.node.ty(self.local_decls, self.tcx);
240
241        let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() else { return };
242
243        if !self.tcx.is_lang_item(fn_def_id, LangItem::CloneFn)
244            || !inner_ty.is_trivially_pure_clone_copy()
245        {
246            return;
247        }
248
249        let Some(arg_place) = arg.node.place() else { return };
250
251        statements.push(Statement::new(
252            terminator.source_info,
253            StatementKind::Assign(Box::new((
254                *destination,
255                Rvalue::Use(
256                    Operand::Copy(arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx)),
257                    WithRetag::Yes,
258                ),
259            ))),
260        ));
261        terminator.kind = TerminatorKind::Goto { target: *destination_block };
262    }
263
264    /// Simplify `size_of_val` and `align_of_val` if we don't actually need
265    /// to look at the value in order to calculate the result:
266    /// - For `Sized` types we can always do this for both,
267    /// - For `align_of_val::<[T]>` we can return `align_of::<T>()`, since it
268    ///   doesn't depend on the slice's length and the elements are sized.
269    ///
270    /// This is here so it can run after inlining, where it's more useful.
271    /// (LowerIntrinsics is done in cleanup, before the optimization passes.)
272    ///
273    /// Note that we intentionally just produce the lang item constants so this
274    /// works on generic types and avoids any risk of layout calculation cycles.
275    fn simplify_size_or_align_of_val(
276        &self,
277        terminator: &mut Terminator<'tcx>,
278        statements: &mut Vec<Statement<'tcx>>,
279    ) {
280        let source_info = terminator.source_info;
281        if let TerminatorKind::Call {
282            func, args, destination, target: Some(destination_block), ..
283        } = &terminator.kind
284            && args.len() == 1
285            && let Some((fn_def_id, generics)) = func.const_fn_def()
286        {
287            let lang_item = if self.tcx.is_intrinsic(fn_def_id, sym::size_of_val) {
288                LangItem::SizeOf
289            } else if self.tcx.is_intrinsic(fn_def_id, sym::align_of_val) {
290                LangItem::AlignOf
291            } else {
292                return;
293            };
294            let generic_ty = generics.type_at(0);
295            let ty = if generic_ty.is_sized(self.tcx, self.typing_env) {
296                generic_ty
297            } else if let LangItem::AlignOf = lang_item
298                && let ty::Slice(elem_ty) = *generic_ty.kind()
299            {
300                elem_ty
301            } else {
302                return;
303            };
304
305            let const_def_id = self.tcx.require_lang_item(lang_item, source_info.span);
306            let const_op = Operand::unevaluated_constant(
307                self.tcx,
308                const_def_id,
309                &[ty.into()],
310                source_info.span,
311            );
312            statements.push(Statement::new(
313                source_info,
314                StatementKind::Assign(Box::new((
315                    *destination,
316                    Rvalue::Use(const_op, WithRetag::Yes),
317                ))),
318            ));
319            terminator.kind = TerminatorKind::Goto { target: *destination_block };
320        }
321    }
322
323    /// Simplify `raw_eq` intrinsic calls to `Eq` when the type has the size of a primitive.
324    ///
325    /// For example, replace `raw_eq::<[u8; 4]>(a, b)` with `Eq(Transmute(a), Transmute(b))`.
326    fn simplify_raw_eq(
327        &mut self,
328        terminator: &mut Terminator<'tcx>,
329        statements: &mut Vec<Statement<'tcx>>,
330    ) {
331        let tcx = self.tcx;
332        let source_info = terminator.source_info;
333        let span = source_info.span;
334        if let TerminatorKind::Call {
335            func, args, destination, target: Some(destination_block), ..
336        } = &terminator.kind
337            && args.len() == 2
338            && let Some((fn_def_id, generics)) = func.const_fn_def()
339            && tcx.is_intrinsic(fn_def_id, sym::raw_eq)
340            && let generic_ty = generics.type_at(0)
341            && let Ok(layout) = tcx.layout_of(self.typing_env.as_query_input(generic_ty))
342            && let Ok(integer) = Integer::from_size(layout.size)
343        {
344            let ref_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, generic_ty);
345            let uint_ty = integer.to_ty(tcx, false);
346
347            let mut transmute_operand = |op: &Operand<'tcx>| -> Operand<'tcx> {
348                let ref_local = self.local_decls.push(LocalDecl::new(ref_ty, span));
349                statements.push(Statement::new(
350                    source_info,
351                    StatementKind::Assign(Box::new((
352                        Place::from(ref_local),
353                        Rvalue::Use(op.clone(), WithRetag::Yes),
354                    ))),
355                ));
356                let place = Place::from(ref_local).project_deeper(&[ProjectionElem::Deref], tcx);
357                let int_local = self.local_decls.push(LocalDecl::new(uint_ty, span));
358                statements.push(Statement::new(
359                    source_info,
360                    StatementKind::Assign(Box::new((
361                        Place::from(int_local),
362                        Rvalue::Cast(CastKind::Transmute, Operand::Copy(place), uint_ty),
363                    ))),
364                ));
365                Operand::Move(Place::from(int_local))
366            };
367            let lhs_op = transmute_operand(&args[0].node);
368            let rhs_op = transmute_operand(&args[1].node);
369            statements.push(Statement::new(
370                source_info,
371                StatementKind::Assign(Box::new((
372                    *destination,
373                    Rvalue::BinaryOp(BinOp::Eq, Box::new((lhs_op, rhs_op))),
374                ))),
375            ));
376            terminator.kind = TerminatorKind::Goto { target: *destination_block };
377        }
378    }
379
380    fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
381        let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else {
382            return;
383        };
384
385        let Some((def_id, _)) = func.const_fn_def() else {
386            return;
387        };
388
389        let body_ty = self.tcx.type_of(def_id).skip_binder();
390        let body_abi = match body_ty.kind() {
391            ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(),
392            ty::Closure(..) => ExternAbi::RustCall,
393            ty::Coroutine(..) => ExternAbi::Rust,
394            _ => bug!("unexpected body ty: {body_ty:?}"),
395        };
396
397        if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) {
398            *unwind = UnwindAction::Unreachable;
399        }
400    }
401
402    fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) {
403        let TerminatorKind::Call { ref func, target: ref mut target @ Some(target_block), .. } =
404            terminator.kind
405        else {
406            return;
407        };
408        let func_ty = func.ty(self.local_decls, self.tcx);
409        let Some((intrinsic_name, args)) = resolve_rust_intrinsic(self.tcx, func_ty) else {
410            return;
411        };
412        // The intrinsics we are interested in have one generic parameter
413        let [arg, ..] = args[..] else { return };
414
415        let known_is_valid =
416            intrinsic_assert_panics(self.tcx, self.typing_env, arg, intrinsic_name);
417        match known_is_valid {
418            // We don't know the layout or it's not validity assertion at all, don't touch it
419            None => {}
420            Some(true) => {
421                // If we know the assert panics, indicate to later opts that the call diverges
422                *target = None;
423            }
424            Some(false) => {
425                // If we know the assert does not panic, turn the call into a Goto
426                terminator.kind = TerminatorKind::Goto { target: target_block };
427            }
428        }
429    }
430}
431
432fn intrinsic_assert_panics<'tcx>(
433    tcx: TyCtxt<'tcx>,
434    typing_env: ty::TypingEnv<'tcx>,
435    arg: ty::GenericArg<'tcx>,
436    intrinsic_name: Symbol,
437) -> Option<bool> {
438    let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?;
439    let ty = arg.expect_ty();
440    Some(!tcx.check_validity_requirement((requirement, typing_env.as_query_input(ty))).ok()?)
441}
442
443fn resolve_rust_intrinsic<'tcx>(
444    tcx: TyCtxt<'tcx>,
445    func_ty: Ty<'tcx>,
446) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
447    let ty::FnDef(def_id, args) = *func_ty.kind() else { return None };
448    let intrinsic = tcx.intrinsic(def_id)?;
449    Some((intrinsic.name, args))
450}
451
452struct SimplifyUbCheck<'tcx> {
453    tcx: TyCtxt<'tcx>,
454}
455
456impl<'tcx> MutVisitor<'tcx> for SimplifyUbCheck<'tcx> {
457    fn tcx(&self) -> TyCtxt<'tcx> {
458        self.tcx
459    }
460
461    fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) {
462        if let Operand::RuntimeChecks(RuntimeChecks::UbChecks) = operand {
463            *operand = Operand::Constant(Box::new(ConstOperand {
464                span: rustc_span::DUMMY_SP,
465                user_ty: None,
466                const_: Const::Val(
467                    ConstValue::from_bool(self.tcx.sess.ub_checks()),
468                    self.tcx.types.bool,
469                ),
470            }));
471        }
472    }
473}