1use std::iter;
2
3use rustc_middle::bug;
4use rustc_middle::mir::interpret::Scalar;
5use rustc_middle::mir::{
6 BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets,
7 TerminatorKind,
8};
9use rustc_middle::ty::{Ty, TyCtxt};
10use tracing::trace;
11
12use crate::ssa::SsaLocals;
13
14pub(super) struct SimplifyComparisonIntegral;
29
30impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
31 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
32 sess.mir_opt_level() > 1
33 }
34
35 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
36 trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
37
38 let typing_env = body.typing_env(tcx);
39 let ssa = SsaLocals::new(tcx, body, typing_env);
40 let helper = OptimizationFinder { body };
41 let opts = helper.find_optimizations(&ssa);
42 let mut storage_deads_to_insert = vec![];
43 let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
44 for opt in opts {
45 trace!("SUCCESS: Applying {:?}", opt);
46 let bbs = &mut body.basic_blocks_mut();
48 let bb = &mut bbs[opt.bb_idx];
49 let new_value = match opt.branch_value_scalar {
50 Scalar::Int(int) => {
51 let layout = tcx
52 .layout_of(typing_env.as_query_input(opt.branch_value_ty))
53 .expect("if we have an evaluated constant we must know the layout");
54 int.to_bits(layout.size)
55 }
56 Scalar::Ptr(..) => continue,
57 };
58 const FALSE: u128 = 0;
59
60 let mut new_targets = opt.targets;
61 let first_value = new_targets.iter().next().unwrap().0;
62 let first_is_false_target = first_value == FALSE;
63 match opt.op {
64 BinOp::Eq => {
65 if first_is_false_target {
67 new_targets.all_targets_mut().swap(0, 1);
68 }
69 }
70 BinOp::Ne => {
71 if !first_is_false_target {
73 new_targets.all_targets_mut().swap(0, 1);
74 }
75 }
76 _ => unreachable!(),
77 }
78
79 if opt.can_remove_bin_op_stmt {
82 bb.statements[opt.bin_op_stmt_idx].make_nop(true);
83 } else {
84 let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
92
93 use Operand::*;
94 match rhs {
95 Rvalue::BinaryOp(_, box (left @ Move(_), Constant(_))) => {
96 *left = Copy(opt.to_switch_on);
97 }
98 Rvalue::BinaryOp(_, box (Constant(_), right @ Move(_))) => {
99 *right = Copy(opt.to_switch_on);
100 }
101 _ => (),
102 }
103 }
104
105 let terminator = bb.terminator();
106
107 for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
109 if !matches!(
110 stmt.kind,
111 StatementKind::StorageDead(local) if local == opt.to_switch_on.local
112 ) {
113 continue;
114 }
115 storage_deads_to_remove.push((stmt_idx, opt.bb_idx));
116 for bb_idx in new_targets.all_targets() {
119 storage_deads_to_insert.push((
120 *bb_idx,
121 Statement::new(
122 terminator.source_info,
123 StatementKind::StorageDead(opt.to_switch_on.local),
124 ),
125 ));
126 }
127 }
128
129 let [bb_cond, bb_otherwise] = match new_targets.all_targets() {
130 [a, b] => [*a, *b],
131 e => bug!("expected 2 switch targets, got: {:?}", e),
132 };
133
134 let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise);
135
136 let terminator = bb.terminator_mut();
137 terminator.kind =
138 TerminatorKind::SwitchInt { discr: Operand::Copy(opt.to_switch_on), targets };
139 }
140
141 for (idx, bb_idx) in storage_deads_to_remove {
142 body.basic_blocks_mut()[bb_idx].statements[idx].make_nop(true);
143 }
144
145 for (idx, stmt) in storage_deads_to_insert {
146 body.basic_blocks_mut()[idx].statements.insert(0, stmt);
147 }
148 }
149
150 fn is_required(&self) -> bool {
151 false
152 }
153}
154
155struct OptimizationFinder<'a, 'tcx> {
156 body: &'a Body<'tcx>,
157}
158
159impl<'tcx> OptimizationFinder<'_, 'tcx> {
160 fn find_optimizations(&self, ssa: &SsaLocals) -> Vec<OptimizationInfo<'tcx>> {
161 self.body
162 .basic_blocks
163 .iter_enumerated()
164 .filter_map(|(bb_idx, bb)| {
165 let (discr, targets) = bb.terminator().kind.as_switch()?;
167 let place_switched_on = discr.place()?;
168 if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() {
170 return None;
171 }
172
173 bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
175 match &stmt.kind {
176 rustc_middle::mir::StatementKind::Assign(box (lhs, rhs))
177 if *lhs == place_switched_on =>
178 {
179 match rhs {
180 Rvalue::BinaryOp(
181 op @ (BinOp::Eq | BinOp::Ne),
182 box (left, right),
183 ) => {
184 let (branch_value_scalar, branch_value_ty, to_switch_on) =
185 find_branch_value_info(left, right, ssa)?;
186
187 Some(OptimizationInfo {
188 bin_op_stmt_idx: stmt_idx,
189 bb_idx,
190 can_remove_bin_op_stmt: discr.is_move(),
191 to_switch_on,
192 branch_value_scalar,
193 branch_value_ty,
194 op: *op,
195 targets: targets.clone(),
196 })
197 }
198 _ => None,
199 }
200 }
201 _ => None,
202 }
203 })
204 })
205 .collect()
206 }
207}
208
209fn find_branch_value_info<'tcx>(
210 left: &Operand<'tcx>,
211 right: &Operand<'tcx>,
212 ssa: &SsaLocals,
213) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
214 use Operand::*;
217 match (left, right) {
218 (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
219 | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
220 if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() {
222 return None;
223 }
224 let branch_value_ty = branch_value.const_.ty();
225 if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
228 return None;
229 };
230 let branch_value_scalar = branch_value.const_.try_to_scalar()?;
231 Some((branch_value_scalar, branch_value_ty, *to_switch_on))
232 }
233 _ => None,
234 }
235}
236
237#[derive(Debug)]
238struct OptimizationInfo<'tcx> {
239 bb_idx: BasicBlock,
241 bin_op_stmt_idx: usize,
244 can_remove_bin_op_stmt: bool,
246 to_switch_on: Place<'tcx>,
248 branch_value_scalar: Scalar,
250 branch_value_ty: Ty<'tcx>,
252 op: BinOp,
254 targets: SwitchTargets,
256}