rustc_mir_transform/
patch.rs

1use rustc_data_structures::fx::FxHashMap;
2use rustc_index::Idx;
3use rustc_middle::mir::*;
4use rustc_middle::ty::Ty;
5use rustc_span::Span;
6use tracing::debug;
7
8/// This struct lets you "patch" a MIR body, i.e. modify it. You can queue up
9/// various changes, such as the addition of new statements and basic blocks
10/// and replacement of terminators, and then apply the queued changes all at
11/// once with `apply`. This is useful for MIR transformation passes.
12pub(crate) struct MirPatch<'tcx> {
13    term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
14    /// Set of statements that should be replaced by `Nop`.
15    nop_statements: Vec<Location>,
16    new_blocks: Vec<BasicBlockData<'tcx>>,
17    new_statements: Vec<(Location, StatementKind<'tcx>)>,
18    new_locals: Vec<LocalDecl<'tcx>>,
19    resume_block: Option<BasicBlock>,
20    // Only for unreachable in cleanup path.
21    unreachable_cleanup_block: Option<BasicBlock>,
22    // Only for unreachable not in cleanup path.
23    unreachable_no_cleanup_block: Option<BasicBlock>,
24    // Cached block for UnwindTerminate (with reason)
25    terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
26    body_span: Span,
27    /// The number of locals at the start of the transformation. New locals
28    /// get appended at the end.
29    next_local: usize,
30    /// The number of blocks at the start of the transformation. New blocks
31    /// get appended at the end.
32    next_block: usize,
33}
34
35impl<'tcx> MirPatch<'tcx> {
36    /// Creates a new, empty patch.
37    pub(crate) fn new(body: &Body<'tcx>) -> Self {
38        let mut result = MirPatch {
39            term_patch_map: Default::default(),
40            nop_statements: vec![],
41            new_blocks: vec![],
42            new_statements: vec![],
43            new_locals: vec![],
44            next_local: body.local_decls.len(),
45            next_block: body.basic_blocks.len(),
46            resume_block: None,
47            unreachable_cleanup_block: None,
48            unreachable_no_cleanup_block: None,
49            terminate_block: None,
50            body_span: body.span,
51        };
52
53        for (bb, block) in body.basic_blocks.iter_enumerated() {
54            // Check if we already have a resume block
55            if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
56                && block.statements.is_empty()
57            {
58                result.resume_block = Some(bb);
59                continue;
60            }
61
62            // Check if we already have an unreachable block
63            if matches!(block.terminator().kind, TerminatorKind::Unreachable)
64                && block.statements.is_empty()
65            {
66                if block.is_cleanup {
67                    result.unreachable_cleanup_block = Some(bb);
68                } else {
69                    result.unreachable_no_cleanup_block = Some(bb);
70                }
71                continue;
72            }
73
74            // Check if we already have a terminate block
75            if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
76                && block.statements.is_empty()
77            {
78                result.terminate_block = Some((bb, reason));
79                continue;
80            }
81        }
82
83        result
84    }
85
86    pub(crate) fn resume_block(&mut self) -> BasicBlock {
87        if let Some(bb) = self.resume_block {
88            return bb;
89        }
90
91        let bb = self.new_block(BasicBlockData::new(
92            Some(Terminator {
93                source_info: SourceInfo::outermost(self.body_span),
94                kind: TerminatorKind::UnwindResume,
95            }),
96            true,
97        ));
98        self.resume_block = Some(bb);
99        bb
100    }
101
102    pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
103        if let Some(bb) = self.unreachable_cleanup_block {
104            return bb;
105        }
106
107        let bb = self.new_block(BasicBlockData::new(
108            Some(Terminator {
109                source_info: SourceInfo::outermost(self.body_span),
110                kind: TerminatorKind::Unreachable,
111            }),
112            true,
113        ));
114        self.unreachable_cleanup_block = Some(bb);
115        bb
116    }
117
118    pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
119        if let Some(bb) = self.unreachable_no_cleanup_block {
120            return bb;
121        }
122
123        let bb = self.new_block(BasicBlockData::new(
124            Some(Terminator {
125                source_info: SourceInfo::outermost(self.body_span),
126                kind: TerminatorKind::Unreachable,
127            }),
128            false,
129        ));
130        self.unreachable_no_cleanup_block = Some(bb);
131        bb
132    }
133
134    pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
135        if let Some((cached_bb, cached_reason)) = self.terminate_block
136            && reason == cached_reason
137        {
138            return cached_bb;
139        }
140
141        let bb = self.new_block(BasicBlockData::new(
142            Some(Terminator {
143                source_info: SourceInfo::outermost(self.body_span),
144                kind: TerminatorKind::UnwindTerminate(reason),
145            }),
146            true,
147        ));
148        self.terminate_block = Some((bb, reason));
149        bb
150    }
151
152    /// Has a replacement of this block's terminator been queued in this patch?
153    pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
154        self.term_patch_map.contains_key(&bb)
155    }
156
157    /// Universal getter for block data, either it is in 'old' blocks or in patched ones
158    pub(crate) fn block<'a>(
159        &'a self,
160        body: &'a Body<'tcx>,
161        bb: BasicBlock,
162    ) -> &'a BasicBlockData<'tcx> {
163        match bb.index().checked_sub(body.basic_blocks.len()) {
164            Some(new) => &self.new_blocks[new],
165            None => &body[bb],
166        }
167    }
168
169    pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
170        let offset = self.block(body, bb).statements.len();
171        Location { block: bb, statement_index: offset }
172    }
173
174    /// Queues the addition of a new temporary with additional local info.
175    pub(crate) fn new_local_with_info(
176        &mut self,
177        ty: Ty<'tcx>,
178        span: Span,
179        local_info: LocalInfo<'tcx>,
180    ) -> Local {
181        let index = self.next_local + self.new_locals.len();
182        let mut new_decl = LocalDecl::new(ty, span);
183        **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
184        self.new_locals.push(new_decl);
185        Local::new(index)
186    }
187
188    /// Queues the addition of a new temporary.
189    pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
190        let index = self.next_local + self.new_locals.len();
191        self.new_locals.push(LocalDecl::new(ty, span));
192        Local::new(index)
193    }
194
195    /// Returns the type of a local that's newly-added in the patch.
196    pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
197        let local = local.as_usize();
198        assert!(local < self.next_local + self.new_locals.len());
199        let new_local_idx = local - self.next_local;
200        self.new_locals[new_local_idx].ty
201    }
202
203    /// Queues the addition of a new basic block.
204    pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
205        let block = BasicBlock::from_usize(self.next_block + self.new_blocks.len());
206        debug!("MirPatch: new_block: {:?}: {:?}", block, data);
207        self.new_blocks.push(data);
208        block
209    }
210
211    /// Queues the replacement of a block's terminator.
212    pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
213        assert!(!self.term_patch_map.contains_key(&block));
214        debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
215        self.term_patch_map.insert(block, new);
216    }
217
218    /// Mark given statement to be replaced by a `Nop`.
219    ///
220    /// This method only works on statements from the initial body, and cannot be used to remove
221    /// statements from `add_statement` or `add_assign`.
222    #[tracing::instrument(level = "debug", skip(self))]
223    pub(crate) fn nop_statement(&mut self, loc: Location) {
224        self.nop_statements.push(loc);
225    }
226
227    /// Queues the insertion of a statement at a given location. The statement
228    /// currently at that location, and all statements that follow, are shifted
229    /// down. If multiple statements are queued for addition at the same
230    /// location, the final statement order after calling `apply` will match
231    /// the queue insertion order.
232    ///
233    /// E.g. if we have `s0` at location `loc` and do these calls:
234    ///
235    ///   p.add_statement(loc, s1);
236    ///   p.add_statement(loc, s2);
237    ///   p.apply(body);
238    ///
239    /// then the final order will be `s1, s2, s0`, with `s1` at `loc`.
240    pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
241        debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
242        self.new_statements.push((loc, stmt));
243    }
244
245    /// Like `add_statement`, but specialized for assignments.
246    pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
247        self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
248    }
249
250    /// Applies the queued changes.
251    pub(crate) fn apply(self, body: &mut Body<'tcx>) {
252        debug!(
253            "MirPatch: {:?} new temps, starting from index {}: {:?}",
254            self.new_locals.len(),
255            body.local_decls.len(),
256            self.new_locals
257        );
258        debug!(
259            "MirPatch: {} new blocks, starting from index {}",
260            self.new_blocks.len(),
261            body.basic_blocks.len()
262        );
263        debug_assert_eq!(self.next_block, body.basic_blocks.len());
264        let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
265            body.basic_blocks.as_mut_preserves_cfg()
266        } else {
267            body.basic_blocks.as_mut()
268        };
269        bbs.extend(self.new_blocks);
270        body.local_decls.extend(self.new_locals);
271
272        for loc in self.nop_statements {
273            bbs[loc.block].statements[loc.statement_index].make_nop(true);
274        }
275
276        let mut new_statements = self.new_statements;
277
278        // This must be a stable sort to provide the ordering described in the
279        // comment for `add_statement`.
280        new_statements.sort_by_key(|s| s.0);
281
282        let mut delta = 0;
283        let mut last_bb = START_BLOCK;
284        for (mut loc, stmt) in new_statements {
285            if loc.block != last_bb {
286                delta = 0;
287                last_bb = loc.block;
288            }
289            debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
290            loc.statement_index += delta;
291            let source_info = Self::source_info_for_index(&bbs[loc.block], loc);
292            bbs[loc.block]
293                .statements
294                .insert(loc.statement_index, Statement::new(source_info, stmt));
295            delta += 1;
296        }
297
298        // The order in which we patch terminators does not change the result.
299        #[allow(rustc::potential_query_instability)]
300        for (src, patch) in self.term_patch_map {
301            debug!("MirPatch: patching block {:?}", src);
302            let bb = &mut bbs[src];
303            if let TerminatorKind::Unreachable = patch {
304                bb.statements.clear();
305            }
306            bb.terminator_mut().kind = patch;
307        }
308    }
309
310    fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
311        match data.statements.get(loc.statement_index) {
312            Some(stmt) => stmt.source_info,
313            None => data.terminator().source_info,
314        }
315    }
316
317    pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
318        let data = self.block(body, loc.block);
319        Self::source_info_for_index(data, loc)
320    }
321}