rustc_mir_transform/
instsimplify.rs

1//! Performs various peephole optimizations.
2
3use rustc_abi::ExternAbi;
4use rustc_ast::attr;
5use rustc_hir::LangItem;
6use rustc_middle::bug;
7use rustc_middle::mir::*;
8use rustc_middle::ty::layout::ValidityRequirement;
9use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, layout};
10use rustc_span::{DUMMY_SP, Symbol, sym};
11
12use crate::simplify::simplify_duplicate_switch_targets;
13use crate::take_array;
14
15pub(super) enum InstSimplify {
16    BeforeInline,
17    AfterSimplifyCfg,
18}
19
20impl<'tcx> crate::MirPass<'tcx> for InstSimplify {
21    fn name(&self) -> &'static str {
22        match self {
23            InstSimplify::BeforeInline => "InstSimplify-before-inline",
24            InstSimplify::AfterSimplifyCfg => "InstSimplify-after-simplifycfg",
25        }
26    }
27
28    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
29        sess.mir_opt_level() > 0
30    }
31
32    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
33        let ctx = InstSimplifyContext {
34            tcx,
35            local_decls: &body.local_decls,
36            typing_env: body.typing_env(tcx),
37        };
38        let preserve_ub_checks =
39            attr::contains_name(tcx.hir().krate_attrs(), sym::rustc_preserve_ub_checks);
40        for block in body.basic_blocks.as_mut() {
41            for statement in block.statements.iter_mut() {
42                match statement.kind {
43                    StatementKind::Assign(box (_place, ref mut rvalue)) => {
44                        if !preserve_ub_checks {
45                            ctx.simplify_ub_check(rvalue);
46                        }
47                        ctx.simplify_bool_cmp(rvalue);
48                        ctx.simplify_ref_deref(rvalue);
49                        ctx.simplify_ptr_aggregate(rvalue);
50                        ctx.simplify_cast(rvalue);
51                        ctx.simplify_repeated_aggregate(rvalue);
52                        ctx.simplify_repeat_once(rvalue);
53                    }
54                    _ => {}
55                }
56            }
57
58            ctx.simplify_primitive_clone(block.terminator.as_mut().unwrap(), &mut block.statements);
59            ctx.simplify_intrinsic_assert(block.terminator.as_mut().unwrap());
60            ctx.simplify_nounwind_call(block.terminator.as_mut().unwrap());
61            simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap());
62        }
63    }
64
65    fn is_required(&self) -> bool {
66        false
67    }
68}
69
70struct InstSimplifyContext<'a, 'tcx> {
71    tcx: TyCtxt<'tcx>,
72    local_decls: &'a LocalDecls<'tcx>,
73    typing_env: ty::TypingEnv<'tcx>,
74}
75
76impl<'tcx> InstSimplifyContext<'_, 'tcx> {
77    /// Transform aggregates like [0, 0, 0, 0, 0] into [0; 5].
78    /// GVN can also do this optimization, but GVN is only run at mir-opt-level 2 so having this in
79    /// InstSimplify helps unoptimized builds.
80    fn simplify_repeated_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
81        let Rvalue::Aggregate(box AggregateKind::Array(_), fields) = rvalue else {
82            return;
83        };
84        if fields.len() < 5 {
85            return;
86        }
87        let first = &fields[rustc_abi::FieldIdx::ZERO];
88        let Operand::Constant(first) = first else {
89            return;
90        };
91        let Ok(first_val) = first.const_.eval(self.tcx, self.typing_env, first.span) else {
92            return;
93        };
94        if fields.iter().all(|field| {
95            let Operand::Constant(field) = field else {
96                return false;
97            };
98            let field = field.const_.eval(self.tcx, self.typing_env, field.span);
99            field == Ok(first_val)
100        }) {
101            let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap());
102            *rvalue = Rvalue::Repeat(Operand::Constant(first.clone()), len);
103        }
104    }
105
106    /// Transform boolean comparisons into logical operations.
107    fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) {
108        match rvalue {
109            Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => {
110                let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
111                    // Transform "Eq(a, true)" ==> "a"
112                    (BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())),
113
114                    // Transform "Ne(a, false)" ==> "a"
115                    (BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())),
116
117                    // Transform "Eq(true, b)" ==> "b"
118                    (BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())),
119
120                    // Transform "Ne(false, b)" ==> "b"
121                    (BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())),
122
123                    // Transform "Eq(false, b)" ==> "Not(b)"
124                    (BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
125
126                    // Transform "Ne(true, b)" ==> "Not(b)"
127                    (BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
128
129                    // Transform "Eq(a, false)" ==> "Not(a)"
130                    (BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),
131
132                    // Transform "Ne(a, true)" ==> "Not(a)"
133                    (BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),
134
135                    _ => None,
136                };
137
138                if let Some(new) = new {
139                    *rvalue = new;
140                }
141            }
142
143            _ => {}
144        }
145    }
146
147    fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> {
148        let a = a.constant()?;
149        if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None }
150    }
151
152    /// Transform `&(*a)` ==> `a`.
153    fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) {
154        if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue {
155            if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() {
156                if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty {
157                    return;
158                }
159
160                *rvalue = Rvalue::Use(Operand::Copy(Place {
161                    local: base.local,
162                    projection: self.tcx.mk_place_elems(base.projection),
163                }));
164            }
165        }
166    }
167
168    /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`.
169    fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
170        if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue
171        {
172            let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx);
173            if meta_ty.is_unit() {
174                // The mutable borrows we're holding prevent printing `rvalue` here
175                let mut fields = std::mem::take(fields);
176                let _meta = fields.pop().unwrap();
177                let data = fields.pop().unwrap();
178                let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
179                *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
180            }
181        }
182    }
183
184    fn simplify_ub_check(&self, rvalue: &mut Rvalue<'tcx>) {
185        if let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue {
186            let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks());
187            let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None };
188            *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant)));
189        }
190    }
191
192    fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) {
193        if let Rvalue::Cast(kind, operand, cast_ty) = rvalue {
194            let operand_ty = operand.ty(self.local_decls, self.tcx);
195            if operand_ty == *cast_ty {
196                *rvalue = Rvalue::Use(operand.clone());
197            } else if *kind == CastKind::Transmute {
198                // Transmuting an integer to another integer is just a signedness cast
199                if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
200                    (operand_ty.kind(), cast_ty.kind())
201                    && int.bit_width() == uint.bit_width()
202                {
203                    // The width check isn't strictly necessary, as different widths
204                    // are UB and thus we'd be allowed to turn it into a cast anyway.
205                    // But let's keep the UB around for codegen to exploit later.
206                    // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
207                    // then the width check is necessary for big-endian correctness.)
208                    *kind = CastKind::IntToInt;
209                    return;
210                }
211            }
212        }
213    }
214
215    /// Simplify `[x; 1]` to just `[x]`.
216    fn simplify_repeat_once(&self, rvalue: &mut Rvalue<'tcx>) {
217        if let Rvalue::Repeat(operand, count) = rvalue
218            && let Some(1) = count.try_to_target_usize(self.tcx)
219        {
220            *rvalue = Rvalue::Aggregate(
221                Box::new(AggregateKind::Array(operand.ty(self.local_decls, self.tcx))),
222                [operand.clone()].into(),
223            );
224        }
225    }
226
227    fn simplify_primitive_clone(
228        &self,
229        terminator: &mut Terminator<'tcx>,
230        statements: &mut Vec<Statement<'tcx>>,
231    ) {
232        let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind
233        else {
234            return;
235        };
236
237        // It's definitely not a clone if there are multiple arguments
238        let [arg] = &args[..] else { return };
239
240        let Some(destination_block) = *target else { return };
241
242        // Only bother looking more if it's easy to know what we're calling
243        let Some((fn_def_id, fn_args)) = func.const_fn_def() else { return };
244
245        // Clone needs one arg, so we can cheaply rule out other stuff
246        if fn_args.len() != 1 {
247            return;
248        }
249
250        // These types are easily available from locals, so check that before
251        // doing DefId lookups to figure out what we're actually calling.
252        let arg_ty = arg.node.ty(self.local_decls, self.tcx);
253
254        let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() else { return };
255
256        if !inner_ty.is_trivially_pure_clone_copy() {
257            return;
258        }
259
260        if !self.tcx.is_lang_item(fn_def_id, LangItem::CloneFn) {
261            return;
262        }
263
264        let Ok([arg]) = take_array(args) else { return };
265        let Some(arg_place) = arg.node.place() else { return };
266
267        statements.push(Statement {
268            source_info: terminator.source_info,
269            kind: StatementKind::Assign(Box::new((
270                *destination,
271                Rvalue::Use(Operand::Copy(
272                    arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx),
273                )),
274            ))),
275        });
276        terminator.kind = TerminatorKind::Goto { target: destination_block };
277    }
278
279    fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
280        let TerminatorKind::Call { func, unwind, .. } = &mut terminator.kind else {
281            return;
282        };
283
284        let Some((def_id, _)) = func.const_fn_def() else {
285            return;
286        };
287
288        let body_ty = self.tcx.type_of(def_id).skip_binder();
289        let body_abi = match body_ty.kind() {
290            ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(),
291            ty::Closure(..) => ExternAbi::RustCall,
292            ty::Coroutine(..) => ExternAbi::Rust,
293            _ => bug!("unexpected body ty: {:?}", body_ty),
294        };
295
296        if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) {
297            *unwind = UnwindAction::Unreachable;
298        }
299    }
300
301    fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) {
302        let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else {
303            return;
304        };
305        let Some(target_block) = target else {
306            return;
307        };
308        let func_ty = func.ty(self.local_decls, self.tcx);
309        let Some((intrinsic_name, args)) = resolve_rust_intrinsic(self.tcx, func_ty) else {
310            return;
311        };
312        // The intrinsics we are interested in have one generic parameter
313        if args.is_empty() {
314            return;
315        }
316
317        let known_is_valid =
318            intrinsic_assert_panics(self.tcx, self.typing_env, args[0], intrinsic_name);
319        match known_is_valid {
320            // We don't know the layout or it's not validity assertion at all, don't touch it
321            None => {}
322            Some(true) => {
323                // If we know the assert panics, indicate to later opts that the call diverges
324                *target = None;
325            }
326            Some(false) => {
327                // If we know the assert does not panic, turn the call into a Goto
328                terminator.kind = TerminatorKind::Goto { target: *target_block };
329            }
330        }
331    }
332}
333
334fn intrinsic_assert_panics<'tcx>(
335    tcx: TyCtxt<'tcx>,
336    typing_env: ty::TypingEnv<'tcx>,
337    arg: ty::GenericArg<'tcx>,
338    intrinsic_name: Symbol,
339) -> Option<bool> {
340    let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?;
341    let ty = arg.expect_ty();
342    Some(!tcx.check_validity_requirement((requirement, typing_env.as_query_input(ty))).ok()?)
343}
344
345fn resolve_rust_intrinsic<'tcx>(
346    tcx: TyCtxt<'tcx>,
347    func_ty: Ty<'tcx>,
348) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
349    if let ty::FnDef(def_id, args) = *func_ty.kind() {
350        let intrinsic = tcx.intrinsic(def_id)?;
351        return Some((intrinsic.name, args));
352    }
353    None
354}