1use rustc_abi::FieldIdx;
2use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
6use rustc_middle::bug;
7use rustc_middle::mir::visit::*;
8use rustc_middle::mir::*;
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
11use tracing::{debug, instrument};
12
13use crate::patch::MirPatch;
14
15pub(super) struct ScalarReplacementOfAggregates;
16
17impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates {
18 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19 sess.mir_opt_level() >= 2
20 }
21
22 #[instrument(level = "debug", skip(self, tcx, body))]
23 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24 debug!(def_id = ?body.source.def_id());
25
26 if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
28 return;
29 }
30
31 let mut excluded = excluded_locals(body);
32 let typing_env = body.typing_env(tcx);
33 loop {
34 debug!(?excluded);
35 let escaping = escaping_locals(tcx, &excluded, body);
36 debug!(?escaping);
37 let replacements = compute_flattening(tcx, typing_env, body, escaping);
38 debug!(?replacements);
39 let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
40 if !all_dead_locals.is_empty() {
41 excluded.union(&all_dead_locals);
42 excluded = {
43 let mut growable = GrowableBitSet::from(excluded);
44 growable.ensure(body.local_decls.len());
45 growable.into()
46 };
47 } else {
48 break;
49 }
50 }
51 }
52
53 fn is_required(&self) -> bool {
54 false
55 }
56}
57
58fn escaping_locals<'tcx>(
66 tcx: TyCtxt<'tcx>,
67 excluded: &DenseBitSet<Local>,
68 body: &Body<'tcx>,
69) -> DenseBitSet<Local> {
70 let is_excluded_ty = |ty: Ty<'tcx>| {
71 if ty.is_union() || ty.is_enum() {
72 return true;
73 }
74 if let ty::Adt(def, _args) = ty.kind()
75 && (def.repr().simd() || tcx.is_lang_item(def.did(), LangItem::DynMetadata))
76 {
77 return true;
84 }
85 false
87 };
88
89 let mut set = DenseBitSet::new_empty(body.local_decls.len());
90 set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
91 for (local, decl) in body.local_decls().iter_enumerated() {
92 if excluded.contains(local) || is_excluded_ty(decl.ty) {
93 set.insert(local);
94 }
95 }
96 let mut visitor = EscapeVisitor { set };
97 visitor.visit_body(body);
98 return visitor.set;
99
100 struct EscapeVisitor {
101 set: DenseBitSet<Local>,
102 }
103
104 impl<'tcx> Visitor<'tcx> for EscapeVisitor {
105 fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
106 self.set.insert(local);
107 }
108
109 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
110 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
112 return;
113 }
114 self.super_place(place, context, location);
115 }
116
117 fn visit_assign(
118 &mut self,
119 lvalue: &Place<'tcx>,
120 rvalue: &Rvalue<'tcx>,
121 location: Location,
122 ) {
123 if lvalue.as_local().is_some() {
124 match rvalue {
125 Rvalue::Aggregate(..) | Rvalue::Use(..) => {
127 self.visit_rvalue(rvalue, location);
128 return;
129 }
130 _ => {}
131 }
132 }
133 self.super_assign(lvalue, rvalue, location)
134 }
135
136 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
137 match statement.kind {
138 StatementKind::StorageLive(..) | StatementKind::StorageDead(..) => return,
140 _ => self.super_statement(statement, location),
141 }
142 }
143
144 fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
147 }
148}
149
150#[derive(Default, Debug)]
151struct ReplacementMap<'tcx> {
152 fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
155}
156
157impl<'tcx> ReplacementMap<'tcx> {
158 fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
159 let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
160 return None;
161 };
162 let fields = self.fragments[place.local].as_ref()?;
163 let (_, new_local) = fields[f]?;
164 Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
165 }
166
167 fn place_fragments(
168 &self,
169 place: Place<'tcx>,
170 ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)>> {
171 let local = place.as_local()?;
172 let fields = self.fragments[local].as_ref()?;
173 Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
174 let (ty, local) = opt_ty_local?;
175 Some((field, ty, local))
176 }))
177 }
178}
179
180fn compute_flattening<'tcx>(
185 tcx: TyCtxt<'tcx>,
186 typing_env: ty::TypingEnv<'tcx>,
187 body: &mut Body<'tcx>,
188 escaping: DenseBitSet<Local>,
189) -> ReplacementMap<'tcx> {
190 let mut fragments = IndexVec::from_elem(None, &body.local_decls);
191
192 for local in body.local_decls.indices() {
193 if escaping.contains(local) {
194 continue;
195 }
196 let decl = body.local_decls[local].clone();
197 let ty = decl.ty;
198 iter_fields(ty, tcx, typing_env, |variant, field, field_ty| {
199 if variant.is_some() {
200 return;
202 };
203 let new_local =
204 body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
205 fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
206 });
207 }
208 ReplacementMap { fragments }
209}
210
211fn replace_flattened_locals<'tcx>(
213 tcx: TyCtxt<'tcx>,
214 body: &mut Body<'tcx>,
215 replacements: ReplacementMap<'tcx>,
216) -> DenseBitSet<Local> {
217 let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len());
218 for (local, replacements) in replacements.fragments.iter_enumerated() {
219 if replacements.is_some() {
220 all_dead_locals.insert(local);
221 }
222 }
223 debug!(?all_dead_locals);
224 if all_dead_locals.is_empty() {
225 return all_dead_locals;
226 }
227
228 let mut visitor = ReplacementVisitor {
229 tcx,
230 local_decls: &body.local_decls,
231 replacements: &replacements,
232 all_dead_locals,
233 patch: MirPatch::new(body),
234 };
235 for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
236 visitor.visit_basic_block_data(bb, data);
237 }
238 for scope in &mut body.source_scopes {
239 visitor.visit_source_scope_data(scope);
240 }
241 for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
242 visitor.visit_user_type_annotation(index, annotation);
243 }
244 visitor.expand_var_debug_info(&mut body.var_debug_info);
245 let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
246 patch.apply(body);
247 all_dead_locals
248}
249
250struct ReplacementVisitor<'tcx, 'll> {
251 tcx: TyCtxt<'tcx>,
252 local_decls: &'ll LocalDecls<'tcx>,
254 replacements: &'ll ReplacementMap<'tcx>,
256 all_dead_locals: DenseBitSet<Local>,
258 patch: MirPatch<'tcx>,
259}
260
261impl<'tcx> ReplacementVisitor<'tcx, '_> {
262 #[instrument(level = "trace", skip(self))]
263 fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) {
264 var_debug_info.flat_map_in_place(|mut var_debug_info| {
265 let place = match var_debug_info.value {
266 VarDebugInfoContents::Const(_) => return vec![var_debug_info],
267 VarDebugInfoContents::Place(ref mut place) => place,
268 };
269
270 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
271 *place = repl;
272 return vec![var_debug_info];
273 }
274
275 let Some(parts) = self.replacements.place_fragments(*place) else {
276 return vec![var_debug_info];
277 };
278
279 let ty = place.ty(self.local_decls, self.tcx).ty;
280
281 parts
282 .map(|(field, field_ty, replacement_local)| {
283 let mut var_debug_info = var_debug_info.clone();
284 let composite = var_debug_info.composite.get_or_insert_with(|| {
285 Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
286 });
287 composite.projection.push(PlaceElem::Field(field, field_ty));
288
289 var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
290 var_debug_info
291 })
292 .collect()
293 });
294 }
295}
296
297impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
298 fn tcx(&self) -> TyCtxt<'tcx> {
299 self.tcx
300 }
301
302 fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
303 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
304 *place = repl
305 } else {
306 self.super_place(place, context, location)
307 }
308 }
309
310 #[instrument(level = "trace", skip(self))]
311 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
312 match statement.kind {
313 StatementKind::StorageLive(l) => {
315 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
316 for (_, _, fl) in final_locals {
317 self.patch.add_statement(location, StatementKind::StorageLive(fl));
318 }
319 statement.make_nop(true);
320 }
321 return;
322 }
323 StatementKind::StorageDead(l) => {
324 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
325 for (_, _, fl) in final_locals {
326 self.patch.add_statement(location, StatementKind::StorageDead(fl));
327 }
328 statement.make_nop(true);
329 }
330 return;
331 }
332
333 StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
341 if let Some(local) = place.as_local()
342 && let Some(final_locals) = &self.replacements.fragments[local]
343 {
344 let operands = std::mem::take(operands);
346 for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
347 if let Some((_, new_local)) = opt_ty_local {
348 self.visit_operand(&mut operand, location);
350
351 let rvalue = Rvalue::Use(operand);
352 self.patch.add_statement(
353 location,
354 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
355 );
356 }
357 }
358 statement.make_nop(true);
359 return;
360 }
361 }
362
363 StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
372 if let Some(final_locals) = self.replacements.place_fragments(place) {
373 let location = location.successor_within_block();
375 for (field, ty, new_local) in final_locals {
376 let rplace = self.tcx.mk_place_field(place, field, ty);
377 let rvalue = Rvalue::Use(Operand::Move(rplace));
378 self.patch.add_statement(
379 location,
380 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
381 );
382 }
383 return;
385 }
386 }
387
388 StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
396 let (rplace, copy) = match *op {
397 Operand::Copy(rplace) => (rplace, true),
398 Operand::Move(rplace) => (rplace, false),
399 Operand::Constant(_) => bug!(),
400 };
401 if let Some(final_locals) = self.replacements.place_fragments(lhs) {
402 for (field, ty, new_local) in final_locals {
403 let rplace = self.tcx.mk_place_field(rplace, field, ty);
404 debug!(?rplace);
405 let rplace = self
406 .replacements
407 .replace_place(self.tcx, rplace.as_ref())
408 .unwrap_or(rplace);
409 debug!(?rplace);
410 let rvalue = if copy {
411 Rvalue::Use(Operand::Copy(rplace))
412 } else {
413 Rvalue::Use(Operand::Move(rplace))
414 };
415 self.patch.add_statement(
416 location,
417 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
418 );
419 }
420 statement.make_nop(true);
421 return;
422 }
423 }
424
425 _ => {}
426 }
427 self.super_statement(statement, location)
428 }
429
430 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
431 assert!(!self.all_dead_locals.contains(*local));
432 }
433}