1use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants};
2use rustc_data_structures::fx::FxHashMap;
3use rustc_middle::mir::interpret::AllocId;
4use rustc_middle::mir::*;
5use rustc_middle::ty::util::IntTypeExt;
6use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
7use rustc_session::Session;
8
9pub(super) struct EnumSizeOpt {
26 pub(crate) discrepancy: u64,
27}
28
29impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
30 fn is_enabled(&self, sess: &Session) -> bool {
31 sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
35 }
36
37 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
38 let mut alloc_cache = FxHashMap::default();
42 let typing_env = body.typing_env(tcx);
43
44 let blocks = body.basic_blocks.as_mut();
45 let local_decls = &mut body.local_decls;
46
47 for bb in blocks {
48 bb.expand_statements(|st| {
49 let StatementKind::Assign(box (
50 lhs,
51 Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
52 )) = &st.kind
53 else {
54 return None;
55 };
56
57 let ty = lhs.ty(local_decls, tcx).ty;
58
59 let (adt_def, num_variants, alloc_id) =
60 self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
61
62 let source_info = st.source_info;
63 let span = source_info.span;
64
65 let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
66 let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
67 let store_live =
68 Statement { source_info, kind: StatementKind::StorageLive(size_array_local) };
69
70 let place = Place::from(size_array_local);
71 let constant_vals = ConstOperand {
72 span,
73 user_ty: None,
74 const_: Const::Val(
75 ConstValue::Indirect { alloc_id, offset: Size::ZERO },
76 tmp_ty,
77 ),
78 };
79 let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
80 let const_assign =
81 Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };
82
83 let discr_place = Place::from(
84 local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
85 );
86 let store_discr = Statement {
87 source_info,
88 kind: StatementKind::Assign(Box::new((
89 discr_place,
90 Rvalue::Discriminant(*rhs),
91 ))),
92 };
93
94 let discr_cast_place =
95 Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
96 let cast_discr = Statement {
97 source_info,
98 kind: StatementKind::Assign(Box::new((
99 discr_cast_place,
100 Rvalue::Cast(
101 CastKind::IntToInt,
102 Operand::Copy(discr_place),
103 tcx.types.usize,
104 ),
105 ))),
106 };
107
108 let size_place =
109 Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
110 let store_size = Statement {
111 source_info,
112 kind: StatementKind::Assign(Box::new((
113 size_place,
114 Rvalue::Use(Operand::Copy(Place {
115 local: size_array_local,
116 projection: tcx
117 .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
118 })),
119 ))),
120 };
121
122 let dst =
123 Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span)));
124 let dst_ptr = Statement {
125 source_info,
126 kind: StatementKind::Assign(Box::new((
127 dst,
128 Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
129 ))),
130 };
131
132 let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
133 let dst_cast_place =
134 Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
135 let dst_cast = Statement {
136 source_info,
137 kind: StatementKind::Assign(Box::new((
138 dst_cast_place,
139 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
140 ))),
141 };
142
143 let src =
144 Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span)));
145 let src_ptr = Statement {
146 source_info,
147 kind: StatementKind::Assign(Box::new((
148 src,
149 Rvalue::RawPtr(RawPtrKind::Const, *rhs),
150 ))),
151 };
152
153 let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
154 let src_cast_place =
155 Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
156 let src_cast = Statement {
157 source_info,
158 kind: StatementKind::Assign(Box::new((
159 src_cast_place,
160 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
161 ))),
162 };
163
164 let deinit_old =
165 Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
166
167 let copy_bytes = Statement {
168 source_info,
169 kind: StatementKind::Intrinsic(Box::new(
170 NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
171 src: Operand::Copy(src_cast_place),
172 dst: Operand::Copy(dst_cast_place),
173 count: Operand::Copy(size_place),
174 }),
175 )),
176 };
177
178 let store_dead =
179 Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
180
181 let iter = [
182 store_live,
183 const_assign,
184 store_discr,
185 cast_discr,
186 store_size,
187 dst_ptr,
188 dst_cast,
189 src_ptr,
190 src_cast,
191 deinit_old,
192 copy_bytes,
193 store_dead,
194 ]
195 .into_iter();
196
197 st.make_nop();
198
199 Some(iter)
200 });
201 }
202 }
203
204 fn is_required(&self) -> bool {
205 false
206 }
207}
208
209impl EnumSizeOpt {
210 fn candidate<'tcx>(
211 &self,
212 tcx: TyCtxt<'tcx>,
213 typing_env: ty::TypingEnv<'tcx>,
214 ty: Ty<'tcx>,
215 alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
216 ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
217 let adt_def = match ty.kind() {
218 ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def,
219 _ => return None,
220 };
221 let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?;
222 let variants = match &layout.variants {
223 Variants::Single { .. } | Variants::Empty => return None,
224 Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None,
225
226 Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
227 Variants::Multiple { variants, .. } => variants,
228 };
229 let min = variants.iter().map(|v| v.size).min().unwrap();
230 let max = variants.iter().map(|v| v.size).max().unwrap();
231 if max.bytes() - min.bytes() < self.discrepancy {
232 return None;
233 }
234
235 let num_discrs = adt_def.discriminants(tcx).count();
236 if variants.iter_enumerated().any(|(var_idx, _)| {
237 let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
238 (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
239 }) {
240 return None;
241 }
242 if let Some(alloc_id) = alloc_cache.get(&ty) {
243 return Some((*adt_def, num_discrs, *alloc_id));
244 }
245
246 let data_layout = tcx.data_layout();
247 let ptr_sized_int = data_layout.ptr_sized_integer();
248 let target_bytes = ptr_sized_int.size().bytes() as usize;
249 let mut data = vec![0; target_bytes * num_discrs];
250
251 macro_rules! encode_store {
253 ($curr_idx: expr, $endian: expr, $bytes: expr) => {
254 let bytes = match $endian {
255 rustc_abi::Endian::Little => $bytes.to_le_bytes(),
256 rustc_abi::Endian::Big => $bytes.to_be_bytes(),
257 };
258 for (i, b) in bytes.into_iter().enumerate() {
259 data[$curr_idx + i] = b;
260 }
261 };
262 }
263
264 for (var_idx, layout) in variants.iter_enumerated() {
265 let curr_idx =
266 target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
267 let sz = layout.size;
268 match ptr_sized_int {
269 rustc_abi::Integer::I32 => {
270 encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
271 }
272 rustc_abi::Integer::I64 => {
273 encode_store!(curr_idx, data_layout.endian, sz.bytes());
274 }
275 _ => unreachable!(),
276 };
277 }
278 let alloc = interpret::Allocation::from_bytes(
279 data,
280 tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
281 Mutability::Not,
282 );
283 let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc));
284 Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
285 }
286}