Skip to main content

rustc_mir_transform/
check_alignment.rs

1use rustc_abi::Align;
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::mir::interpret::Scalar;
5use rustc_middle::mir::visit::PlaceContext;
6use rustc_middle::mir::*;
7use rustc_middle::ty::{Ty, TyCtxt};
8use rustc_session::Session;
9
10use crate::check_pointers::{BorrowedFieldProjectionMode, PointerCheck, check_pointers};
11
12pub(super) struct CheckAlignment;
13
14impl<'tcx> crate::MirPass<'tcx> for CheckAlignment {
15    fn is_enabled(&self, sess: &Session) -> bool {
16        sess.ub_checks()
17    }
18
19    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
20        // Skip trivially aligned place types.
21        let excluded_pointees = [tcx.types.bool, tcx.types.i8, tcx.types.u8];
22
23        // When checking the alignment of references to field projections (`&(*ptr).a`),
24        // we need to make sure that the reference is aligned according to the field type
25        // and not to the pointer type.
26        check_pointers(
27            tcx,
28            body,
29            &excluded_pointees,
30            insert_alignment_check,
31            BorrowedFieldProjectionMode::FollowProjections,
32        );
33    }
34
35    fn is_required(&self) -> bool {
36        true
37    }
38}
39
40/// Inserts the actual alignment check's logic. Returns a
41/// [AssertKind::MisalignedPointerDereference] on failure.
42fn insert_alignment_check<'tcx>(
43    tcx: TyCtxt<'tcx>,
44    pointer: Place<'tcx>,
45    pointee_ty: Ty<'tcx>,
46    _context: PlaceContext,
47    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
48    stmts: &mut Vec<Statement<'tcx>>,
49    source_info: SourceInfo,
50) -> PointerCheck<'tcx> {
51    // Cast the pointer to a *const ().
52    let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit);
53    let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
54    let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
55    stmts.push(Statement::new(source_info, StatementKind::Assign(Box::new((thin_ptr, rvalue)))));
56
57    // Transmute the pointer to a usize (equivalent to `ptr.addr()`).
58    let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
59    let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
60    stmts.push(Statement::new(source_info, StatementKind::Assign(Box::new((addr, rvalue)))));
61
62    // Get the alignment of the pointee
63    let align_def_id = tcx.require_lang_item(LangItem::AlignOf, source_info.span);
64    let alignment =
65        Operand::unevaluated_constant(tcx, align_def_id, &[pointee_ty.into()], source_info.span);
66
67    // Subtract 1 from the alignment to get the alignment mask
68    let alignment_mask =
69        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
70    let one = Operand::Constant(Box::new(ConstOperand {
71        span: source_info.span,
72        user_ty: None,
73        const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), tcx.types.usize),
74    }));
75    stmts.push(Statement::new(
76        source_info,
77        StatementKind::Assign(Box::new((
78            alignment_mask,
79            Rvalue::BinaryOp(BinOp::Sub, Box::new((alignment.clone(), one))),
80        ))),
81    ));
82
83    // If this target does not have reliable alignment, further limit the mask by anding it with
84    // the mask for the highest reliable alignment.
85    if let max_align = tcx.sess.target.max_reliable_alignment()
86        && max_align < Align::MAX
87    {
88        let max_mask = max_align.bytes() - 1;
89        let max_mask = Operand::Constant(Box::new(ConstOperand {
90            span: source_info.span,
91            user_ty: None,
92            const_: Const::Val(
93                ConstValue::Scalar(Scalar::from_target_usize(max_mask, &tcx)),
94                tcx.types.usize,
95            ),
96        }));
97        stmts.push(Statement::new(
98            source_info,
99            StatementKind::Assign(Box::new((
100                alignment_mask,
101                Rvalue::BinaryOp(
102                    BinOp::BitAnd,
103                    Box::new((Operand::Copy(alignment_mask), max_mask)),
104                ),
105            ))),
106        ));
107    }
108
109    // BitAnd the alignment mask with the pointer
110    let alignment_bits =
111        local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
112    stmts.push(Statement::new(
113        source_info,
114        StatementKind::Assign(Box::new((
115            alignment_bits,
116            Rvalue::BinaryOp(
117                BinOp::BitAnd,
118                Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))),
119            ),
120        ))),
121    ));
122
123    // Check if the alignment bits are all zero
124    let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
125    let zero = Operand::Constant(Box::new(ConstOperand {
126        span: source_info.span,
127        user_ty: None,
128        const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
129    }));
130    stmts.push(Statement::new(
131        source_info,
132        StatementKind::Assign(Box::new((
133            is_ok,
134            Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
135        ))),
136    ));
137
138    // Emit a check that asserts on the alignment and otherwise triggers a
139    // AssertKind::MisalignedPointerDereference.
140    PointerCheck {
141        cond: Operand::Copy(is_ok),
142        assert_kind: Box::new(AssertKind::MisalignedPointerDereference {
143            required: alignment,
144            found: Operand::Copy(addr),
145        }),
146    }
147}