rustc_mir_transform/
large_enums.rs

1use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants};
2use rustc_data_structures::fx::FxHashMap;
3use rustc_middle::mir::interpret::AllocId;
4use rustc_middle::mir::*;
5use rustc_middle::ty::util::IntTypeExt;
6use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
7use rustc_session::Session;
8
9use crate::patch::MirPatch;
10
11/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
12/// enough discrepancy between them.
13///
14/// i.e. If there are two variants:
15/// ```
16/// enum Example {
17///   Small,
18///   Large([u32; 1024]),
19/// }
20/// ```
21/// Instead of emitting moves of the large variant, perform a memcpy instead.
22/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
23///
24/// In summary, what this does is at runtime determine which enum variant is active,
25/// and instead of copying all the bytes of the largest possible variant,
26/// copy only the bytes for the currently active variant.
27pub(super) struct EnumSizeOpt {
28    pub(crate) discrepancy: u64,
29}
30
31impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
32    fn is_enabled(&self, sess: &Session) -> bool {
33        // There are some differences in behavior on wasm and ARM that are not properly
34        // understood, so we conservatively treat this optimization as unsound:
35        // https://github.com/rust-lang/rust/pull/85158#issuecomment-1101836457
36        sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
37    }
38
39    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
40        // NOTE: This pass may produce different MIR based on the alignment of the target
41        // platform, but it will still be valid.
42
43        let mut alloc_cache = FxHashMap::default();
44        let typing_env = body.typing_env(tcx);
45
46        let mut patch = MirPatch::new(body);
47
48        for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
49            for (statement_index, st) in data.statements.iter_mut().enumerate() {
50                let StatementKind::Assign(box (
51                    lhs,
52                    Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
53                )) = &st.kind
54                else {
55                    continue;
56                };
57
58                let location = Location { block, statement_index };
59
60                let ty = lhs.ty(&body.local_decls, tcx).ty;
61
62                let Some((adt_def, num_variants, alloc_id)) =
63                    self.candidate(tcx, typing_env, ty, &mut alloc_cache)
64                else {
65                    continue;
66                };
67
68                let span = st.source_info.span;
69
70                let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
71                let size_array_local = patch.new_temp(tmp_ty, span);
72
73                let store_live = StatementKind::StorageLive(size_array_local);
74
75                let place = Place::from(size_array_local);
76                let constant_vals = ConstOperand {
77                    span,
78                    user_ty: None,
79                    const_: Const::Val(
80                        ConstValue::Indirect { alloc_id, offset: Size::ZERO },
81                        tmp_ty,
82                    ),
83                };
84                let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
85                let const_assign = StatementKind::Assign(Box::new((place, rval)));
86
87                let discr_place =
88                    Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
89                let store_discr =
90                    StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
91
92                let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
93                let cast_discr = StatementKind::Assign(Box::new((
94                    discr_cast_place,
95                    Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
96                )));
97
98                let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
99                let store_size = StatementKind::Assign(Box::new((
100                    size_place,
101                    Rvalue::Use(Operand::Copy(Place {
102                        local: size_array_local,
103                        projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
104                    })),
105                )));
106
107                let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
108                let dst_ptr =
109                    StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
110
111                let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
112                let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
113                let dst_cast = StatementKind::Assign(Box::new((
114                    dst_cast_place,
115                    Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
116                )));
117
118                let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
119                let src_ptr =
120                    StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
121
122                let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
123                let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
124                let src_cast = StatementKind::Assign(Box::new((
125                    src_cast_place,
126                    Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
127                )));
128
129                let deinit_old = StatementKind::Deinit(Box::new(dst));
130
131                let copy_bytes = StatementKind::Intrinsic(Box::new(
132                    NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
133                        src: Operand::Copy(src_cast_place),
134                        dst: Operand::Copy(dst_cast_place),
135                        count: Operand::Copy(size_place),
136                    }),
137                ));
138
139                let store_dead = StatementKind::StorageDead(size_array_local);
140
141                let stmts = [
142                    store_live,
143                    const_assign,
144                    store_discr,
145                    cast_discr,
146                    store_size,
147                    dst_ptr,
148                    dst_cast,
149                    src_ptr,
150                    src_cast,
151                    deinit_old,
152                    copy_bytes,
153                    store_dead,
154                ];
155                for stmt in stmts {
156                    patch.add_statement(location, stmt);
157                }
158
159                st.make_nop();
160            }
161        }
162
163        patch.apply(body);
164    }
165
166    fn is_required(&self) -> bool {
167        false
168    }
169}
170
171impl EnumSizeOpt {
172    fn candidate<'tcx>(
173        &self,
174        tcx: TyCtxt<'tcx>,
175        typing_env: ty::TypingEnv<'tcx>,
176        ty: Ty<'tcx>,
177        alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
178    ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
179        let adt_def = match ty.kind() {
180            ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def,
181            _ => return None,
182        };
183        let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?;
184        let variants = match &layout.variants {
185            Variants::Single { .. } | Variants::Empty => return None,
186            Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None,
187
188            Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
189            Variants::Multiple { variants, .. } => variants,
190        };
191        let min = variants.iter().map(|v| v.size).min().unwrap();
192        let max = variants.iter().map(|v| v.size).max().unwrap();
193        if max.bytes() - min.bytes() < self.discrepancy {
194            return None;
195        }
196
197        let num_discrs = adt_def.discriminants(tcx).count();
198        if variants.iter_enumerated().any(|(var_idx, _)| {
199            let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
200            (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
201        }) {
202            return None;
203        }
204        if let Some(alloc_id) = alloc_cache.get(&ty) {
205            return Some((*adt_def, num_discrs, *alloc_id));
206        }
207
208        let data_layout = tcx.data_layout();
209        let ptr_sized_int = data_layout.ptr_sized_integer();
210        let target_bytes = ptr_sized_int.size().bytes() as usize;
211        let mut data = vec![0; target_bytes * num_discrs];
212
213        // We use a macro because `$bytes` can be u32 or u64.
214        macro_rules! encode_store {
215            ($curr_idx: expr, $endian: expr, $bytes: expr) => {
216                let bytes = match $endian {
217                    rustc_abi::Endian::Little => $bytes.to_le_bytes(),
218                    rustc_abi::Endian::Big => $bytes.to_be_bytes(),
219                };
220                for (i, b) in bytes.into_iter().enumerate() {
221                    data[$curr_idx + i] = b;
222                }
223            };
224        }
225
226        for (var_idx, layout) in variants.iter_enumerated() {
227            let curr_idx =
228                target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
229            let sz = layout.size;
230            match ptr_sized_int {
231                rustc_abi::Integer::I32 => {
232                    encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
233                }
234                rustc_abi::Integer::I64 => {
235                    encode_store!(curr_idx, data_layout.endian, sz.bytes());
236                }
237                _ => unreachable!(),
238            };
239        }
240        let alloc = interpret::Allocation::from_bytes(
241            data,
242            tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
243            Mutability::Not,
244        );
245        let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc));
246        Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
247    }
248}