rustc_mir_transform/
patch.rs

1use rustc_data_structures::fx::FxHashMap;
2use rustc_index::Idx;
3use rustc_middle::mir::*;
4use rustc_middle::ty::Ty;
5use rustc_span::Span;
6use tracing::debug;
7
8/// This struct lets you "patch" a MIR body, i.e. modify it. You can queue up
9/// various changes, such as the addition of new statements and basic blocks
10/// and replacement of terminators, and then apply the queued changes all at
11/// once with `apply`. This is useful for MIR transformation passes.
12pub(crate) struct MirPatch<'tcx> {
13    term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
14    /// Set of statements that should be replaced by `Nop`.
15    nop_statements: Vec<Location>,
16    new_blocks: Vec<BasicBlockData<'tcx>>,
17    new_statements: Vec<(Location, StatementKind<'tcx>)>,
18    new_locals: Vec<LocalDecl<'tcx>>,
19    resume_block: Option<BasicBlock>,
20    // Only for unreachable in cleanup path.
21    unreachable_cleanup_block: Option<BasicBlock>,
22    // Only for unreachable not in cleanup path.
23    unreachable_no_cleanup_block: Option<BasicBlock>,
24    // Cached block for UnwindTerminate (with reason)
25    terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
26    body_span: Span,
27    next_local: usize,
28    /// The number of blocks at the start of the transformation. New blocks
29    /// get appended at the end.
30    next_block: usize,
31}
32
33impl<'tcx> MirPatch<'tcx> {
34    /// Creates a new, empty patch.
35    pub(crate) fn new(body: &Body<'tcx>) -> Self {
36        let mut result = MirPatch {
37            term_patch_map: Default::default(),
38            nop_statements: vec![],
39            new_blocks: vec![],
40            new_statements: vec![],
41            new_locals: vec![],
42            next_local: body.local_decls.len(),
43            next_block: body.basic_blocks.len(),
44            resume_block: None,
45            unreachable_cleanup_block: None,
46            unreachable_no_cleanup_block: None,
47            terminate_block: None,
48            body_span: body.span,
49        };
50
51        for (bb, block) in body.basic_blocks.iter_enumerated() {
52            // Check if we already have a resume block
53            if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
54                && block.statements.is_empty()
55            {
56                result.resume_block = Some(bb);
57                continue;
58            }
59
60            // Check if we already have an unreachable block
61            if matches!(block.terminator().kind, TerminatorKind::Unreachable)
62                && block.statements.is_empty()
63            {
64                if block.is_cleanup {
65                    result.unreachable_cleanup_block = Some(bb);
66                } else {
67                    result.unreachable_no_cleanup_block = Some(bb);
68                }
69                continue;
70            }
71
72            // Check if we already have a terminate block
73            if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
74                && block.statements.is_empty()
75            {
76                result.terminate_block = Some((bb, reason));
77                continue;
78            }
79        }
80
81        result
82    }
83
84    pub(crate) fn resume_block(&mut self) -> BasicBlock {
85        if let Some(bb) = self.resume_block {
86            return bb;
87        }
88
89        let bb = self.new_block(BasicBlockData::new(
90            Some(Terminator {
91                source_info: SourceInfo::outermost(self.body_span),
92                kind: TerminatorKind::UnwindResume,
93            }),
94            true,
95        ));
96        self.resume_block = Some(bb);
97        bb
98    }
99
100    pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
101        if let Some(bb) = self.unreachable_cleanup_block {
102            return bb;
103        }
104
105        let bb = self.new_block(BasicBlockData::new(
106            Some(Terminator {
107                source_info: SourceInfo::outermost(self.body_span),
108                kind: TerminatorKind::Unreachable,
109            }),
110            true,
111        ));
112        self.unreachable_cleanup_block = Some(bb);
113        bb
114    }
115
116    pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
117        if let Some(bb) = self.unreachable_no_cleanup_block {
118            return bb;
119        }
120
121        let bb = self.new_block(BasicBlockData::new(
122            Some(Terminator {
123                source_info: SourceInfo::outermost(self.body_span),
124                kind: TerminatorKind::Unreachable,
125            }),
126            false,
127        ));
128        self.unreachable_no_cleanup_block = Some(bb);
129        bb
130    }
131
132    pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
133        if let Some((cached_bb, cached_reason)) = self.terminate_block
134            && reason == cached_reason
135        {
136            return cached_bb;
137        }
138
139        let bb = self.new_block(BasicBlockData::new(
140            Some(Terminator {
141                source_info: SourceInfo::outermost(self.body_span),
142                kind: TerminatorKind::UnwindTerminate(reason),
143            }),
144            true,
145        ));
146        self.terminate_block = Some((bb, reason));
147        bb
148    }
149
150    /// Has a replacement of this block's terminator been queued in this patch?
151    pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
152        self.term_patch_map.contains_key(&bb)
153    }
154
155    /// Universal getter for block data, either it is in 'old' blocks or in patched ones
156    pub(crate) fn block<'a>(
157        &'a self,
158        body: &'a Body<'tcx>,
159        bb: BasicBlock,
160    ) -> &'a BasicBlockData<'tcx> {
161        match bb.index().checked_sub(body.basic_blocks.len()) {
162            Some(new) => &self.new_blocks[new],
163            None => &body[bb],
164        }
165    }
166
167    pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
168        let offset = self.block(body, bb).statements.len();
169        Location { block: bb, statement_index: offset }
170    }
171
172    /// Queues the addition of a new temporary with additional local info.
173    pub(crate) fn new_local_with_info(
174        &mut self,
175        ty: Ty<'tcx>,
176        span: Span,
177        local_info: LocalInfo<'tcx>,
178    ) -> Local {
179        let index = self.next_local;
180        self.next_local += 1;
181        let mut new_decl = LocalDecl::new(ty, span);
182        **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
183        self.new_locals.push(new_decl);
184        Local::new(index)
185    }
186
187    /// Queues the addition of a new temporary.
188    pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
189        let index = self.next_local;
190        self.next_local += 1;
191        self.new_locals.push(LocalDecl::new(ty, span));
192        Local::new(index)
193    }
194
195    /// Returns the type of a local that's newly-added in the patch.
196    pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
197        let local = local.as_usize();
198        assert!(local < self.next_local);
199        let new_local_idx = self.new_locals.len() - (self.next_local - local);
200        self.new_locals[new_local_idx].ty
201    }
202
203    /// Queues the addition of a new basic block.
204    pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
205        let block = BasicBlock::from_usize(self.next_block + self.new_blocks.len());
206        debug!("MirPatch: new_block: {:?}: {:?}", block, data);
207        self.new_blocks.push(data);
208        block
209    }
210
211    /// Queues the replacement of a block's terminator.
212    pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
213        assert!(!self.term_patch_map.contains_key(&block));
214        debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
215        self.term_patch_map.insert(block, new);
216    }
217
218    /// Mark given statement to be replaced by a `Nop`.
219    ///
220    /// This method only works on statements from the initial body, and cannot be used to remove
221    /// statements from `add_statement` or `add_assign`.
222    #[tracing::instrument(level = "debug", skip(self))]
223    pub(crate) fn nop_statement(&mut self, loc: Location) {
224        self.nop_statements.push(loc);
225    }
226
227    /// Queues the insertion of a statement at a given location. The statement
228    /// currently at that location, and all statements that follow, are shifted
229    /// down. If multiple statements are queued for addition at the same
230    /// location, the final statement order after calling `apply` will match
231    /// the queue insertion order.
232    ///
233    /// E.g. if we have `s0` at location `loc` and do these calls:
234    ///
235    ///   p.add_statement(loc, s1);
236    ///   p.add_statement(loc, s2);
237    ///   p.apply(body);
238    ///
239    /// then the final order will be `s1, s2, s0`, with `s1` at `loc`.
240    pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
241        debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
242        self.new_statements.push((loc, stmt));
243    }
244
245    /// Like `add_statement`, but specialized for assignments.
246    pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
247        self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
248    }
249
250    /// Applies the queued changes.
251    pub(crate) fn apply(self, body: &mut Body<'tcx>) {
252        debug!(
253            "MirPatch: {:?} new temps, starting from index {}: {:?}",
254            self.new_locals.len(),
255            body.local_decls.len(),
256            self.new_locals
257        );
258        debug!(
259            "MirPatch: {} new blocks, starting from index {}",
260            self.new_blocks.len(),
261            body.basic_blocks.len()
262        );
263        debug_assert_eq!(self.next_block, body.basic_blocks.len());
264        let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
265            body.basic_blocks.as_mut_preserves_cfg()
266        } else {
267            body.basic_blocks.as_mut()
268        };
269        bbs.extend(self.new_blocks);
270        body.local_decls.extend(self.new_locals);
271
272        for loc in self.nop_statements {
273            bbs[loc.block].statements[loc.statement_index].make_nop();
274        }
275
276        let mut new_statements = self.new_statements;
277
278        // This must be a stable sort to provide the ordering described in the
279        // comment for `add_statement`.
280        new_statements.sort_by_key(|s| s.0);
281
282        let mut delta = 0;
283        let mut last_bb = START_BLOCK;
284        for (mut loc, stmt) in new_statements {
285            if loc.block != last_bb {
286                delta = 0;
287                last_bb = loc.block;
288            }
289            debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
290            loc.statement_index += delta;
291            let source_info = Self::source_info_for_index(&bbs[loc.block], loc);
292            bbs[loc.block]
293                .statements
294                .insert(loc.statement_index, Statement::new(source_info, stmt));
295            delta += 1;
296        }
297
298        // The order in which we patch terminators does not change the result.
299        #[allow(rustc::potential_query_instability)]
300        for (src, patch) in self.term_patch_map {
301            debug!("MirPatch: patching block {:?}", src);
302            let bb = &mut bbs[src];
303            if let TerminatorKind::Unreachable = patch {
304                bb.statements.clear();
305            }
306            bb.terminator_mut().kind = patch;
307        }
308    }
309
310    fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
311        match data.statements.get(loc.statement_index) {
312            Some(stmt) => stmt.source_info,
313            None => data.terminator().source_info,
314        }
315    }
316
317    pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
318        let data = self.block(body, loc.block);
319        Self::source_info_for_index(data, loc)
320    }
321}