1use rustc_data_structures::fx::FxHashMap;
2use rustc_index::Idx;
3use rustc_middle::mir::*;
4use rustc_middle::ty::Ty;
5use rustc_span::Span;
6use tracing::debug;
7
8pub(crate) struct MirPatch<'tcx> {
13 term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
14 nop_statements: Vec<Location>,
16 new_blocks: Vec<BasicBlockData<'tcx>>,
17 new_statements: Vec<(Location, StatementKind<'tcx>)>,
18 new_locals: Vec<LocalDecl<'tcx>>,
19 resume_block: Option<BasicBlock>,
20 unreachable_cleanup_block: Option<BasicBlock>,
22 unreachable_no_cleanup_block: Option<BasicBlock>,
24 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
26 body_span: Span,
27 next_local: usize,
28 next_block: usize,
31}
32
33impl<'tcx> MirPatch<'tcx> {
34 pub(crate) fn new(body: &Body<'tcx>) -> Self {
36 let mut result = MirPatch {
37 term_patch_map: Default::default(),
38 nop_statements: vec![],
39 new_blocks: vec![],
40 new_statements: vec![],
41 new_locals: vec![],
42 next_local: body.local_decls.len(),
43 next_block: body.basic_blocks.len(),
44 resume_block: None,
45 unreachable_cleanup_block: None,
46 unreachable_no_cleanup_block: None,
47 terminate_block: None,
48 body_span: body.span,
49 };
50
51 for (bb, block) in body.basic_blocks.iter_enumerated() {
52 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
54 && block.statements.is_empty()
55 {
56 result.resume_block = Some(bb);
57 continue;
58 }
59
60 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
62 && block.statements.is_empty()
63 {
64 if block.is_cleanup {
65 result.unreachable_cleanup_block = Some(bb);
66 } else {
67 result.unreachable_no_cleanup_block = Some(bb);
68 }
69 continue;
70 }
71
72 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
74 && block.statements.is_empty()
75 {
76 result.terminate_block = Some((bb, reason));
77 continue;
78 }
79 }
80
81 result
82 }
83
84 pub(crate) fn resume_block(&mut self) -> BasicBlock {
85 if let Some(bb) = self.resume_block {
86 return bb;
87 }
88
89 let bb = self.new_block(BasicBlockData::new(
90 Some(Terminator {
91 source_info: SourceInfo::outermost(self.body_span),
92 kind: TerminatorKind::UnwindResume,
93 }),
94 true,
95 ));
96 self.resume_block = Some(bb);
97 bb
98 }
99
100 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
101 if let Some(bb) = self.unreachable_cleanup_block {
102 return bb;
103 }
104
105 let bb = self.new_block(BasicBlockData::new(
106 Some(Terminator {
107 source_info: SourceInfo::outermost(self.body_span),
108 kind: TerminatorKind::Unreachable,
109 }),
110 true,
111 ));
112 self.unreachable_cleanup_block = Some(bb);
113 bb
114 }
115
116 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
117 if let Some(bb) = self.unreachable_no_cleanup_block {
118 return bb;
119 }
120
121 let bb = self.new_block(BasicBlockData::new(
122 Some(Terminator {
123 source_info: SourceInfo::outermost(self.body_span),
124 kind: TerminatorKind::Unreachable,
125 }),
126 false,
127 ));
128 self.unreachable_no_cleanup_block = Some(bb);
129 bb
130 }
131
132 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
133 if let Some((cached_bb, cached_reason)) = self.terminate_block
134 && reason == cached_reason
135 {
136 return cached_bb;
137 }
138
139 let bb = self.new_block(BasicBlockData::new(
140 Some(Terminator {
141 source_info: SourceInfo::outermost(self.body_span),
142 kind: TerminatorKind::UnwindTerminate(reason),
143 }),
144 true,
145 ));
146 self.terminate_block = Some((bb, reason));
147 bb
148 }
149
150 pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
152 self.term_patch_map.contains_key(&bb)
153 }
154
155 pub(crate) fn block<'a>(
157 &'a self,
158 body: &'a Body<'tcx>,
159 bb: BasicBlock,
160 ) -> &'a BasicBlockData<'tcx> {
161 match bb.index().checked_sub(body.basic_blocks.len()) {
162 Some(new) => &self.new_blocks[new],
163 None => &body[bb],
164 }
165 }
166
167 pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
168 let offset = self.block(body, bb).statements.len();
169 Location { block: bb, statement_index: offset }
170 }
171
172 pub(crate) fn new_local_with_info(
174 &mut self,
175 ty: Ty<'tcx>,
176 span: Span,
177 local_info: LocalInfo<'tcx>,
178 ) -> Local {
179 let index = self.next_local;
180 self.next_local += 1;
181 let mut new_decl = LocalDecl::new(ty, span);
182 **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
183 self.new_locals.push(new_decl);
184 Local::new(index)
185 }
186
187 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
189 let index = self.next_local;
190 self.next_local += 1;
191 self.new_locals.push(LocalDecl::new(ty, span));
192 Local::new(index)
193 }
194
195 pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
197 let local = local.as_usize();
198 assert!(local < self.next_local);
199 let new_local_idx = self.new_locals.len() - (self.next_local - local);
200 self.new_locals[new_local_idx].ty
201 }
202
203 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
205 let block = BasicBlock::from_usize(self.next_block + self.new_blocks.len());
206 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
207 self.new_blocks.push(data);
208 block
209 }
210
211 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
213 assert!(!self.term_patch_map.contains_key(&block));
214 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
215 self.term_patch_map.insert(block, new);
216 }
217
218 #[tracing::instrument(level = "debug", skip(self))]
223 pub(crate) fn nop_statement(&mut self, loc: Location) {
224 self.nop_statements.push(loc);
225 }
226
227 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
241 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
242 self.new_statements.push((loc, stmt));
243 }
244
245 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
247 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
248 }
249
250 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
252 debug!(
253 "MirPatch: {:?} new temps, starting from index {}: {:?}",
254 self.new_locals.len(),
255 body.local_decls.len(),
256 self.new_locals
257 );
258 debug!(
259 "MirPatch: {} new blocks, starting from index {}",
260 self.new_blocks.len(),
261 body.basic_blocks.len()
262 );
263 debug_assert_eq!(self.next_block, body.basic_blocks.len());
264 let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
265 body.basic_blocks.as_mut_preserves_cfg()
266 } else {
267 body.basic_blocks.as_mut()
268 };
269 bbs.extend(self.new_blocks);
270 body.local_decls.extend(self.new_locals);
271
272 for loc in self.nop_statements {
273 bbs[loc.block].statements[loc.statement_index].make_nop();
274 }
275
276 let mut new_statements = self.new_statements;
277
278 new_statements.sort_by_key(|s| s.0);
281
282 let mut delta = 0;
283 let mut last_bb = START_BLOCK;
284 for (mut loc, stmt) in new_statements {
285 if loc.block != last_bb {
286 delta = 0;
287 last_bb = loc.block;
288 }
289 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
290 loc.statement_index += delta;
291 let source_info = Self::source_info_for_index(&bbs[loc.block], loc);
292 bbs[loc.block]
293 .statements
294 .insert(loc.statement_index, Statement::new(source_info, stmt));
295 delta += 1;
296 }
297
298 #[allow(rustc::potential_query_instability)]
300 for (src, patch) in self.term_patch_map {
301 debug!("MirPatch: patching block {:?}", src);
302 let bb = &mut bbs[src];
303 if let TerminatorKind::Unreachable = patch {
304 bb.statements.clear();
305 }
306 bb.terminator_mut().kind = patch;
307 }
308 }
309
310 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
311 match data.statements.get(loc.statement_index) {
312 Some(stmt) => stmt.source_info,
313 None => data.terminator().source_info,
314 }
315 }
316
317 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
318 let data = self.block(body, loc.block);
319 Self::source_info_for_index(data, loc)
320 }
321}