1use rustc_data_structures::fx::FxHashMap;
2use rustc_index::Idx;
3use rustc_middle::mir::*;
4use rustc_middle::ty::Ty;
5use rustc_span::Span;
6use tracing::debug;
7
8pub(crate) struct MirPatch<'tcx> {
13 term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
14 nop_statements: Vec<Location>,
16 new_blocks: Vec<BasicBlockData<'tcx>>,
17 new_statements: Vec<(Location, StatementKind<'tcx>)>,
18 new_locals: Vec<LocalDecl<'tcx>>,
19 resume_block: Option<BasicBlock>,
20 unreachable_cleanup_block: Option<BasicBlock>,
22 unreachable_no_cleanup_block: Option<BasicBlock>,
24 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
26 body_span: Span,
27 next_local: usize,
30 next_block: usize,
33}
34
35impl<'tcx> MirPatch<'tcx> {
36 pub(crate) fn new(body: &Body<'tcx>) -> Self {
38 let mut result = MirPatch {
39 term_patch_map: Default::default(),
40 nop_statements: vec![],
41 new_blocks: vec![],
42 new_statements: vec![],
43 new_locals: vec![],
44 next_local: body.local_decls.len(),
45 next_block: body.basic_blocks.len(),
46 resume_block: None,
47 unreachable_cleanup_block: None,
48 unreachable_no_cleanup_block: None,
49 terminate_block: None,
50 body_span: body.span,
51 };
52
53 for (bb, block) in body.basic_blocks.iter_enumerated() {
54 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
56 && block.statements.is_empty()
57 {
58 result.resume_block = Some(bb);
59 continue;
60 }
61
62 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
64 && block.statements.is_empty()
65 {
66 if block.is_cleanup {
67 result.unreachable_cleanup_block = Some(bb);
68 } else {
69 result.unreachable_no_cleanup_block = Some(bb);
70 }
71 continue;
72 }
73
74 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
76 && block.statements.is_empty()
77 {
78 result.terminate_block = Some((bb, reason));
79 continue;
80 }
81 }
82
83 result
84 }
85
86 pub(crate) fn resume_block(&mut self) -> BasicBlock {
87 if let Some(bb) = self.resume_block {
88 return bb;
89 }
90
91 let bb = self.new_block(BasicBlockData::new(
92 Some(Terminator {
93 source_info: SourceInfo::outermost(self.body_span),
94 kind: TerminatorKind::UnwindResume,
95 }),
96 true,
97 ));
98 self.resume_block = Some(bb);
99 bb
100 }
101
102 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
103 if let Some(bb) = self.unreachable_cleanup_block {
104 return bb;
105 }
106
107 let bb = self.new_block(BasicBlockData::new(
108 Some(Terminator {
109 source_info: SourceInfo::outermost(self.body_span),
110 kind: TerminatorKind::Unreachable,
111 }),
112 true,
113 ));
114 self.unreachable_cleanup_block = Some(bb);
115 bb
116 }
117
118 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
119 if let Some(bb) = self.unreachable_no_cleanup_block {
120 return bb;
121 }
122
123 let bb = self.new_block(BasicBlockData::new(
124 Some(Terminator {
125 source_info: SourceInfo::outermost(self.body_span),
126 kind: TerminatorKind::Unreachable,
127 }),
128 false,
129 ));
130 self.unreachable_no_cleanup_block = Some(bb);
131 bb
132 }
133
134 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
135 if let Some((cached_bb, cached_reason)) = self.terminate_block
136 && reason == cached_reason
137 {
138 return cached_bb;
139 }
140
141 let bb = self.new_block(BasicBlockData::new(
142 Some(Terminator {
143 source_info: SourceInfo::outermost(self.body_span),
144 kind: TerminatorKind::UnwindTerminate(reason),
145 }),
146 true,
147 ));
148 self.terminate_block = Some((bb, reason));
149 bb
150 }
151
152 pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
154 self.term_patch_map.contains_key(&bb)
155 }
156
157 pub(crate) fn block<'a>(
159 &'a self,
160 body: &'a Body<'tcx>,
161 bb: BasicBlock,
162 ) -> &'a BasicBlockData<'tcx> {
163 match bb.index().checked_sub(body.basic_blocks.len()) {
164 Some(new) => &self.new_blocks[new],
165 None => &body[bb],
166 }
167 }
168
169 pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
170 let offset = self.block(body, bb).statements.len();
171 Location { block: bb, statement_index: offset }
172 }
173
174 pub(crate) fn new_local_with_info(
176 &mut self,
177 ty: Ty<'tcx>,
178 span: Span,
179 local_info: LocalInfo<'tcx>,
180 ) -> Local {
181 let index = self.next_local + self.new_locals.len();
182 let mut new_decl = LocalDecl::new(ty, span);
183 **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
184 self.new_locals.push(new_decl);
185 Local::new(index)
186 }
187
188 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
190 let index = self.next_local + self.new_locals.len();
191 self.new_locals.push(LocalDecl::new(ty, span));
192 Local::new(index)
193 }
194
195 pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
197 let local = local.as_usize();
198 assert!(local < self.next_local + self.new_locals.len());
199 let new_local_idx = local - self.next_local;
200 self.new_locals[new_local_idx].ty
201 }
202
203 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
205 let block = BasicBlock::from_usize(self.next_block + self.new_blocks.len());
206 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
207 self.new_blocks.push(data);
208 block
209 }
210
211 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
213 assert!(!self.term_patch_map.contains_key(&block));
214 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
215 self.term_patch_map.insert(block, new);
216 }
217
218 #[tracing::instrument(level = "debug", skip(self))]
223 pub(crate) fn nop_statement(&mut self, loc: Location) {
224 self.nop_statements.push(loc);
225 }
226
227 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
241 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
242 self.new_statements.push((loc, stmt));
243 }
244
245 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
247 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
248 }
249
250 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
252 debug!(
253 "MirPatch: {:?} new temps, starting from index {}: {:?}",
254 self.new_locals.len(),
255 body.local_decls.len(),
256 self.new_locals
257 );
258 debug!(
259 "MirPatch: {} new blocks, starting from index {}",
260 self.new_blocks.len(),
261 body.basic_blocks.len()
262 );
263 debug_assert_eq!(self.next_block, body.basic_blocks.len());
264 let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
265 body.basic_blocks.as_mut_preserves_cfg()
266 } else {
267 body.basic_blocks.as_mut()
268 };
269 bbs.extend(self.new_blocks);
270 body.local_decls.extend(self.new_locals);
271
272 for loc in self.nop_statements {
273 bbs[loc.block].statements[loc.statement_index].make_nop(true);
274 }
275
276 let mut new_statements = self.new_statements;
277
278 new_statements.sort_by_key(|s| s.0);
281
282 let mut delta = 0;
283 let mut last_bb = START_BLOCK;
284 for (mut loc, stmt) in new_statements {
285 if loc.block != last_bb {
286 delta = 0;
287 last_bb = loc.block;
288 }
289 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
290 loc.statement_index += delta;
291 let source_info = Self::source_info_for_index(&bbs[loc.block], loc);
292 bbs[loc.block]
293 .statements
294 .insert(loc.statement_index, Statement::new(source_info, stmt));
295 delta += 1;
296 }
297
298 #[allow(rustc::potential_query_instability)]
300 for (src, patch) in self.term_patch_map {
301 debug!("MirPatch: patching block {:?}", src);
302 let bb = &mut bbs[src];
303 if let TerminatorKind::Unreachable = patch {
304 bb.statements.clear();
305 }
306 bb.terminator_mut().kind = patch;
307 }
308 }
309
310 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
311 match data.statements.get(loc.statement_index) {
312 Some(stmt) => stmt.source_info,
313 None => data.terminator().source_info,
314 }
315 }
316
317 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
318 let data = self.block(body, loc.block);
319 Self::source_info_for_index(data, loc)
320 }
321}