Skip to main content

rustc_mir_transform/
check_enums.rs

1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_data_structures::thin_vec::ThinVec;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_middle::bug;
6use rustc_middle::mir::visit::Visitor;
7use rustc_middle::mir::*;
8use rustc_middle::ty::layout::PrimitiveExt;
9use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
10use rustc_session::Session;
11use tracing::debug;
12
13/// This pass inserts checks for a valid enum discriminant where they are most
14/// likely to find UB, because checking everywhere like Miri would generate too
15/// much MIR.
16pub(super) struct CheckEnums;
17
18impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
19    fn is_enabled(&self, sess: &Session) -> bool {
20        sess.ub_checks()
21    }
22
23    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24        // This pass emits new panics. If for whatever reason we do not have a panic
25        // implementation, running this pass may cause otherwise-valid code to not compile.
26        if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
27            return;
28        }
29
30        let typing_env = body.typing_env(tcx);
31        let basic_blocks = body.basic_blocks.as_mut();
32        let local_decls = &mut body.local_decls;
33
34        // This operation inserts new blocks. Each insertion changes the Location for all
35        // statements/blocks after. Iterating or visiting the MIR in order would require updating
36        // our current location after every insertion. By iterating backwards, we dodge this issue:
37        // The only Locations that an insertion changes have already been handled.
38        for block in basic_blocks.indices().rev() {
39            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
40                let location = Location { block, statement_index };
41                let statement = &basic_blocks[block].statements[statement_index];
42                let source_info = statement.source_info;
43
44                let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
45                finder.visit_statement(statement, location);
46
47                for check in finder.into_found_enums() {
48                    debug!("Inserting enum check");
49                    let new_block = split_block(basic_blocks, location);
50
51                    match check {
52                        EnumCheckType::Direct { op_size, .. }
53                        | EnumCheckType::WithNiche { op_size, .. }
54                            if op_size.bytes() == 0 =>
55                        {
56                            // It is never valid to use a ZST as a discriminant for an inhabited enum, but that will
57                            // have been caught by the type checker. Do nothing but ensure that a bug has been signaled.
58                            tcx.dcx().span_delayed_bug(
59                                source_info.span,
60                                "cannot build enum discriminant from zero-sized type",
61                            );
62                            basic_blocks[block].terminator = Some(Terminator {
63                                source_info,
64                                kind: TerminatorKind::Goto { target: new_block },
65                                attributes: ThinVec::new(),
66                            });
67                        }
68                        EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
69                            insert_direct_enum_check(
70                                tcx,
71                                local_decls,
72                                basic_blocks,
73                                block,
74                                source_op,
75                                discr,
76                                op_size,
77                                valid_discrs,
78                                source_info,
79                                new_block,
80                            )
81                        }
82                        EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
83                            tcx,
84                            local_decls,
85                            &mut basic_blocks[block],
86                            source_info,
87                            new_block,
88                        ),
89                        EnumCheckType::WithNiche {
90                            source_op,
91                            discr,
92                            op_size,
93                            offset,
94                            valid_range,
95                        } => insert_niche_check(
96                            tcx,
97                            local_decls,
98                            &mut basic_blocks[block],
99                            source_op,
100                            valid_range,
101                            discr,
102                            op_size,
103                            offset,
104                            source_info,
105                            new_block,
106                        ),
107                    }
108                }
109            }
110        }
111    }
112
113    fn is_required(&self) -> bool {
114        true
115    }
116}
117
118/// Represent the different kind of enum checks we can insert.
119enum EnumCheckType<'tcx> {
120    /// We know we try to create an uninhabited enum from an inhabited variant.
121    Uninhabited,
122    /// We know the enum does no niche optimizations and can thus easily compute
123    /// the valid discriminants.
124    Direct {
125        source_op: Operand<'tcx>,
126        discr: TyAndSize<'tcx>,
127        op_size: Size,
128        valid_discrs: Vec<u128>,
129    },
130    /// We try to construct an enum that has a niche.
131    WithNiche {
132        source_op: Operand<'tcx>,
133        discr: TyAndSize<'tcx>,
134        op_size: Size,
135        offset: Size,
136        valid_range: WrappingRange,
137    },
138}
139
140#[derive(Debug, Copy, Clone)]
141struct TyAndSize<'tcx> {
142    pub ty: Ty<'tcx>,
143    pub size: Size,
144}
145
146/// A [Visitor] that finds the construction of enums and evaluates which checks
147/// we should apply.
148struct EnumFinder<'a, 'tcx> {
149    tcx: TyCtxt<'tcx>,
150    local_decls: &'a mut LocalDecls<'tcx>,
151    typing_env: TypingEnv<'tcx>,
152    enums: Vec<EnumCheckType<'tcx>>,
153}
154
155impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
156    fn new(
157        tcx: TyCtxt<'tcx>,
158        local_decls: &'a mut LocalDecls<'tcx>,
159        typing_env: TypingEnv<'tcx>,
160    ) -> Self {
161        EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
162    }
163
164    /// Returns the found enum creations and which checks should be inserted.
165    fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
166        self.enums
167    }
168}
169
170impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
171    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
172        if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
173            let ty::Adt(adt_def, _) = ty.kind() else {
174                return;
175            };
176            if !adt_def.is_enum() {
177                return;
178            }
179
180            let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
181                return;
182            };
183            let Ok(op_layout) = self
184                .tcx
185                .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
186            else {
187                return;
188            };
189
190            match enum_layout.variants {
191                Variants::Empty if op_layout.is_uninhabited() => return,
192                // An empty enum that tries to be constructed from an inhabited value, this
193                // is never correct.
194                Variants::Empty => {
195                    // The enum layout is uninhabited but we construct it from sth inhabited.
196                    // This is always UB.
197                    self.enums.push(EnumCheckType::Uninhabited);
198                }
199                // Construction of Single value enums is always fine.
200                Variants::Single { .. } => {}
201                // Construction of an enum with multiple variants but no niche optimizations.
202                Variants::Multiple {
203                    tag_encoding: TagEncoding::Direct,
204                    tag: Scalar::Initialized { value, .. },
205                    ..
206                } => {
207                    let valid_discrs =
208                        adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
209
210                    let discr =
211                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
212                    self.enums.push(EnumCheckType::Direct {
213                        source_op: op.to_copy(),
214                        discr,
215                        op_size: op_layout.size,
216                        valid_discrs,
217                    });
218                }
219                // Construction of an enum with multiple variants and niche optimizations.
220                Variants::Multiple {
221                    tag_encoding: TagEncoding::Niche { .. },
222                    tag: Scalar::Initialized { value, valid_range, .. },
223                    tag_field,
224                    ..
225                } => {
226                    let discr =
227                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
228                    self.enums.push(EnumCheckType::WithNiche {
229                        source_op: op.to_copy(),
230                        discr,
231                        op_size: op_layout.size,
232                        offset: enum_layout.fields.offset(tag_field.as_usize()),
233                        valid_range,
234                    });
235                }
236                _ => return,
237            }
238
239            self.super_rvalue(rvalue, location);
240        }
241    }
242}
243
244fn split_block(
245    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
246    location: Location,
247) -> BasicBlock {
248    let block_data = &mut basic_blocks[location.block];
249
250    // Drain every statement after this one and move the current terminator to a new basic block.
251    let new_block = BasicBlockData::new_stmts(
252        block_data.statements.split_off(location.statement_index),
253        block_data.terminator.take(),
254        block_data.is_cleanup,
255    );
256
257    basic_blocks.push(new_block)
258}
259
260/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
261fn insert_discr_cast_to_u128<'tcx>(
262    tcx: TyCtxt<'tcx>,
263    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
264    block_data: &mut BasicBlockData<'tcx>,
265    source_op: Operand<'tcx>,
266    discr: TyAndSize<'tcx>,
267    op_size: Size,
268    offset: Option<Size>,
269    source_info: SourceInfo,
270) -> Place<'tcx> {
271    let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
272        match size.bytes() {
273            1 => tcx.types.u8,
274            2 => tcx.types.u16,
275            4 => tcx.types.u32,
276            8 => tcx.types.u64,
277            16 => tcx.types.u128,
278            invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
279        }
280    };
281
282    let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
283        // The discriminant is less wide than the operand, cast the operand into
284        // [MaybeUninit; N] and then index into it.
285        let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
286        let array_len = op_size.bytes();
287        let mu_array_ty = Ty::new_array(tcx, mu, array_len);
288        let mu_array =
289            local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
290        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
291        block_data
292            .statements
293            .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
294
295        // Index into the array of MaybeUninit to get something that is actually
296        // as wide as the discriminant.
297        let offset = offset.unwrap_or(Size::ZERO);
298        let smaller_mu_array = mu_array.project_deeper(
299            &[ProjectionElem::Subslice {
300                from: offset.bytes(),
301                to: offset.bytes() + discr.size.bytes(),
302                from_end: false,
303            }],
304            tcx,
305        );
306
307        (CastKind::Transmute, Operand::Copy(smaller_mu_array))
308    } else {
309        let operand_int_ty = get_ty_for_size(tcx, op_size);
310
311        let op_as_int =
312            local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
313        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
314        block_data.statements.push(Statement::new(
315            source_info,
316            StatementKind::Assign(Box::new((op_as_int, rvalue))),
317        ));
318
319        (CastKind::IntToInt, Operand::Copy(op_as_int))
320    };
321
322    // Cast the resulting value to the actual discriminant integer type.
323    let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
324    let discr_in_discr_ty =
325        local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
326    block_data.statements.push(Statement::new(
327        source_info,
328        StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
329    ));
330
331    // Cast the discriminant to a u128 (base for comparisons of enum discriminants).
332    let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
333    let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
334    let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
335    block_data
336        .statements
337        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
338
339    discr
340}
341
342fn insert_direct_enum_check<'tcx>(
343    tcx: TyCtxt<'tcx>,
344    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
345    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
346    current_block: BasicBlock,
347    source_op: Operand<'tcx>,
348    discr: TyAndSize<'tcx>,
349    op_size: Size,
350    discriminants: Vec<u128>,
351    source_info: SourceInfo,
352    new_block: BasicBlock,
353) {
354    // Insert a new target block that is branched to in case of an invalid discriminant.
355    let invalid_discr_block_data = BasicBlockData::new(None, false);
356    let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
357    let block_data = &mut basic_blocks[current_block];
358    let discr_place = insert_discr_cast_to_u128(
359        tcx,
360        local_decls,
361        block_data,
362        source_op,
363        discr,
364        op_size,
365        None,
366        source_info,
367    );
368
369    // Mask out the bits of the discriminant type.
370    let mask = discr.size.unsigned_int_max();
371    let discr_masked =
372        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
373    let rvalue = Rvalue::BinaryOp(
374        BinOp::BitAnd,
375        Box::new((
376            Operand::Copy(discr_place),
377            Operand::Constant(Box::new(ConstOperand {
378                span: source_info.span,
379                user_ty: None,
380                const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
381            })),
382        )),
383    );
384    block_data
385        .statements
386        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
387
388    // Branch based on the discriminant value.
389    block_data.terminator = Some(Terminator {
390        source_info,
391        kind: TerminatorKind::SwitchInt {
392            discr: Operand::Copy(discr_masked),
393            targets: SwitchTargets::new(
394                discriminants
395                    .into_iter()
396                    .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
397                invalid_discr_block,
398            ),
399        },
400        attributes: ThinVec::new(),
401    });
402
403    // Abort in case of an invalid enum discriminant.
404    basic_blocks[invalid_discr_block].terminator = Some(Terminator {
405        source_info,
406        kind: TerminatorKind::Assert {
407            cond: Operand::Constant(Box::new(ConstOperand {
408                span: source_info.span,
409                user_ty: None,
410                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
411            })),
412            expected: true,
413            target: new_block,
414            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
415            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
416            // We never want to insert an unwind into unsafe code, because unwinding could
417            // make a failing UB check turn into much worse UB when we start unwinding.
418            unwind: UnwindAction::Unreachable,
419        },
420        attributes: ThinVec::new(),
421    });
422}
423
424fn insert_uninhabited_enum_check<'tcx>(
425    tcx: TyCtxt<'tcx>,
426    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
427    block_data: &mut BasicBlockData<'tcx>,
428    source_info: SourceInfo,
429    new_block: BasicBlock,
430) {
431    let is_ok: Place<'_> =
432        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
433    block_data.statements.push(Statement::new(
434        source_info,
435        StatementKind::Assign(Box::new((
436            is_ok,
437            Rvalue::Use(
438                Operand::Constant(Box::new(ConstOperand {
439                    span: source_info.span,
440                    user_ty: None,
441                    const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
442                })),
443                WithRetag::Yes, // it's a bool, retag doesn't matter
444            ),
445        ))),
446    ));
447
448    block_data.terminator = Some(Terminator {
449        source_info,
450        kind: TerminatorKind::Assert {
451            cond: Operand::Copy(is_ok),
452            expected: true,
453            target: new_block,
454            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
455                ConstOperand {
456                    span: source_info.span,
457                    user_ty: None,
458                    const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
459                },
460            )))),
461            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
462            // We never want to insert an unwind into unsafe code, because unwinding could
463            // make a failing UB check turn into much worse UB when we start unwinding.
464            unwind: UnwindAction::Unreachable,
465        },
466        attributes: ThinVec::new(),
467    });
468}
469
470fn insert_niche_check<'tcx>(
471    tcx: TyCtxt<'tcx>,
472    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
473    block_data: &mut BasicBlockData<'tcx>,
474    source_op: Operand<'tcx>,
475    valid_range: WrappingRange,
476    discr: TyAndSize<'tcx>,
477    op_size: Size,
478    offset: Size,
479    source_info: SourceInfo,
480    new_block: BasicBlock,
481) {
482    let discr = insert_discr_cast_to_u128(
483        tcx,
484        local_decls,
485        block_data,
486        source_op,
487        discr,
488        op_size,
489        Some(offset),
490        source_info,
491    );
492
493    // Compare the discriminant against the valid_range.
494    let start_const = Operand::Constant(Box::new(ConstOperand {
495        span: source_info.span,
496        user_ty: None,
497        const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
498    }));
499    let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
500        span: source_info.span,
501        user_ty: None,
502        const_: Const::Val(
503            ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
504            tcx.types.u128,
505        ),
506    }));
507
508    let discr_diff: Place<'_> =
509        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
510    block_data.statements.push(Statement::new(
511        source_info,
512        StatementKind::Assign(Box::new((
513            discr_diff,
514            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
515        ))),
516    ));
517
518    let is_ok: Place<'_> =
519        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
520    block_data.statements.push(Statement::new(
521        source_info,
522        StatementKind::Assign(Box::new((
523            is_ok,
524            Rvalue::BinaryOp(
525                // This is a `WrappingRange`, so make sure to get the wrapping right.
526                BinOp::Le,
527                Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
528            ),
529        ))),
530    ));
531
532    block_data.terminator = Some(Terminator {
533        source_info,
534        kind: TerminatorKind::Assert {
535            cond: Operand::Copy(is_ok),
536            expected: true,
537            target: new_block,
538            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
539            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
540            // We never want to insert an unwind into unsafe code, because unwinding could
541            // make a failing UB check turn into much worse UB when we start unwinding.
542            unwind: UnwindAction::Unreachable,
543        },
544        attributes: ThinVec::new(),
545    });
546}