1use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants};
2use rustc_const_eval::interpret::{Scalar, alloc_range};
3use rustc_data_structures::fx::FxHashMap;
4use rustc_middle::mir::interpret::AllocId;
5use rustc_middle::mir::*;
6use rustc_middle::ty::util::IntTypeExt;
7use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
8use rustc_session::Session;
9
10use crate::patch::MirPatch;
11
12pub(super) struct EnumSizeOpt {
30 pub(crate) discrepancy: u64,
31}
32
33impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
34 fn is_enabled(&self, sess: &Session) -> bool {
35 sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() >= 3
39 }
40
41 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
42 let mut alloc_cache = FxHashMap::default();
46 let typing_env = body.typing_env(tcx);
47
48 let mut patch = MirPatch::new(body);
49
50 for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
51 for (statement_index, st) in data.statements.iter_mut().enumerate() {
52 let StatementKind::Assign(box (
53 lhs,
54 Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
55 )) = &st.kind
56 else {
57 continue;
58 };
59
60 let location = Location { block, statement_index };
61
62 let ty = lhs.ty(&body.local_decls, tcx).ty;
63
64 let Some((adt_def, num_variants, alloc_id)) =
65 self.candidate(tcx, typing_env, ty, &mut alloc_cache)
66 else {
67 continue;
68 };
69
70 let span = st.source_info.span;
71
72 let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
73 let size_array_local = patch.new_temp(tmp_ty, span);
74
75 let store_live = StatementKind::StorageLive(size_array_local);
76
77 let place = Place::from(size_array_local);
78 let constant_vals = ConstOperand {
79 span,
80 user_ty: None,
81 const_: Const::Val(
82 ConstValue::Indirect { alloc_id, offset: Size::ZERO },
83 tmp_ty,
84 ),
85 };
86 let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
87 let const_assign = StatementKind::Assign(Box::new((place, rval)));
88
89 let discr_place =
90 Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
91 let store_discr =
92 StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
93
94 let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
95 let cast_discr = StatementKind::Assign(Box::new((
96 discr_cast_place,
97 Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
98 )));
99
100 let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
101 let store_size = StatementKind::Assign(Box::new((
102 size_place,
103 Rvalue::Use(Operand::Copy(Place {
104 local: size_array_local,
105 projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
106 })),
107 )));
108
109 let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
110 let dst_ptr =
111 StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
112
113 let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
114 let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
115 let dst_cast = StatementKind::Assign(Box::new((
116 dst_cast_place,
117 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
118 )));
119
120 let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
121 let src_ptr =
122 StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
123
124 let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
125 let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
126 let src_cast = StatementKind::Assign(Box::new((
127 src_cast_place,
128 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
129 )));
130
131 let copy_bytes = StatementKind::Intrinsic(Box::new(
132 NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
133 src: Operand::Copy(src_cast_place),
134 dst: Operand::Copy(dst_cast_place),
135 count: Operand::Copy(size_place),
136 }),
137 ));
138
139 let store_dead = StatementKind::StorageDead(size_array_local);
140
141 let stmts = [
142 store_live,
143 const_assign,
144 store_discr,
145 cast_discr,
146 store_size,
147 dst_ptr,
148 dst_cast,
149 src_ptr,
150 src_cast,
151 copy_bytes,
152 store_dead,
153 ];
154 for stmt in stmts {
155 patch.add_statement(location, stmt);
156 }
157
158 st.make_nop(true);
159 }
160 }
161
162 patch.apply(body);
163 }
164
165 fn is_required(&self) -> bool {
166 false
167 }
168}
169
170impl EnumSizeOpt {
171 fn candidate<'tcx>(
172 &self,
173 tcx: TyCtxt<'tcx>,
174 typing_env: ty::TypingEnv<'tcx>,
175 ty: Ty<'tcx>,
176 alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
177 ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
178 let adt_def = match ty.kind() {
179 ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def,
180 _ => return None,
181 };
182 let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?;
183 let variants = match &layout.variants {
184 Variants::Single { .. } | Variants::Empty => return None,
185 Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None,
186
187 Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
188 Variants::Multiple { variants, .. } => variants,
189 };
190 let min = variants.iter().map(|v| v.size).min().unwrap();
191 let max = variants.iter().map(|v| v.size).max().unwrap();
192 if max.bytes() - min.bytes() < self.discrepancy {
193 return None;
194 }
195
196 let num_discrs = adt_def.discriminants(tcx).count();
197 if variants.iter_enumerated().any(|(var_idx, _)| {
198 let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
199 (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
200 }) {
201 return None;
202 }
203 if let Some(alloc_id) = alloc_cache.get(&ty) {
204 return Some((*adt_def, num_discrs, *alloc_id));
205 }
206
207 let data_layout = tcx.data_layout();
209 let ptr_size = data_layout.pointer_size();
210 let mut alloc = interpret::Allocation::from_bytes(
211 vec![0; ptr_size.bytes_usize() * num_discrs],
212 tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
213 Mutability::Mut,
214 (),
215 );
216 for (var_idx, layout) in variants.iter_enumerated() {
217 let curr_idx = ptr_size * adt_def.discriminant_for_variant(tcx, var_idx).val as u64;
218 let val = Scalar::from_target_usize(layout.size.bytes(), &tcx);
219 alloc.write_scalar(&tcx, alloc_range(curr_idx, val.size()), val).unwrap();
220 }
221 alloc.mutability = Mutability::Not;
222 let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc));
223
224 Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
225 }
226}