1use rustc_abi::{FIRST_VARIANT, FieldIdx};
2use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
6use rustc_middle::bug;
7use rustc_middle::mir::visit::*;
8use rustc_middle::mir::*;
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
11use tracing::{debug, instrument};
12
13use crate::patch::MirPatch;
14
15pub(super) struct ScalarReplacementOfAggregates;
16
17impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates {
18 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19 sess.mir_opt_level() >= 2
20 }
21
22 #[instrument(level = "debug", skip(self, tcx, body))]
23 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24 debug!(def_id = ?body.source.def_id());
25
26 if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
28 return;
29 }
30
31 let mut excluded = excluded_locals(body);
32 let typing_env = body.typing_env(tcx);
33 loop {
34 debug!(?excluded);
35 let escaping = escaping_locals(tcx, typing_env, &excluded, body);
36 debug!(?escaping);
37 let replacements = compute_flattening(tcx, typing_env, body, escaping);
38 debug!(?replacements);
39 let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
40 if !all_dead_locals.is_empty() {
41 excluded.union(&all_dead_locals);
42 excluded = {
43 let mut growable = GrowableBitSet::from(excluded);
44 growable.ensure(body.local_decls.len());
45 growable.into()
46 };
47 } else {
48 break;
49 }
50 }
51 }
52
53 fn is_required(&self) -> bool {
54 false
55 }
56}
57
58fn escaping_locals<'tcx>(
66 tcx: TyCtxt<'tcx>,
67 typing_env: ty::TypingEnv<'tcx>,
68 excluded: &DenseBitSet<Local>,
69 body: &Body<'tcx>,
70) -> DenseBitSet<Local> {
71 let is_excluded_ty = |ty: Ty<'tcx>| {
72 if ty.is_union() || ty.is_enum() {
73 return true;
74 }
75 if let ty::Adt(def, _args) = ty.kind() {
76 if def.repr().simd() {
77 return true;
79 }
80 if tcx.is_lang_item(def.did(), LangItem::DynMetadata) {
81 return true;
84 }
85 let variant = def.variant(FIRST_VARIANT);
87 if variant.fields.len() > 1 {
88 return false;
91 }
92 let Ok(layout) = tcx.layout_of(typing_env.as_query_input(ty)) else {
93 return true;
95 };
96 if layout.layout.largest_niche().is_some() {
97 return true;
99 }
100 }
101 false
103 };
104
105 let mut set = DenseBitSet::new_empty(body.local_decls.len());
106 set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
107 for (local, decl) in body.local_decls().iter_enumerated() {
108 if excluded.contains(local) || is_excluded_ty(decl.ty) {
109 set.insert(local);
110 }
111 }
112 let mut visitor = EscapeVisitor { set };
113 visitor.visit_body(body);
114 return visitor.set;
115
116 struct EscapeVisitor {
117 set: DenseBitSet<Local>,
118 }
119
120 impl<'tcx> Visitor<'tcx> for EscapeVisitor {
121 fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
122 self.set.insert(local);
123 }
124
125 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
126 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
128 return;
129 }
130 self.super_place(place, context, location);
131 }
132
133 fn visit_assign(
134 &mut self,
135 lvalue: &Place<'tcx>,
136 rvalue: &Rvalue<'tcx>,
137 location: Location,
138 ) {
139 if lvalue.as_local().is_some() {
140 match rvalue {
141 Rvalue::Aggregate(..) | Rvalue::Use(..) => {
143 self.visit_rvalue(rvalue, location);
144 return;
145 }
146 _ => {}
147 }
148 }
149 self.super_assign(lvalue, rvalue, location)
150 }
151
152 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
153 match statement.kind {
154 StatementKind::StorageLive(..)
156 | StatementKind::StorageDead(..)
157 | StatementKind::Deinit(..) => return,
158 _ => self.super_statement(statement, location),
159 }
160 }
161
162 fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
165 }
166}
167
168#[derive(Default, Debug)]
169struct ReplacementMap<'tcx> {
170 fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
173}
174
175impl<'tcx> ReplacementMap<'tcx> {
176 fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
177 let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
178 return None;
179 };
180 let fields = self.fragments[place.local].as_ref()?;
181 let (_, new_local) = fields[f]?;
182 Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
183 }
184
185 fn place_fragments(
186 &self,
187 place: Place<'tcx>,
188 ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)> + '_> {
189 let local = place.as_local()?;
190 let fields = self.fragments[local].as_ref()?;
191 Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
192 let (ty, local) = opt_ty_local?;
193 Some((field, ty, local))
194 }))
195 }
196}
197
198fn compute_flattening<'tcx>(
203 tcx: TyCtxt<'tcx>,
204 typing_env: ty::TypingEnv<'tcx>,
205 body: &mut Body<'tcx>,
206 escaping: DenseBitSet<Local>,
207) -> ReplacementMap<'tcx> {
208 let mut fragments = IndexVec::from_elem(None, &body.local_decls);
209
210 for local in body.local_decls.indices() {
211 if escaping.contains(local) {
212 continue;
213 }
214 let decl = body.local_decls[local].clone();
215 let ty = decl.ty;
216 iter_fields(ty, tcx, typing_env, |variant, field, field_ty| {
217 if variant.is_some() {
218 return;
220 };
221 let new_local =
222 body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
223 fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
224 });
225 }
226 ReplacementMap { fragments }
227}
228
229fn replace_flattened_locals<'tcx>(
231 tcx: TyCtxt<'tcx>,
232 body: &mut Body<'tcx>,
233 replacements: ReplacementMap<'tcx>,
234) -> DenseBitSet<Local> {
235 let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len());
236 for (local, replacements) in replacements.fragments.iter_enumerated() {
237 if replacements.is_some() {
238 all_dead_locals.insert(local);
239 }
240 }
241 debug!(?all_dead_locals);
242 if all_dead_locals.is_empty() {
243 return all_dead_locals;
244 }
245
246 let mut visitor = ReplacementVisitor {
247 tcx,
248 local_decls: &body.local_decls,
249 replacements: &replacements,
250 all_dead_locals,
251 patch: MirPatch::new(body),
252 };
253 for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
254 visitor.visit_basic_block_data(bb, data);
255 }
256 for scope in &mut body.source_scopes {
257 visitor.visit_source_scope_data(scope);
258 }
259 for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
260 visitor.visit_user_type_annotation(index, annotation);
261 }
262 visitor.expand_var_debug_info(&mut body.var_debug_info);
263 let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
264 patch.apply(body);
265 all_dead_locals
266}
267
268struct ReplacementVisitor<'tcx, 'll> {
269 tcx: TyCtxt<'tcx>,
270 local_decls: &'ll LocalDecls<'tcx>,
272 replacements: &'ll ReplacementMap<'tcx>,
274 all_dead_locals: DenseBitSet<Local>,
276 patch: MirPatch<'tcx>,
277}
278
279impl<'tcx> ReplacementVisitor<'tcx, '_> {
280 #[instrument(level = "trace", skip(self))]
281 fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) {
282 var_debug_info.flat_map_in_place(|mut var_debug_info| {
283 let place = match var_debug_info.value {
284 VarDebugInfoContents::Const(_) => return vec![var_debug_info],
285 VarDebugInfoContents::Place(ref mut place) => place,
286 };
287
288 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
289 *place = repl;
290 return vec![var_debug_info];
291 }
292
293 let Some(parts) = self.replacements.place_fragments(*place) else {
294 return vec![var_debug_info];
295 };
296
297 let ty = place.ty(self.local_decls, self.tcx).ty;
298
299 parts
300 .map(|(field, field_ty, replacement_local)| {
301 let mut var_debug_info = var_debug_info.clone();
302 let composite = var_debug_info.composite.get_or_insert_with(|| {
303 Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
304 });
305 composite.projection.push(PlaceElem::Field(field, field_ty));
306
307 var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
308 var_debug_info
309 })
310 .collect()
311 });
312 }
313}
314
315impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
316 fn tcx(&self) -> TyCtxt<'tcx> {
317 self.tcx
318 }
319
320 fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
321 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
322 *place = repl
323 } else {
324 self.super_place(place, context, location)
325 }
326 }
327
328 #[instrument(level = "trace", skip(self))]
329 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
330 match statement.kind {
331 StatementKind::StorageLive(l) => {
333 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
334 for (_, _, fl) in final_locals {
335 self.patch.add_statement(location, StatementKind::StorageLive(fl));
336 }
337 statement.make_nop();
338 }
339 return;
340 }
341 StatementKind::StorageDead(l) => {
342 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
343 for (_, _, fl) in final_locals {
344 self.patch.add_statement(location, StatementKind::StorageDead(fl));
345 }
346 statement.make_nop();
347 }
348 return;
349 }
350 StatementKind::Deinit(box place) => {
351 if let Some(final_locals) = self.replacements.place_fragments(place) {
352 for (_, _, fl) in final_locals {
353 self.patch
354 .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
355 }
356 statement.make_nop();
357 return;
358 }
359 }
360
361 StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
369 if let Some(local) = place.as_local()
370 && let Some(final_locals) = &self.replacements.fragments[local]
371 {
372 let operands = std::mem::take(operands);
374 for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
375 if let Some((_, new_local)) = opt_ty_local {
376 self.visit_operand(&mut operand, location);
378
379 let rvalue = Rvalue::Use(operand);
380 self.patch.add_statement(
381 location,
382 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
383 );
384 }
385 }
386 statement.make_nop();
387 return;
388 }
389 }
390
391 StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
400 if let Some(final_locals) = self.replacements.place_fragments(place) {
401 let location = location.successor_within_block();
403 for (field, ty, new_local) in final_locals {
404 let rplace = self.tcx.mk_place_field(place, field, ty);
405 let rvalue = Rvalue::Use(Operand::Move(rplace));
406 self.patch.add_statement(
407 location,
408 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
409 );
410 }
411 return;
413 }
414 }
415
416 StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
424 let (rplace, copy) = match *op {
425 Operand::Copy(rplace) => (rplace, true),
426 Operand::Move(rplace) => (rplace, false),
427 Operand::Constant(_) => bug!(),
428 };
429 if let Some(final_locals) = self.replacements.place_fragments(lhs) {
430 for (field, ty, new_local) in final_locals {
431 let rplace = self.tcx.mk_place_field(rplace, field, ty);
432 debug!(?rplace);
433 let rplace = self
434 .replacements
435 .replace_place(self.tcx, rplace.as_ref())
436 .unwrap_or(rplace);
437 debug!(?rplace);
438 let rvalue = if copy {
439 Rvalue::Use(Operand::Copy(rplace))
440 } else {
441 Rvalue::Use(Operand::Move(rplace))
442 };
443 self.patch.add_statement(
444 location,
445 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
446 );
447 }
448 statement.make_nop();
449 return;
450 }
451 }
452
453 _ => {}
454 }
455 self.super_statement(statement, location)
456 }
457
458 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
459 assert!(!self.all_dead_locals.contains(*local));
460 }
461}