rustc_mir_transform/
large_enums.rs

1use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants};
2use rustc_data_structures::fx::FxHashMap;
3use rustc_middle::mir::interpret::AllocId;
4use rustc_middle::mir::*;
5use rustc_middle::ty::util::IntTypeExt;
6use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
7use rustc_session::Session;
8
9/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
10/// enough discrepancy between them.
11///
12/// i.e. If there are two variants:
13/// ```
14/// enum Example {
15///   Small,
16///   Large([u32; 1024]),
17/// }
18/// ```
19/// Instead of emitting moves of the large variant, perform a memcpy instead.
20/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
21///
22/// In summary, what this does is at runtime determine which enum variant is active,
23/// and instead of copying all the bytes of the largest possible variant,
24/// copy only the bytes for the currently active variant.
25pub(super) struct EnumSizeOpt {
26    pub(crate) discrepancy: u64,
27}
28
29impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
30    fn is_enabled(&self, sess: &Session) -> bool {
31        // There are some differences in behavior on wasm and ARM that are not properly
32        // understood, so we conservatively treat this optimization as unsound:
33        // https://github.com/rust-lang/rust/pull/85158#issuecomment-1101836457
34        sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
35    }
36
37    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
38        // NOTE: This pass may produce different MIR based on the alignment of the target
39        // platform, but it will still be valid.
40
41        let mut alloc_cache = FxHashMap::default();
42        let typing_env = body.typing_env(tcx);
43
44        let blocks = body.basic_blocks.as_mut();
45        let local_decls = &mut body.local_decls;
46
47        for bb in blocks {
48            bb.expand_statements(|st| {
49                let StatementKind::Assign(box (
50                    lhs,
51                    Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
52                )) = &st.kind
53                else {
54                    return None;
55                };
56
57                let ty = lhs.ty(local_decls, tcx).ty;
58
59                let (adt_def, num_variants, alloc_id) =
60                    self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
61
62                let source_info = st.source_info;
63                let span = source_info.span;
64
65                let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
66                let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
67                let store_live =
68                    Statement { source_info, kind: StatementKind::StorageLive(size_array_local) };
69
70                let place = Place::from(size_array_local);
71                let constant_vals = ConstOperand {
72                    span,
73                    user_ty: None,
74                    const_: Const::Val(
75                        ConstValue::Indirect { alloc_id, offset: Size::ZERO },
76                        tmp_ty,
77                    ),
78                };
79                let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
80                let const_assign =
81                    Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };
82
83                let discr_place = Place::from(
84                    local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
85                );
86                let store_discr = Statement {
87                    source_info,
88                    kind: StatementKind::Assign(Box::new((
89                        discr_place,
90                        Rvalue::Discriminant(*rhs),
91                    ))),
92                };
93
94                let discr_cast_place =
95                    Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
96                let cast_discr = Statement {
97                    source_info,
98                    kind: StatementKind::Assign(Box::new((
99                        discr_cast_place,
100                        Rvalue::Cast(
101                            CastKind::IntToInt,
102                            Operand::Copy(discr_place),
103                            tcx.types.usize,
104                        ),
105                    ))),
106                };
107
108                let size_place =
109                    Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
110                let store_size = Statement {
111                    source_info,
112                    kind: StatementKind::Assign(Box::new((
113                        size_place,
114                        Rvalue::Use(Operand::Copy(Place {
115                            local: size_array_local,
116                            projection: tcx
117                                .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
118                        })),
119                    ))),
120                };
121
122                let dst =
123                    Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span)));
124                let dst_ptr = Statement {
125                    source_info,
126                    kind: StatementKind::Assign(Box::new((
127                        dst,
128                        Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
129                    ))),
130                };
131
132                let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
133                let dst_cast_place =
134                    Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
135                let dst_cast = Statement {
136                    source_info,
137                    kind: StatementKind::Assign(Box::new((
138                        dst_cast_place,
139                        Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
140                    ))),
141                };
142
143                let src =
144                    Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span)));
145                let src_ptr = Statement {
146                    source_info,
147                    kind: StatementKind::Assign(Box::new((
148                        src,
149                        Rvalue::RawPtr(RawPtrKind::Const, *rhs),
150                    ))),
151                };
152
153                let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
154                let src_cast_place =
155                    Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
156                let src_cast = Statement {
157                    source_info,
158                    kind: StatementKind::Assign(Box::new((
159                        src_cast_place,
160                        Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
161                    ))),
162                };
163
164                let deinit_old =
165                    Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
166
167                let copy_bytes = Statement {
168                    source_info,
169                    kind: StatementKind::Intrinsic(Box::new(
170                        NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
171                            src: Operand::Copy(src_cast_place),
172                            dst: Operand::Copy(dst_cast_place),
173                            count: Operand::Copy(size_place),
174                        }),
175                    )),
176                };
177
178                let store_dead =
179                    Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
180
181                let iter = [
182                    store_live,
183                    const_assign,
184                    store_discr,
185                    cast_discr,
186                    store_size,
187                    dst_ptr,
188                    dst_cast,
189                    src_ptr,
190                    src_cast,
191                    deinit_old,
192                    copy_bytes,
193                    store_dead,
194                ]
195                .into_iter();
196
197                st.make_nop();
198
199                Some(iter)
200            });
201        }
202    }
203
204    fn is_required(&self) -> bool {
205        false
206    }
207}
208
209impl EnumSizeOpt {
210    fn candidate<'tcx>(
211        &self,
212        tcx: TyCtxt<'tcx>,
213        typing_env: ty::TypingEnv<'tcx>,
214        ty: Ty<'tcx>,
215        alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
216    ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
217        let adt_def = match ty.kind() {
218            ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def,
219            _ => return None,
220        };
221        let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?;
222        let variants = match &layout.variants {
223            Variants::Single { .. } | Variants::Empty => return None,
224            Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None,
225
226            Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
227            Variants::Multiple { variants, .. } => variants,
228        };
229        let min = variants.iter().map(|v| v.size).min().unwrap();
230        let max = variants.iter().map(|v| v.size).max().unwrap();
231        if max.bytes() - min.bytes() < self.discrepancy {
232            return None;
233        }
234
235        let num_discrs = adt_def.discriminants(tcx).count();
236        if variants.iter_enumerated().any(|(var_idx, _)| {
237            let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
238            (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
239        }) {
240            return None;
241        }
242        if let Some(alloc_id) = alloc_cache.get(&ty) {
243            return Some((*adt_def, num_discrs, *alloc_id));
244        }
245
246        let data_layout = tcx.data_layout();
247        let ptr_sized_int = data_layout.ptr_sized_integer();
248        let target_bytes = ptr_sized_int.size().bytes() as usize;
249        let mut data = vec![0; target_bytes * num_discrs];
250
251        // We use a macro because `$bytes` can be u32 or u64.
252        macro_rules! encode_store {
253            ($curr_idx: expr, $endian: expr, $bytes: expr) => {
254                let bytes = match $endian {
255                    rustc_abi::Endian::Little => $bytes.to_le_bytes(),
256                    rustc_abi::Endian::Big => $bytes.to_be_bytes(),
257                };
258                for (i, b) in bytes.into_iter().enumerate() {
259                    data[$curr_idx + i] = b;
260                }
261            };
262        }
263
264        for (var_idx, layout) in variants.iter_enumerated() {
265            let curr_idx =
266                target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
267            let sz = layout.size;
268            match ptr_sized_int {
269                rustc_abi::Integer::I32 => {
270                    encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
271                }
272                rustc_abi::Integer::I64 => {
273                    encode_store!(curr_idx, data_layout.endian, sz.bytes());
274                }
275                _ => unreachable!(),
276            };
277        }
278        let alloc = interpret::Allocation::from_bytes(
279            data,
280            tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
281            Mutability::Not,
282        );
283        let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc));
284        Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
285    }
286}