rustc_mir_transform/
sroa.rs

1use rustc_abi::{FIRST_VARIANT, FieldIdx};
2use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
6use rustc_middle::bug;
7use rustc_middle::mir::visit::*;
8use rustc_middle::mir::*;
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
11use tracing::{debug, instrument};
12
13use crate::patch::MirPatch;
14
15pub(super) struct ScalarReplacementOfAggregates;
16
17impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates {
18    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19        sess.mir_opt_level() >= 2
20    }
21
22    #[instrument(level = "debug", skip(self, tcx, body))]
23    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24        debug!(def_id = ?body.source.def_id());
25
26        // Avoid query cycles (coroutines require optimized MIR for layout).
27        if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
28            return;
29        }
30
31        let mut excluded = excluded_locals(body);
32        let typing_env = body.typing_env(tcx);
33        loop {
34            debug!(?excluded);
35            let escaping = escaping_locals(tcx, typing_env, &excluded, body);
36            debug!(?escaping);
37            let replacements = compute_flattening(tcx, typing_env, body, escaping);
38            debug!(?replacements);
39            let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
40            if !all_dead_locals.is_empty() {
41                excluded.union(&all_dead_locals);
42                excluded = {
43                    let mut growable = GrowableBitSet::from(excluded);
44                    growable.ensure(body.local_decls.len());
45                    growable.into()
46                };
47            } else {
48                break;
49            }
50        }
51    }
52
53    fn is_required(&self) -> bool {
54        false
55    }
56}
57
58/// Identify all locals that are not eligible for SROA.
59///
60/// There are 3 cases:
61/// - the aggregated local is used or passed to other code (function parameters and arguments);
62/// - the locals is a union or an enum;
63/// - the local's address is taken, and thus the relative addresses of the fields are observable to
64///   client code.
65fn escaping_locals<'tcx>(
66    tcx: TyCtxt<'tcx>,
67    typing_env: ty::TypingEnv<'tcx>,
68    excluded: &DenseBitSet<Local>,
69    body: &Body<'tcx>,
70) -> DenseBitSet<Local> {
71    let is_excluded_ty = |ty: Ty<'tcx>| {
72        if ty.is_union() || ty.is_enum() {
73            return true;
74        }
75        if let ty::Adt(def, _args) = ty.kind() {
76            if def.repr().simd() {
77                // Exclude #[repr(simd)] types so that they are not de-optimized into an array
78                return true;
79            }
80            if tcx.is_lang_item(def.did(), LangItem::DynMetadata) {
81                // codegen wants to see the `DynMetadata<T>`,
82                // not the inner reference-to-opaque-type.
83                return true;
84            }
85            // We already excluded unions and enums, so this ADT must have one variant
86            let variant = def.variant(FIRST_VARIANT);
87            if variant.fields.len() > 1 {
88                // If this has more than one field, it cannot be a wrapper that only provides a
89                // niche, so we do not want to automatically exclude it.
90                return false;
91            }
92            let Ok(layout) = tcx.layout_of(typing_env.as_query_input(ty)) else {
93                // We can't get the layout
94                return true;
95            };
96            if layout.layout.largest_niche().is_some() {
97                // This type has a niche
98                return true;
99            }
100        }
101        // Default for non-ADTs
102        false
103    };
104
105    let mut set = DenseBitSet::new_empty(body.local_decls.len());
106    set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
107    for (local, decl) in body.local_decls().iter_enumerated() {
108        if excluded.contains(local) || is_excluded_ty(decl.ty) {
109            set.insert(local);
110        }
111    }
112    let mut visitor = EscapeVisitor { set };
113    visitor.visit_body(body);
114    return visitor.set;
115
116    struct EscapeVisitor {
117        set: DenseBitSet<Local>,
118    }
119
120    impl<'tcx> Visitor<'tcx> for EscapeVisitor {
121        fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
122            self.set.insert(local);
123        }
124
125        fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
126            // Mirror the implementation in PreFlattenVisitor.
127            if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
128                return;
129            }
130            self.super_place(place, context, location);
131        }
132
133        fn visit_assign(
134            &mut self,
135            lvalue: &Place<'tcx>,
136            rvalue: &Rvalue<'tcx>,
137            location: Location,
138        ) {
139            if lvalue.as_local().is_some() {
140                match rvalue {
141                    // Aggregate assignments are expanded in run_pass.
142                    Rvalue::Aggregate(..) | Rvalue::Use(..) => {
143                        self.visit_rvalue(rvalue, location);
144                        return;
145                    }
146                    _ => {}
147                }
148            }
149            self.super_assign(lvalue, rvalue, location)
150        }
151
152        fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
153            match statement.kind {
154                // Storage statements are expanded in run_pass.
155                StatementKind::StorageLive(..)
156                | StatementKind::StorageDead(..)
157                | StatementKind::Deinit(..) => return,
158                _ => self.super_statement(statement, location),
159            }
160        }
161
162        // We ignore anything that happens in debuginfo, since we expand it using
163        // `VarDebugInfoFragment`.
164        fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
165    }
166}
167
168#[derive(Default, Debug)]
169struct ReplacementMap<'tcx> {
170    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
171    /// and deinit statement and debuginfo.
172    fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
173}
174
175impl<'tcx> ReplacementMap<'tcx> {
176    fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
177        let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
178            return None;
179        };
180        let fields = self.fragments[place.local].as_ref()?;
181        let (_, new_local) = fields[f]?;
182        Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
183    }
184
185    fn place_fragments(
186        &self,
187        place: Place<'tcx>,
188    ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)> + '_> {
189        let local = place.as_local()?;
190        let fields = self.fragments[local].as_ref()?;
191        Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
192            let (ty, local) = opt_ty_local?;
193            Some((field, ty, local))
194        }))
195    }
196}
197
198/// Compute the replacement of flattened places into locals.
199///
200/// For each eligible place, we assign a new local to each accessed field.
201/// The replacement will be done later in `ReplacementVisitor`.
202fn compute_flattening<'tcx>(
203    tcx: TyCtxt<'tcx>,
204    typing_env: ty::TypingEnv<'tcx>,
205    body: &mut Body<'tcx>,
206    escaping: DenseBitSet<Local>,
207) -> ReplacementMap<'tcx> {
208    let mut fragments = IndexVec::from_elem(None, &body.local_decls);
209
210    for local in body.local_decls.indices() {
211        if escaping.contains(local) {
212            continue;
213        }
214        let decl = body.local_decls[local].clone();
215        let ty = decl.ty;
216        iter_fields(ty, tcx, typing_env, |variant, field, field_ty| {
217            if variant.is_some() {
218                // Downcasts are currently not supported.
219                return;
220            };
221            let new_local =
222                body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
223            fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
224        });
225    }
226    ReplacementMap { fragments }
227}
228
229/// Perform the replacement computed by `compute_flattening`.
230fn replace_flattened_locals<'tcx>(
231    tcx: TyCtxt<'tcx>,
232    body: &mut Body<'tcx>,
233    replacements: ReplacementMap<'tcx>,
234) -> DenseBitSet<Local> {
235    let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len());
236    for (local, replacements) in replacements.fragments.iter_enumerated() {
237        if replacements.is_some() {
238            all_dead_locals.insert(local);
239        }
240    }
241    debug!(?all_dead_locals);
242    if all_dead_locals.is_empty() {
243        return all_dead_locals;
244    }
245
246    let mut visitor = ReplacementVisitor {
247        tcx,
248        local_decls: &body.local_decls,
249        replacements: &replacements,
250        all_dead_locals,
251        patch: MirPatch::new(body),
252    };
253    for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
254        visitor.visit_basic_block_data(bb, data);
255    }
256    for scope in &mut body.source_scopes {
257        visitor.visit_source_scope_data(scope);
258    }
259    for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
260        visitor.visit_user_type_annotation(index, annotation);
261    }
262    visitor.expand_var_debug_info(&mut body.var_debug_info);
263    let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
264    patch.apply(body);
265    all_dead_locals
266}
267
268struct ReplacementVisitor<'tcx, 'll> {
269    tcx: TyCtxt<'tcx>,
270    /// This is only used to compute the type for `VarDebugInfoFragment`.
271    local_decls: &'ll LocalDecls<'tcx>,
272    /// Work to do.
273    replacements: &'ll ReplacementMap<'tcx>,
274    /// This is used to check that we are not leaving references to replaced locals behind.
275    all_dead_locals: DenseBitSet<Local>,
276    patch: MirPatch<'tcx>,
277}
278
279impl<'tcx> ReplacementVisitor<'tcx, '_> {
280    #[instrument(level = "trace", skip(self))]
281    fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) {
282        var_debug_info.flat_map_in_place(|mut var_debug_info| {
283            let place = match var_debug_info.value {
284                VarDebugInfoContents::Const(_) => return vec![var_debug_info],
285                VarDebugInfoContents::Place(ref mut place) => place,
286            };
287
288            if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
289                *place = repl;
290                return vec![var_debug_info];
291            }
292
293            let Some(parts) = self.replacements.place_fragments(*place) else {
294                return vec![var_debug_info];
295            };
296
297            let ty = place.ty(self.local_decls, self.tcx).ty;
298
299            parts
300                .map(|(field, field_ty, replacement_local)| {
301                    let mut var_debug_info = var_debug_info.clone();
302                    let composite = var_debug_info.composite.get_or_insert_with(|| {
303                        Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
304                    });
305                    composite.projection.push(PlaceElem::Field(field, field_ty));
306
307                    var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
308                    var_debug_info
309                })
310                .collect()
311        });
312    }
313}
314
315impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
316    fn tcx(&self) -> TyCtxt<'tcx> {
317        self.tcx
318    }
319
320    fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
321        if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
322            *place = repl
323        } else {
324            self.super_place(place, context, location)
325        }
326    }
327
328    #[instrument(level = "trace", skip(self))]
329    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
330        match statement.kind {
331            // Duplicate storage and deinit statements, as they pretty much apply to all fields.
332            StatementKind::StorageLive(l) => {
333                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
334                    for (_, _, fl) in final_locals {
335                        self.patch.add_statement(location, StatementKind::StorageLive(fl));
336                    }
337                    statement.make_nop();
338                }
339                return;
340            }
341            StatementKind::StorageDead(l) => {
342                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
343                    for (_, _, fl) in final_locals {
344                        self.patch.add_statement(location, StatementKind::StorageDead(fl));
345                    }
346                    statement.make_nop();
347                }
348                return;
349            }
350            StatementKind::Deinit(box place) => {
351                if let Some(final_locals) = self.replacements.place_fragments(place) {
352                    for (_, _, fl) in final_locals {
353                        self.patch
354                            .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
355                    }
356                    statement.make_nop();
357                    return;
358                }
359            }
360
361            // We have `a = Struct { 0: x, 1: y, .. }`.
362            // We replace it by
363            // ```
364            // a_0 = x
365            // a_1 = y
366            // ...
367            // ```
368            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
369                if let Some(local) = place.as_local()
370                    && let Some(final_locals) = &self.replacements.fragments[local]
371                {
372                    // This is ok as we delete the statement later.
373                    let operands = std::mem::take(operands);
374                    for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
375                        if let Some((_, new_local)) = opt_ty_local {
376                            // Replace mentions of SROA'd locals that appear in the operand.
377                            self.visit_operand(&mut operand, location);
378
379                            let rvalue = Rvalue::Use(operand);
380                            self.patch.add_statement(
381                                location,
382                                StatementKind::Assign(Box::new((new_local.into(), rvalue))),
383                            );
384                        }
385                    }
386                    statement.make_nop();
387                    return;
388                }
389            }
390
391            // We have `a = some constant`
392            // We add the projections.
393            // ```
394            // a_0 = a.0
395            // a_1 = a.1
396            // ...
397            // ```
398            // ConstProp will pick up the pieces and replace them by actual constants.
399            StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
400                if let Some(final_locals) = self.replacements.place_fragments(place) {
401                    // Put the deaggregated statements *after* the original one.
402                    let location = location.successor_within_block();
403                    for (field, ty, new_local) in final_locals {
404                        let rplace = self.tcx.mk_place_field(place, field, ty);
405                        let rvalue = Rvalue::Use(Operand::Move(rplace));
406                        self.patch.add_statement(
407                            location,
408                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
409                        );
410                    }
411                    // We still need `place.local` to exist, so don't make it nop.
412                    return;
413                }
414            }
415
416            // We have `a = move? place`
417            // We replace it by
418            // ```
419            // a_0 = move? place.0
420            // a_1 = move? place.1
421            // ...
422            // ```
423            StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
424                let (rplace, copy) = match *op {
425                    Operand::Copy(rplace) => (rplace, true),
426                    Operand::Move(rplace) => (rplace, false),
427                    Operand::Constant(_) => bug!(),
428                };
429                if let Some(final_locals) = self.replacements.place_fragments(lhs) {
430                    for (field, ty, new_local) in final_locals {
431                        let rplace = self.tcx.mk_place_field(rplace, field, ty);
432                        debug!(?rplace);
433                        let rplace = self
434                            .replacements
435                            .replace_place(self.tcx, rplace.as_ref())
436                            .unwrap_or(rplace);
437                        debug!(?rplace);
438                        let rvalue = if copy {
439                            Rvalue::Use(Operand::Copy(rplace))
440                        } else {
441                            Rvalue::Use(Operand::Move(rplace))
442                        };
443                        self.patch.add_statement(
444                            location,
445                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
446                        );
447                    }
448                    statement.make_nop();
449                    return;
450                }
451            }
452
453            _ => {}
454        }
455        self.super_statement(statement, location)
456    }
457
458    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
459        assert!(!self.all_dead_locals.contains(*local));
460    }
461}