rustc_mir_transform/
patch.rs

1use rustc_index::{Idx, IndexVec};
2use rustc_middle::mir::*;
3use rustc_middle::ty::Ty;
4use rustc_span::Span;
5use tracing::debug;
6
7/// This struct represents a patch to MIR, which can add
8/// new statements and basic blocks and patch over block
9/// terminators.
10pub(crate) struct MirPatch<'tcx> {
11    patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
12    new_blocks: Vec<BasicBlockData<'tcx>>,
13    new_statements: Vec<(Location, StatementKind<'tcx>)>,
14    new_locals: Vec<LocalDecl<'tcx>>,
15    resume_block: Option<BasicBlock>,
16    // Only for unreachable in cleanup path.
17    unreachable_cleanup_block: Option<BasicBlock>,
18    // Only for unreachable not in cleanup path.
19    unreachable_no_cleanup_block: Option<BasicBlock>,
20    // Cached block for UnwindTerminate (with reason)
21    terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
22    body_span: Span,
23    next_local: usize,
24}
25
26impl<'tcx> MirPatch<'tcx> {
27    pub(crate) fn new(body: &Body<'tcx>) -> Self {
28        let mut result = MirPatch {
29            patch_map: IndexVec::from_elem(None, &body.basic_blocks),
30            new_blocks: vec![],
31            new_statements: vec![],
32            new_locals: vec![],
33            next_local: body.local_decls.len(),
34            resume_block: None,
35            unreachable_cleanup_block: None,
36            unreachable_no_cleanup_block: None,
37            terminate_block: None,
38            body_span: body.span,
39        };
40
41        for (bb, block) in body.basic_blocks.iter_enumerated() {
42            // Check if we already have a resume block
43            if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
44                && block.statements.is_empty()
45            {
46                result.resume_block = Some(bb);
47                continue;
48            }
49
50            // Check if we already have an unreachable block
51            if matches!(block.terminator().kind, TerminatorKind::Unreachable)
52                && block.statements.is_empty()
53            {
54                if block.is_cleanup {
55                    result.unreachable_cleanup_block = Some(bb);
56                } else {
57                    result.unreachable_no_cleanup_block = Some(bb);
58                }
59                continue;
60            }
61
62            // Check if we already have a terminate block
63            if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
64                && block.statements.is_empty()
65            {
66                result.terminate_block = Some((bb, reason));
67                continue;
68            }
69        }
70
71        result
72    }
73
74    pub(crate) fn resume_block(&mut self) -> BasicBlock {
75        if let Some(bb) = self.resume_block {
76            return bb;
77        }
78
79        let bb = self.new_block(BasicBlockData {
80            statements: vec![],
81            terminator: Some(Terminator {
82                source_info: SourceInfo::outermost(self.body_span),
83                kind: TerminatorKind::UnwindResume,
84            }),
85            is_cleanup: true,
86        });
87        self.resume_block = Some(bb);
88        bb
89    }
90
91    pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
92        if let Some(bb) = self.unreachable_cleanup_block {
93            return bb;
94        }
95
96        let bb = self.new_block(BasicBlockData {
97            statements: vec![],
98            terminator: Some(Terminator {
99                source_info: SourceInfo::outermost(self.body_span),
100                kind: TerminatorKind::Unreachable,
101            }),
102            is_cleanup: true,
103        });
104        self.unreachable_cleanup_block = Some(bb);
105        bb
106    }
107
108    pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
109        if let Some(bb) = self.unreachable_no_cleanup_block {
110            return bb;
111        }
112
113        let bb = self.new_block(BasicBlockData {
114            statements: vec![],
115            terminator: Some(Terminator {
116                source_info: SourceInfo::outermost(self.body_span),
117                kind: TerminatorKind::Unreachable,
118            }),
119            is_cleanup: false,
120        });
121        self.unreachable_no_cleanup_block = Some(bb);
122        bb
123    }
124
125    pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
126        if let Some((cached_bb, cached_reason)) = self.terminate_block
127            && reason == cached_reason
128        {
129            return cached_bb;
130        }
131
132        let bb = self.new_block(BasicBlockData {
133            statements: vec![],
134            terminator: Some(Terminator {
135                source_info: SourceInfo::outermost(self.body_span),
136                kind: TerminatorKind::UnwindTerminate(reason),
137            }),
138            is_cleanup: true,
139        });
140        self.terminate_block = Some((bb, reason));
141        bb
142    }
143
144    pub(crate) fn is_patched(&self, bb: BasicBlock) -> bool {
145        self.patch_map[bb].is_some()
146    }
147
148    pub(crate) fn new_local_with_info(
149        &mut self,
150        ty: Ty<'tcx>,
151        span: Span,
152        local_info: LocalInfo<'tcx>,
153    ) -> Local {
154        let index = self.next_local;
155        self.next_local += 1;
156        let mut new_decl = LocalDecl::new(ty, span);
157        **new_decl.local_info.as_mut().assert_crate_local() = local_info;
158        self.new_locals.push(new_decl);
159        Local::new(index)
160    }
161
162    pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
163        let index = self.next_local;
164        self.next_local += 1;
165        self.new_locals.push(LocalDecl::new(ty, span));
166        Local::new(index)
167    }
168
169    pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
170        let block = BasicBlock::new(self.patch_map.len());
171        debug!("MirPatch: new_block: {:?}: {:?}", block, data);
172        self.new_blocks.push(data);
173        self.patch_map.push(None);
174        block
175    }
176
177    pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
178        assert!(self.patch_map[block].is_none());
179        debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
180        self.patch_map[block] = Some(new);
181    }
182
183    pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
184        debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
185        self.new_statements.push((loc, stmt));
186    }
187
188    pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
189        self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
190    }
191
192    pub(crate) fn apply(self, body: &mut Body<'tcx>) {
193        debug!(
194            "MirPatch: {:?} new temps, starting from index {}: {:?}",
195            self.new_locals.len(),
196            body.local_decls.len(),
197            self.new_locals
198        );
199        debug!(
200            "MirPatch: {} new blocks, starting from index {}",
201            self.new_blocks.len(),
202            body.basic_blocks.len()
203        );
204        let bbs = if self.patch_map.is_empty() && self.new_blocks.is_empty() {
205            body.basic_blocks.as_mut_preserves_cfg()
206        } else {
207            body.basic_blocks.as_mut()
208        };
209        bbs.extend(self.new_blocks);
210        body.local_decls.extend(self.new_locals);
211        for (src, patch) in self.patch_map.into_iter_enumerated() {
212            if let Some(patch) = patch {
213                debug!("MirPatch: patching block {:?}", src);
214                bbs[src].terminator_mut().kind = patch;
215            }
216        }
217
218        let mut new_statements = self.new_statements;
219        new_statements.sort_by_key(|s| s.0);
220
221        let mut delta = 0;
222        let mut last_bb = START_BLOCK;
223        for (mut loc, stmt) in new_statements {
224            if loc.block != last_bb {
225                delta = 0;
226                last_bb = loc.block;
227            }
228            debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
229            loc.statement_index += delta;
230            let source_info = Self::source_info_for_index(&body[loc.block], loc);
231            body[loc.block]
232                .statements
233                .insert(loc.statement_index, Statement { source_info, kind: stmt });
234            delta += 1;
235        }
236    }
237
238    fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
239        match data.statements.get(loc.statement_index) {
240            Some(stmt) => stmt.source_info,
241            None => data.terminator().source_info,
242        }
243    }
244
245    pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
246        let data = match loc.block.index().checked_sub(body.basic_blocks.len()) {
247            Some(new) => &self.new_blocks[new],
248            None => &body[loc.block],
249        };
250        Self::source_info_for_index(data, loc)
251    }
252}