1use rustc_index::{Idx, IndexVec};
2use rustc_middle::mir::*;
3use rustc_middle::ty::Ty;
4use rustc_span::Span;
5use tracing::debug;
6
7pub(crate) struct MirPatch<'tcx> {
11 patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
12 new_blocks: Vec<BasicBlockData<'tcx>>,
13 new_statements: Vec<(Location, StatementKind<'tcx>)>,
14 new_locals: Vec<LocalDecl<'tcx>>,
15 resume_block: Option<BasicBlock>,
16 unreachable_cleanup_block: Option<BasicBlock>,
18 unreachable_no_cleanup_block: Option<BasicBlock>,
20 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
22 body_span: Span,
23 next_local: usize,
24}
25
26impl<'tcx> MirPatch<'tcx> {
27 pub(crate) fn new(body: &Body<'tcx>) -> Self {
28 let mut result = MirPatch {
29 patch_map: IndexVec::from_elem(None, &body.basic_blocks),
30 new_blocks: vec![],
31 new_statements: vec![],
32 new_locals: vec![],
33 next_local: body.local_decls.len(),
34 resume_block: None,
35 unreachable_cleanup_block: None,
36 unreachable_no_cleanup_block: None,
37 terminate_block: None,
38 body_span: body.span,
39 };
40
41 for (bb, block) in body.basic_blocks.iter_enumerated() {
42 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
44 && block.statements.is_empty()
45 {
46 result.resume_block = Some(bb);
47 continue;
48 }
49
50 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
52 && block.statements.is_empty()
53 {
54 if block.is_cleanup {
55 result.unreachable_cleanup_block = Some(bb);
56 } else {
57 result.unreachable_no_cleanup_block = Some(bb);
58 }
59 continue;
60 }
61
62 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
64 && block.statements.is_empty()
65 {
66 result.terminate_block = Some((bb, reason));
67 continue;
68 }
69 }
70
71 result
72 }
73
74 pub(crate) fn resume_block(&mut self) -> BasicBlock {
75 if let Some(bb) = self.resume_block {
76 return bb;
77 }
78
79 let bb = self.new_block(BasicBlockData {
80 statements: vec![],
81 terminator: Some(Terminator {
82 source_info: SourceInfo::outermost(self.body_span),
83 kind: TerminatorKind::UnwindResume,
84 }),
85 is_cleanup: true,
86 });
87 self.resume_block = Some(bb);
88 bb
89 }
90
91 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
92 if let Some(bb) = self.unreachable_cleanup_block {
93 return bb;
94 }
95
96 let bb = self.new_block(BasicBlockData {
97 statements: vec![],
98 terminator: Some(Terminator {
99 source_info: SourceInfo::outermost(self.body_span),
100 kind: TerminatorKind::Unreachable,
101 }),
102 is_cleanup: true,
103 });
104 self.unreachable_cleanup_block = Some(bb);
105 bb
106 }
107
108 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
109 if let Some(bb) = self.unreachable_no_cleanup_block {
110 return bb;
111 }
112
113 let bb = self.new_block(BasicBlockData {
114 statements: vec![],
115 terminator: Some(Terminator {
116 source_info: SourceInfo::outermost(self.body_span),
117 kind: TerminatorKind::Unreachable,
118 }),
119 is_cleanup: false,
120 });
121 self.unreachable_no_cleanup_block = Some(bb);
122 bb
123 }
124
125 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
126 if let Some((cached_bb, cached_reason)) = self.terminate_block
127 && reason == cached_reason
128 {
129 return cached_bb;
130 }
131
132 let bb = self.new_block(BasicBlockData {
133 statements: vec![],
134 terminator: Some(Terminator {
135 source_info: SourceInfo::outermost(self.body_span),
136 kind: TerminatorKind::UnwindTerminate(reason),
137 }),
138 is_cleanup: true,
139 });
140 self.terminate_block = Some((bb, reason));
141 bb
142 }
143
144 pub(crate) fn is_patched(&self, bb: BasicBlock) -> bool {
145 self.patch_map[bb].is_some()
146 }
147
148 pub(crate) fn new_local_with_info(
149 &mut self,
150 ty: Ty<'tcx>,
151 span: Span,
152 local_info: LocalInfo<'tcx>,
153 ) -> Local {
154 let index = self.next_local;
155 self.next_local += 1;
156 let mut new_decl = LocalDecl::new(ty, span);
157 **new_decl.local_info.as_mut().assert_crate_local() = local_info;
158 self.new_locals.push(new_decl);
159 Local::new(index)
160 }
161
162 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
163 let index = self.next_local;
164 self.next_local += 1;
165 self.new_locals.push(LocalDecl::new(ty, span));
166 Local::new(index)
167 }
168
169 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
170 let block = BasicBlock::new(self.patch_map.len());
171 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
172 self.new_blocks.push(data);
173 self.patch_map.push(None);
174 block
175 }
176
177 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
178 assert!(self.patch_map[block].is_none());
179 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
180 self.patch_map[block] = Some(new);
181 }
182
183 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
184 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
185 self.new_statements.push((loc, stmt));
186 }
187
188 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
189 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
190 }
191
192 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
193 debug!(
194 "MirPatch: {:?} new temps, starting from index {}: {:?}",
195 self.new_locals.len(),
196 body.local_decls.len(),
197 self.new_locals
198 );
199 debug!(
200 "MirPatch: {} new blocks, starting from index {}",
201 self.new_blocks.len(),
202 body.basic_blocks.len()
203 );
204 let bbs = if self.patch_map.is_empty() && self.new_blocks.is_empty() {
205 body.basic_blocks.as_mut_preserves_cfg()
206 } else {
207 body.basic_blocks.as_mut()
208 };
209 bbs.extend(self.new_blocks);
210 body.local_decls.extend(self.new_locals);
211 for (src, patch) in self.patch_map.into_iter_enumerated() {
212 if let Some(patch) = patch {
213 debug!("MirPatch: patching block {:?}", src);
214 bbs[src].terminator_mut().kind = patch;
215 }
216 }
217
218 let mut new_statements = self.new_statements;
219 new_statements.sort_by_key(|s| s.0);
220
221 let mut delta = 0;
222 let mut last_bb = START_BLOCK;
223 for (mut loc, stmt) in new_statements {
224 if loc.block != last_bb {
225 delta = 0;
226 last_bb = loc.block;
227 }
228 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
229 loc.statement_index += delta;
230 let source_info = Self::source_info_for_index(&body[loc.block], loc);
231 body[loc.block]
232 .statements
233 .insert(loc.statement_index, Statement { source_info, kind: stmt });
234 delta += 1;
235 }
236 }
237
238 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
239 match data.statements.get(loc.statement_index) {
240 Some(stmt) => stmt.source_info,
241 None => data.terminator().source_info,
242 }
243 }
244
245 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
246 let data = match loc.block.index().checked_sub(body.basic_blocks.len()) {
247 Some(new) => &self.new_blocks[new],
248 None => &body[loc.block],
249 };
250 Self::source_info_for_index(data, loc)
251 }
252}