1use rustc_index::{Idx, IndexVec};
2use rustc_middle::mir::*;
3use rustc_middle::ty::Ty;
4use rustc_span::Span;
5use tracing::debug;
6
7pub(crate) struct MirPatch<'tcx> {
12 term_patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
13 new_blocks: Vec<BasicBlockData<'tcx>>,
14 new_statements: Vec<(Location, StatementKind<'tcx>)>,
15 new_locals: Vec<LocalDecl<'tcx>>,
16 resume_block: Option<BasicBlock>,
17 unreachable_cleanup_block: Option<BasicBlock>,
19 unreachable_no_cleanup_block: Option<BasicBlock>,
21 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
23 body_span: Span,
24 next_local: usize,
25}
26
27impl<'tcx> MirPatch<'tcx> {
28 pub(crate) fn new(body: &Body<'tcx>) -> Self {
30 let mut result = MirPatch {
31 term_patch_map: IndexVec::from_elem(None, &body.basic_blocks),
32 new_blocks: vec![],
33 new_statements: vec![],
34 new_locals: vec![],
35 next_local: body.local_decls.len(),
36 resume_block: None,
37 unreachable_cleanup_block: None,
38 unreachable_no_cleanup_block: None,
39 terminate_block: None,
40 body_span: body.span,
41 };
42
43 for (bb, block) in body.basic_blocks.iter_enumerated() {
44 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
46 && block.statements.is_empty()
47 {
48 result.resume_block = Some(bb);
49 continue;
50 }
51
52 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
54 && block.statements.is_empty()
55 {
56 if block.is_cleanup {
57 result.unreachable_cleanup_block = Some(bb);
58 } else {
59 result.unreachable_no_cleanup_block = Some(bb);
60 }
61 continue;
62 }
63
64 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
66 && block.statements.is_empty()
67 {
68 result.terminate_block = Some((bb, reason));
69 continue;
70 }
71 }
72
73 result
74 }
75
76 pub(crate) fn resume_block(&mut self) -> BasicBlock {
77 if let Some(bb) = self.resume_block {
78 return bb;
79 }
80
81 let bb = self.new_block(BasicBlockData {
82 statements: vec![],
83 terminator: Some(Terminator {
84 source_info: SourceInfo::outermost(self.body_span),
85 kind: TerminatorKind::UnwindResume,
86 }),
87 is_cleanup: true,
88 });
89 self.resume_block = Some(bb);
90 bb
91 }
92
93 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
94 if let Some(bb) = self.unreachable_cleanup_block {
95 return bb;
96 }
97
98 let bb = self.new_block(BasicBlockData {
99 statements: vec![],
100 terminator: Some(Terminator {
101 source_info: SourceInfo::outermost(self.body_span),
102 kind: TerminatorKind::Unreachable,
103 }),
104 is_cleanup: true,
105 });
106 self.unreachable_cleanup_block = Some(bb);
107 bb
108 }
109
110 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
111 if let Some(bb) = self.unreachable_no_cleanup_block {
112 return bb;
113 }
114
115 let bb = self.new_block(BasicBlockData {
116 statements: vec![],
117 terminator: Some(Terminator {
118 source_info: SourceInfo::outermost(self.body_span),
119 kind: TerminatorKind::Unreachable,
120 }),
121 is_cleanup: false,
122 });
123 self.unreachable_no_cleanup_block = Some(bb);
124 bb
125 }
126
127 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
128 if let Some((cached_bb, cached_reason)) = self.terminate_block
129 && reason == cached_reason
130 {
131 return cached_bb;
132 }
133
134 let bb = self.new_block(BasicBlockData {
135 statements: vec![],
136 terminator: Some(Terminator {
137 source_info: SourceInfo::outermost(self.body_span),
138 kind: TerminatorKind::UnwindTerminate(reason),
139 }),
140 is_cleanup: true,
141 });
142 self.terminate_block = Some((bb, reason));
143 bb
144 }
145
146 pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
148 self.term_patch_map[bb].is_some()
149 }
150
151 pub(crate) fn new_local_with_info(
153 &mut self,
154 ty: Ty<'tcx>,
155 span: Span,
156 local_info: LocalInfo<'tcx>,
157 ) -> Local {
158 let index = self.next_local;
159 self.next_local += 1;
160 let mut new_decl = LocalDecl::new(ty, span);
161 **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
162 self.new_locals.push(new_decl);
163 Local::new(index)
164 }
165
166 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
168 let index = self.next_local;
169 self.next_local += 1;
170 self.new_locals.push(LocalDecl::new(ty, span));
171 Local::new(index)
172 }
173
174 pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
176 let local = local.as_usize();
177 assert!(local < self.next_local);
178 let new_local_idx = self.new_locals.len() - (self.next_local - local);
179 self.new_locals[new_local_idx].ty
180 }
181
182 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
184 let block = BasicBlock::new(self.term_patch_map.len());
185 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
186 self.new_blocks.push(data);
187 self.term_patch_map.push(None);
188 block
189 }
190
191 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
193 assert!(self.term_patch_map[block].is_none());
194 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
195 self.term_patch_map[block] = Some(new);
196 }
197
198 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
212 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
213 self.new_statements.push((loc, stmt));
214 }
215
216 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
218 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
219 }
220
221 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
223 debug!(
224 "MirPatch: {:?} new temps, starting from index {}: {:?}",
225 self.new_locals.len(),
226 body.local_decls.len(),
227 self.new_locals
228 );
229 debug!(
230 "MirPatch: {} new blocks, starting from index {}",
231 self.new_blocks.len(),
232 body.basic_blocks.len()
233 );
234 let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
235 body.basic_blocks.as_mut_preserves_cfg()
236 } else {
237 body.basic_blocks.as_mut()
238 };
239 bbs.extend(self.new_blocks);
240 body.local_decls.extend(self.new_locals);
241 for (src, patch) in self.term_patch_map.into_iter_enumerated() {
242 if let Some(patch) = patch {
243 debug!("MirPatch: patching block {:?}", src);
244 bbs[src].terminator_mut().kind = patch;
245 }
246 }
247
248 let mut new_statements = self.new_statements;
249
250 new_statements.sort_by_key(|s| s.0);
253
254 let mut delta = 0;
255 let mut last_bb = START_BLOCK;
256 for (mut loc, stmt) in new_statements {
257 if loc.block != last_bb {
258 delta = 0;
259 last_bb = loc.block;
260 }
261 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
262 loc.statement_index += delta;
263 let source_info = Self::source_info_for_index(&body[loc.block], loc);
264 body[loc.block]
265 .statements
266 .insert(loc.statement_index, Statement { source_info, kind: stmt });
267 delta += 1;
268 }
269 }
270
271 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
272 match data.statements.get(loc.statement_index) {
273 Some(stmt) => stmt.source_info,
274 None => data.terminator().source_info,
275 }
276 }
277
278 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
279 let data = match loc.block.index().checked_sub(body.basic_blocks.len()) {
280 Some(new) => &self.new_blocks[new],
281 None => &body[loc.block],
282 };
283 Self::source_info_for_index(data, loc)
284 }
285}