1use std::iter;
2
3use rustc_middle::bug;
4use rustc_middle::mir::interpret::Scalar;
5use rustc_middle::mir::{
6 BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets,
7 TerminatorKind,
8};
9use rustc_middle::ty::{Ty, TyCtxt};
10use tracing::trace;
11
12pub(super) struct SimplifyComparisonIntegral;
27
28impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
29 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
30 sess.mir_opt_level() > 0
31 }
32
33 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
34 trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
35
36 let helper = OptimizationFinder { body };
37 let opts = helper.find_optimizations();
38 let mut storage_deads_to_insert = vec![];
39 let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
40 let typing_env = body.typing_env(tcx);
41 for opt in opts {
42 trace!("SUCCESS: Applying {:?}", opt);
43 let bbs = &mut body.basic_blocks_mut();
45 let bb = &mut bbs[opt.bb_idx];
46 let new_value = match opt.branch_value_scalar {
47 Scalar::Int(int) => {
48 let layout = tcx
49 .layout_of(typing_env.as_query_input(opt.branch_value_ty))
50 .expect("if we have an evaluated constant we must know the layout");
51 int.to_bits(layout.size)
52 }
53 Scalar::Ptr(..) => continue,
54 };
55 const FALSE: u128 = 0;
56
57 let mut new_targets = opt.targets;
58 let first_value = new_targets.iter().next().unwrap().0;
59 let first_is_false_target = first_value == FALSE;
60 match opt.op {
61 BinOp::Eq => {
62 if first_is_false_target {
64 new_targets.all_targets_mut().swap(0, 1);
65 }
66 }
67 BinOp::Ne => {
68 if !first_is_false_target {
70 new_targets.all_targets_mut().swap(0, 1);
71 }
72 }
73 _ => unreachable!(),
74 }
75
76 if opt.can_remove_bin_op_stmt {
79 bb.statements[opt.bin_op_stmt_idx].make_nop();
80 } else {
81 let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
89
90 use Operand::*;
91 match rhs {
92 Rvalue::BinaryOp(_, box (ref mut left @ Move(_), Constant(_))) => {
93 *left = Copy(opt.to_switch_on);
94 }
95 Rvalue::BinaryOp(_, box (Constant(_), ref mut right @ Move(_))) => {
96 *right = Copy(opt.to_switch_on);
97 }
98 _ => (),
99 }
100 }
101
102 let terminator = bb.terminator();
103
104 for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
106 if !matches!(
107 stmt.kind,
108 StatementKind::StorageDead(local) if local == opt.to_switch_on.local
109 ) {
110 continue;
111 }
112 storage_deads_to_remove.push((stmt_idx, opt.bb_idx));
113 for bb_idx in new_targets.all_targets() {
116 storage_deads_to_insert.push((
117 *bb_idx,
118 Statement {
119 source_info: terminator.source_info,
120 kind: StatementKind::StorageDead(opt.to_switch_on.local),
121 },
122 ));
123 }
124 }
125
126 let [bb_cond, bb_otherwise] = match new_targets.all_targets() {
127 [a, b] => [*a, *b],
128 e => bug!("expected 2 switch targets, got: {:?}", e),
129 };
130
131 let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise);
132
133 let terminator = bb.terminator_mut();
134 terminator.kind =
135 TerminatorKind::SwitchInt { discr: Operand::Move(opt.to_switch_on), targets };
136 }
137
138 for (idx, bb_idx) in storage_deads_to_remove {
139 body.basic_blocks_mut()[bb_idx].statements[idx].make_nop();
140 }
141
142 for (idx, stmt) in storage_deads_to_insert {
143 body.basic_blocks_mut()[idx].statements.insert(0, stmt);
144 }
145 }
146
147 fn is_required(&self) -> bool {
148 false
149 }
150}
151
152struct OptimizationFinder<'a, 'tcx> {
153 body: &'a Body<'tcx>,
154}
155
156impl<'tcx> OptimizationFinder<'_, 'tcx> {
157 fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
158 self.body
159 .basic_blocks
160 .iter_enumerated()
161 .filter_map(|(bb_idx, bb)| {
162 let (place_switched_on, targets, place_switched_on_moved) =
164 match &bb.terminator().kind {
165 rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => {
166 Some((discr.place()?, targets, discr.is_move()))
167 }
168 _ => None,
169 }?;
170
171 bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
173 match &stmt.kind {
174 rustc_middle::mir::StatementKind::Assign(box (lhs, rhs))
175 if *lhs == place_switched_on =>
176 {
177 match rhs {
178 Rvalue::BinaryOp(
179 op @ (BinOp::Eq | BinOp::Ne),
180 box (left, right),
181 ) => {
182 let (branch_value_scalar, branch_value_ty, to_switch_on) =
183 find_branch_value_info(left, right)?;
184
185 Some(OptimizationInfo {
186 bin_op_stmt_idx: stmt_idx,
187 bb_idx,
188 can_remove_bin_op_stmt: place_switched_on_moved,
189 to_switch_on,
190 branch_value_scalar,
191 branch_value_ty,
192 op: *op,
193 targets: targets.clone(),
194 })
195 }
196 _ => None,
197 }
198 }
199 _ => None,
200 }
201 })
202 })
203 .collect()
204 }
205}
206
207fn find_branch_value_info<'tcx>(
208 left: &Operand<'tcx>,
209 right: &Operand<'tcx>,
210) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
211 use Operand::*;
214 match (left, right) {
215 (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
216 | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
217 let branch_value_ty = branch_value.const_.ty();
218 if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
221 return None;
222 };
223 let branch_value_scalar = branch_value.const_.try_to_scalar()?;
224 Some((branch_value_scalar, branch_value_ty, *to_switch_on))
225 }
226 _ => None,
227 }
228}
229
230#[derive(Debug)]
231struct OptimizationInfo<'tcx> {
232 bb_idx: BasicBlock,
234 bin_op_stmt_idx: usize,
237 can_remove_bin_op_stmt: bool,
239 to_switch_on: Place<'tcx>,
241 branch_value_scalar: Scalar,
243 branch_value_ty: Ty<'tcx>,
245 op: BinOp,
247 targets: SwitchTargets,
249}